In [1]:
import torch
import glob
import tqdm

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from utilities import utils, train_utils
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from torch.utils.data import DataLoader, TensorDataset
from pythae.models import AE, AEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOutput

In [19]:
paths = sorted(glob.glob('./data/environmental_embeddings_0001/0001/*.msgpack'))
device = train_utils.get_device()

X = []
for path in tqdm.tqdm(paths):
    pos_emb, neg_emb = utils.read_embedding_data(path)
    X.append(pos_emb)
X = np.concatenate(X, axis=0)
X = X.reshape(-1, 768)

Xtr, Xvl = train_test_split(X, test_size=0.2, random_state=42)

Xtr_scl = normalize(Xtr, norm='l2', axis=1)
Xvl_scl = normalize(Xvl, norm='l2', axis=1)

train_data = torch.tensor(Xtr_scl, dtype=torch.float32)
val_data = torch.tensor(Xvl_scl, dtype=torch.float32)

train_dataset = TensorDataset(train_data)
val_dataset = TensorDataset(val_data)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

100%|██████████| 1000/1000 [00:05<00:00, 197.27it/s]


In [20]:
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        self.encoder_layers = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, self.latent_dim),
        )

    def forward(self, x):
        x = self.encoder_layers(x)

        return x
    
class Decoder(BaseDecoder):
    def __init__(self, input_dim, latent_dim):
        super(Decoder, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        self.decoder_layers = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, self.input_dim),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.decoder_layers(z)

        return x
    
class AutoEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(AutoEncoder, self).__init__()

        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(input_dim, latent_dim)

    def forward(self, x):
        self.z = self.encoder(x)
        recon = self.decoder(self.z)

        return recon

In [21]:
model = AutoEncoder(input_dim=768, latent_dim=128)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, threshold=0.01)
criterion = torch.nn.MSELoss(reduction='sum')
model = model.to(device)

In [22]:
train_utils.train_loop(
    train_loader=train_loader,
    val_loader=val_loader,
    model=model,
    criterion=criterion,
    device=device,
    optimizer=optimizer,
    epochs=500,
    save_path='experiments/fc_exp3/',
    sparsity_penalty_weight=0.001,
    scheduler=None
)

train loss: 0.429: 100%|██████████| 482/482 [00:01<00:00, 242.16it/s]
val loss: 0.417: 100%|██████████| 121/121 [00:00<00:00, 798.76it/s]


Weight saved: epoch 0
Epoch 0	Train Loss: 0.519  Val Loss: 0.401



train loss: 0.322: 100%|██████████| 482/482 [00:01<00:00, 243.40it/s]
val loss: 0.375: 100%|██████████| 121/121 [00:00<00:00, 848.39it/s]


Weight saved: epoch 1
Epoch 1	Train Loss: 0.381  Val Loss: 0.360



train loss: 0.339: 100%|██████████| 482/482 [00:01<00:00, 241.98it/s]
val loss: 0.348: 100%|██████████| 121/121 [00:00<00:00, 916.85it/s]


Weight saved: epoch 2
Epoch 2	Train Loss: 0.347  Val Loss: 0.334



train loss: 0.332: 100%|██████████| 482/482 [00:01<00:00, 241.68it/s]
val loss: 0.328: 100%|██████████| 121/121 [00:00<00:00, 875.75it/s]


Weight saved: epoch 3
Epoch 3	Train Loss: 0.323  Val Loss: 0.312



train loss: 0.285: 100%|██████████| 482/482 [00:02<00:00, 240.87it/s]
val loss: 0.306: 100%|██████████| 121/121 [00:00<00:00, 800.35it/s]


Weight saved: epoch 4
Epoch 4	Train Loss: 0.303  Val Loss: 0.294



train loss: 0.298: 100%|██████████| 482/482 [00:01<00:00, 243.73it/s]
val loss: 0.289: 100%|██████████| 121/121 [00:00<00:00, 817.20it/s]


Weight saved: epoch 5
Epoch 5	Train Loss: 0.287  Val Loss: 0.279



train loss: 0.295: 100%|██████████| 482/482 [00:01<00:00, 241.52it/s]
val loss: 0.276: 100%|██████████| 121/121 [00:00<00:00, 848.01it/s]


Weight saved: epoch 6
Epoch 6	Train Loss: 0.273  Val Loss: 0.268



train loss: 0.300: 100%|██████████| 482/482 [00:01<00:00, 244.97it/s]
val loss: 0.265: 100%|██████████| 121/121 [00:00<00:00, 858.36it/s]


Weight saved: epoch 7
Epoch 7	Train Loss: 0.261  Val Loss: 0.257



train loss: 0.283: 100%|██████████| 482/482 [00:01<00:00, 243.90it/s]
val loss: 0.253: 100%|██████████| 121/121 [00:00<00:00, 769.73it/s]


Weight saved: epoch 8
Epoch 8	Train Loss: 0.251  Val Loss: 0.248



train loss: 0.248: 100%|██████████| 482/482 [00:02<00:00, 237.27it/s]
val loss: 0.245: 100%|██████████| 121/121 [00:00<00:00, 793.19it/s]


Weight saved: epoch 9
Epoch 9	Train Loss: 0.243  Val Loss: 0.239



train loss: 0.234: 100%|██████████| 482/482 [00:02<00:00, 237.24it/s]
val loss: 0.236: 100%|██████████| 121/121 [00:00<00:00, 801.01it/s]


Weight saved: epoch 10
Epoch 10	Train Loss: 0.235  Val Loss: 0.232



train loss: 0.251: 100%|██████████| 482/482 [00:01<00:00, 244.73it/s]
val loss: 0.231: 100%|██████████| 121/121 [00:00<00:00, 867.74it/s]


Weight saved: epoch 11
Epoch 11	Train Loss: 0.228  Val Loss: 0.227



train loss: 0.253: 100%|██████████| 482/482 [00:01<00:00, 241.76it/s]
val loss: 0.225: 100%|██████████| 121/121 [00:00<00:00, 781.96it/s]


Weight saved: epoch 12
Epoch 12	Train Loss: 0.222  Val Loss: 0.222



train loss: 0.237: 100%|██████████| 482/482 [00:02<00:00, 234.64it/s]
val loss: 0.218: 100%|██████████| 121/121 [00:00<00:00, 804.51it/s]


Weight saved: epoch 13
Epoch 13	Train Loss: 0.217  Val Loss: 0.216



train loss: 0.246: 100%|██████████| 482/482 [00:02<00:00, 229.70it/s]
val loss: 0.213: 100%|██████████| 121/121 [00:00<00:00, 819.44it/s]


Weight saved: epoch 14
Epoch 14	Train Loss: 0.212  Val Loss: 0.211



train loss: 0.208: 100%|██████████| 482/482 [00:02<00:00, 238.09it/s]
val loss: 0.207: 100%|██████████| 121/121 [00:00<00:00, 862.65it/s]


Weight saved: epoch 15
Epoch 15	Train Loss: 0.207  Val Loss: 0.207



train loss: 0.217: 100%|██████████| 482/482 [00:02<00:00, 240.83it/s]
val loss: 0.205: 100%|██████████| 121/121 [00:00<00:00, 851.54it/s]


Weight saved: epoch 16
Epoch 16	Train Loss: 0.202  Val Loss: 0.204



train loss: 0.182: 100%|██████████| 482/482 [00:01<00:00, 241.31it/s]
val loss: 0.199: 100%|██████████| 121/121 [00:00<00:00, 742.15it/s]


Weight saved: epoch 17
Epoch 17	Train Loss: 0.198  Val Loss: 0.199



train loss: 0.175: 100%|██████████| 482/482 [00:02<00:00, 237.07it/s]
val loss: 0.196: 100%|██████████| 121/121 [00:00<00:00, 786.41it/s]


Weight saved: epoch 18
Epoch 18	Train Loss: 0.195  Val Loss: 0.196



train loss: 0.173: 100%|██████████| 482/482 [00:02<00:00, 240.19it/s]
val loss: 0.192: 100%|██████████| 121/121 [00:00<00:00, 815.54it/s]


Weight saved: epoch 19
Epoch 19	Train Loss: 0.191  Val Loss: 0.193



train loss: 0.249: 100%|██████████| 482/482 [00:02<00:00, 233.45it/s]
val loss: 0.188: 100%|██████████| 121/121 [00:00<00:00, 860.46it/s]


Weight saved: epoch 20
Epoch 20	Train Loss: 0.188  Val Loss: 0.190



train loss: 0.161: 100%|██████████| 482/482 [00:01<00:00, 241.16it/s]
val loss: 0.185: 100%|██████████| 121/121 [00:00<00:00, 782.92it/s]


Weight saved: epoch 21
Epoch 21	Train Loss: 0.185  Val Loss: 0.187



train loss: 0.190: 100%|██████████| 482/482 [00:02<00:00, 236.32it/s]
val loss: 0.183: 100%|██████████| 121/121 [00:00<00:00, 923.56it/s]


Weight saved: epoch 22
Epoch 22	Train Loss: 0.182  Val Loss: 0.185



train loss: 0.166: 100%|██████████| 482/482 [00:02<00:00, 237.59it/s]
val loss: 0.180: 100%|██████████| 121/121 [00:00<00:00, 829.67it/s]


Weight saved: epoch 23
Epoch 23	Train Loss: 0.180  Val Loss: 0.182



train loss: 0.167: 100%|██████████| 482/482 [00:01<00:00, 244.31it/s]
val loss: 0.178: 100%|██████████| 121/121 [00:00<00:00, 834.40it/s]


Weight saved: epoch 24
Epoch 24	Train Loss: 0.177  Val Loss: 0.180



train loss: 0.190: 100%|██████████| 482/482 [00:01<00:00, 242.28it/s]
val loss: 0.175: 100%|██████████| 121/121 [00:00<00:00, 833.44it/s]


Weight saved: epoch 25
Epoch 25	Train Loss: 0.175  Val Loss: 0.178



train loss: 0.139: 100%|██████████| 482/482 [00:02<00:00, 237.71it/s]
val loss: 0.174: 100%|██████████| 121/121 [00:00<00:00, 866.96it/s]


Weight saved: epoch 26
Epoch 26	Train Loss: 0.173  Val Loss: 0.176



train loss: 0.182: 100%|██████████| 482/482 [00:01<00:00, 241.52it/s]
val loss: 0.172: 100%|██████████| 121/121 [00:00<00:00, 895.43it/s]


Weight saved: epoch 27
Epoch 27	Train Loss: 0.171  Val Loss: 0.174



train loss: 0.161: 100%|██████████| 482/482 [00:02<00:00, 232.95it/s]
val loss: 0.170: 100%|██████████| 121/121 [00:00<00:00, 801.21it/s]


Weight saved: epoch 28
Epoch 28	Train Loss: 0.168  Val Loss: 0.172



train loss: 0.192: 100%|██████████| 482/482 [00:01<00:00, 242.15it/s]
val loss: 0.168: 100%|██████████| 121/121 [00:00<00:00, 874.18it/s]


Weight saved: epoch 29
Epoch 29	Train Loss: 0.167  Val Loss: 0.170



train loss: 0.152: 100%|██████████| 482/482 [00:02<00:00, 237.83it/s]
val loss: 0.167: 100%|██████████| 121/121 [00:00<00:00, 821.42it/s]


Weight saved: epoch 30
Epoch 30	Train Loss: 0.165  Val Loss: 0.169



train loss: 0.187: 100%|██████████| 482/482 [00:01<00:00, 242.81it/s]
val loss: 0.164: 100%|██████████| 121/121 [00:00<00:00, 838.28it/s]


Weight saved: epoch 31
Epoch 31	Train Loss: 0.163  Val Loss: 0.167



train loss: 0.188: 100%|██████████| 482/482 [00:01<00:00, 242.57it/s]
val loss: 0.163: 100%|██████████| 121/121 [00:00<00:00, 893.62it/s]


Weight saved: epoch 32
Epoch 32	Train Loss: 0.161  Val Loss: 0.165



train loss: 0.172: 100%|██████████| 482/482 [00:01<00:00, 244.37it/s]
val loss: 0.161: 100%|██████████| 121/121 [00:00<00:00, 822.52it/s]


Weight saved: epoch 33
Epoch 33	Train Loss: 0.160  Val Loss: 0.164



train loss: 0.175: 100%|██████████| 482/482 [00:02<00:00, 240.09it/s]
val loss: 0.159: 100%|██████████| 121/121 [00:00<00:00, 858.00it/s]


Weight saved: epoch 34
Epoch 34	Train Loss: 0.158  Val Loss: 0.163



train loss: 0.168: 100%|██████████| 482/482 [00:02<00:00, 231.39it/s]
val loss: 0.158: 100%|██████████| 121/121 [00:00<00:00, 807.02it/s]


Weight saved: epoch 35
Epoch 35	Train Loss: 0.157  Val Loss: 0.161



train loss: 0.121: 100%|██████████| 482/482 [00:02<00:00, 235.85it/s]
val loss: 0.157: 100%|██████████| 121/121 [00:00<00:00, 803.27it/s]


Weight saved: epoch 36
Epoch 36	Train Loss: 0.156  Val Loss: 0.160



train loss: 0.172: 100%|██████████| 482/482 [00:02<00:00, 236.63it/s]
val loss: 0.156: 100%|██████████| 121/121 [00:00<00:00, 810.47it/s]


Weight saved: epoch 37
Epoch 37	Train Loss: 0.154  Val Loss: 0.159



train loss: 0.154: 100%|██████████| 482/482 [00:02<00:00, 232.29it/s]
val loss: 0.154: 100%|██████████| 121/121 [00:00<00:00, 778.65it/s]


Weight saved: epoch 38
Epoch 38	Train Loss: 0.153  Val Loss: 0.158



train loss: 0.181: 100%|██████████| 482/482 [00:02<00:00, 235.98it/s]
val loss: 0.154: 100%|██████████| 121/121 [00:00<00:00, 833.73it/s]


Weight saved: epoch 39
Epoch 39	Train Loss: 0.152  Val Loss: 0.157



train loss: 0.116: 100%|██████████| 482/482 [00:02<00:00, 231.76it/s]
val loss: 0.152: 100%|██████████| 121/121 [00:00<00:00, 847.64it/s]


Weight saved: epoch 40
Epoch 40	Train Loss: 0.151  Val Loss: 0.155



train loss: 0.184: 100%|██████████| 482/482 [00:02<00:00, 230.50it/s]
val loss: 0.151: 100%|██████████| 121/121 [00:00<00:00, 836.41it/s]


Weight saved: epoch 41
Epoch 41	Train Loss: 0.149  Val Loss: 0.155



train loss: 0.147: 100%|██████████| 482/482 [00:01<00:00, 241.86it/s]
val loss: 0.151: 100%|██████████| 121/121 [00:00<00:00, 840.94it/s]


Weight saved: epoch 42
Epoch 42	Train Loss: 0.148  Val Loss: 0.154



train loss: 0.137: 100%|██████████| 482/482 [00:02<00:00, 235.12it/s]
val loss: 0.150: 100%|██████████| 121/121 [00:00<00:00, 790.18it/s]


Weight saved: epoch 43
Epoch 43	Train Loss: 0.147  Val Loss: 0.153



train loss: 0.123: 100%|██████████| 482/482 [00:02<00:00, 233.43it/s]
val loss: 0.149: 100%|██████████| 121/121 [00:00<00:00, 842.01it/s]


Weight saved: epoch 44
Epoch 44	Train Loss: 0.146  Val Loss: 0.152



train loss: 0.127: 100%|██████████| 482/482 [00:02<00:00, 235.92it/s]
val loss: 0.147: 100%|██████████| 121/121 [00:00<00:00, 839.27it/s]


Weight saved: epoch 45
Epoch 45	Train Loss: 0.145  Val Loss: 0.151



train loss: 0.130: 100%|██████████| 482/482 [00:02<00:00, 234.79it/s]
val loss: 0.147: 100%|██████████| 121/121 [00:00<00:00, 860.58it/s]


Weight saved: epoch 46
Epoch 46	Train Loss: 0.144  Val Loss: 0.150



train loss: 0.154: 100%|██████████| 482/482 [00:02<00:00, 236.85it/s]
val loss: 0.146: 100%|██████████| 121/121 [00:00<00:00, 861.12it/s]


Weight saved: epoch 47
Epoch 47	Train Loss: 0.144  Val Loss: 0.149



train loss: 0.164: 100%|██████████| 482/482 [00:02<00:00, 237.67it/s]
val loss: 0.145: 100%|██████████| 121/121 [00:00<00:00, 793.03it/s]


Weight saved: epoch 48
Epoch 48	Train Loss: 0.143  Val Loss: 0.149



train loss: 0.131: 100%|██████████| 482/482 [00:02<00:00, 240.61it/s]
val loss: 0.144: 100%|██████████| 121/121 [00:00<00:00, 806.55it/s]


Weight saved: epoch 49
Epoch 49	Train Loss: 0.142  Val Loss: 0.149



train loss: 0.141: 100%|██████████| 482/482 [00:01<00:00, 242.95it/s]
val loss: 0.144: 100%|██████████| 121/121 [00:00<00:00, 810.52it/s]


Weight saved: epoch 50
Epoch 50	Train Loss: 0.141  Val Loss: 0.148



train loss: 0.140: 100%|██████████| 482/482 [00:02<00:00, 238.20it/s]
val loss: 0.143: 100%|██████████| 121/121 [00:00<00:00, 878.25it/s]


Weight saved: epoch 51
Epoch 51	Train Loss: 0.140  Val Loss: 0.147



train loss: 0.139: 100%|██████████| 482/482 [00:01<00:00, 241.64it/s]
val loss: 0.142: 100%|██████████| 121/121 [00:00<00:00, 839.94it/s]


Weight saved: epoch 52
Epoch 52	Train Loss: 0.140  Val Loss: 0.146



train loss: 0.120: 100%|██████████| 482/482 [00:02<00:00, 239.33it/s]
val loss: 0.142: 100%|██████████| 121/121 [00:00<00:00, 829.25it/s]


Weight saved: epoch 53
Epoch 53	Train Loss: 0.139  Val Loss: 0.146



train loss: 0.139: 100%|██████████| 482/482 [00:02<00:00, 239.45it/s]
val loss: 0.141: 100%|██████████| 121/121 [00:00<00:00, 832.32it/s]


Weight saved: epoch 54
Epoch 54	Train Loss: 0.138  Val Loss: 0.145



train loss: 0.151: 100%|██████████| 482/482 [00:02<00:00, 238.68it/s]
val loss: 0.140: 100%|██████████| 121/121 [00:00<00:00, 822.78it/s]


Weight saved: epoch 55
Epoch 55	Train Loss: 0.137  Val Loss: 0.145



train loss: 0.158: 100%|██████████| 482/482 [00:02<00:00, 238.26it/s]
val loss: 0.140: 100%|██████████| 121/121 [00:00<00:00, 815.00it/s]


Weight saved: epoch 56
Epoch 56	Train Loss: 0.137  Val Loss: 0.144



train loss: 0.115: 100%|██████████| 482/482 [00:01<00:00, 242.06it/s]
val loss: 0.138: 100%|██████████| 121/121 [00:00<00:00, 822.11it/s]


Weight saved: epoch 57
Epoch 57	Train Loss: 0.136  Val Loss: 0.143



train loss: 0.122: 100%|██████████| 482/482 [00:02<00:00, 234.01it/s]
val loss: 0.138: 100%|██████████| 121/121 [00:00<00:00, 836.73it/s]


Weight saved: epoch 58
Epoch 58	Train Loss: 0.135  Val Loss: 0.143



train loss: 0.143: 100%|██████████| 482/482 [00:02<00:00, 236.84it/s]
val loss: 0.137: 100%|██████████| 121/121 [00:00<00:00, 749.43it/s]


Weight saved: epoch 59
Epoch 59	Train Loss: 0.135  Val Loss: 0.142



train loss: 0.138: 100%|██████████| 482/482 [00:02<00:00, 230.66it/s]
val loss: 0.137: 100%|██████████| 121/121 [00:00<00:00, 809.08it/s]


Weight saved: epoch 60
Epoch 60	Train Loss: 0.134  Val Loss: 0.142



train loss: 0.111: 100%|██████████| 482/482 [00:02<00:00, 220.89it/s]
val loss: 0.137: 100%|██████████| 121/121 [00:00<00:00, 767.77it/s]


Weight saved: epoch 61
Epoch 61	Train Loss: 0.134  Val Loss: 0.142



train loss: 0.121: 100%|██████████| 482/482 [00:02<00:00, 230.85it/s]
val loss: 0.135: 100%|██████████| 121/121 [00:00<00:00, 794.52it/s]


Weight saved: epoch 62
Epoch 62	Train Loss: 0.133  Val Loss: 0.141



train loss: 0.155: 100%|██████████| 482/482 [00:02<00:00, 230.35it/s]
val loss: 0.136: 100%|██████████| 121/121 [00:00<00:00, 763.06it/s]


Weight saved: epoch 63
Epoch 63	Train Loss: 0.133  Val Loss: 0.140



train loss: 0.129: 100%|██████████| 482/482 [00:02<00:00, 229.22it/s]
val loss: 0.136: 100%|██████████| 121/121 [00:00<00:00, 836.72it/s]


Weight saved: epoch 64
Epoch 64	Train Loss: 0.132  Val Loss: 0.140



train loss: 0.161: 100%|██████████| 482/482 [00:02<00:00, 238.16it/s]
val loss: 0.135: 100%|██████████| 121/121 [00:00<00:00, 819.67it/s]


Weight saved: epoch 65
Epoch 65	Train Loss: 0.132  Val Loss: 0.140



train loss: 0.141: 100%|██████████| 482/482 [00:02<00:00, 238.04it/s]
val loss: 0.135: 100%|██████████| 121/121 [00:00<00:00, 798.46it/s]


Weight saved: epoch 66
Epoch 66	Train Loss: 0.131  Val Loss: 0.140



train loss: 0.126: 100%|██████████| 482/482 [00:02<00:00, 237.20it/s]
val loss: 0.134: 100%|██████████| 121/121 [00:00<00:00, 715.24it/s]


Weight saved: epoch 67
Epoch 67	Train Loss: 0.131  Val Loss: 0.139



train loss: 0.132: 100%|██████████| 482/482 [00:02<00:00, 235.66it/s]
val loss: 0.135: 100%|██████████| 121/121 [00:00<00:00, 857.91it/s]


Weight saved: epoch 68
Epoch 68	Train Loss: 0.131  Val Loss: 0.139



train loss: 0.107: 100%|██████████| 482/482 [00:02<00:00, 237.46it/s]
val loss: 0.134: 100%|██████████| 121/121 [00:00<00:00, 849.15it/s]


Weight saved: epoch 69
Epoch 69	Train Loss: 0.130  Val Loss: 0.138



train loss: 0.088: 100%|██████████| 482/482 [00:02<00:00, 230.99it/s]
val loss: 0.134: 100%|██████████| 121/121 [00:00<00:00, 815.94it/s]


Epoch 70	Train Loss: 0.130  Val Loss: 0.139



train loss: 0.098: 100%|██████████| 482/482 [00:02<00:00, 237.40it/s]
val loss: 0.132: 100%|██████████| 121/121 [00:00<00:00, 764.35it/s]


Weight saved: epoch 71
Epoch 71	Train Loss: 0.129  Val Loss: 0.138



train loss: 0.133: 100%|██████████| 482/482 [00:02<00:00, 238.62it/s]
val loss: 0.133: 100%|██████████| 121/121 [00:00<00:00, 834.34it/s]


Epoch 72	Train Loss: 0.129  Val Loss: 0.138



train loss: 0.142: 100%|██████████| 482/482 [00:02<00:00, 239.83it/s]
val loss: 0.132: 100%|██████████| 121/121 [00:00<00:00, 789.32it/s]


Epoch 73	Train Loss: 0.129  Val Loss: 0.138



train loss: 0.096: 100%|██████████| 482/482 [00:02<00:00, 237.16it/s]
val loss: 0.132: 100%|██████████| 121/121 [00:00<00:00, 853.10it/s]


Weight saved: epoch 74
Epoch 74	Train Loss: 0.128  Val Loss: 0.137



train loss: 0.126: 100%|██████████| 482/482 [00:02<00:00, 237.01it/s]
val loss: 0.132: 100%|██████████| 121/121 [00:00<00:00, 811.39it/s]


Weight saved: epoch 75
Epoch 75	Train Loss: 0.128  Val Loss: 0.137



train loss: 0.102: 100%|██████████| 482/482 [00:02<00:00, 231.44it/s]
val loss: 0.132: 100%|██████████| 121/121 [00:00<00:00, 824.99it/s]


Weight saved: epoch 76
Epoch 76	Train Loss: 0.128  Val Loss: 0.137



train loss: 0.101: 100%|██████████| 482/482 [00:02<00:00, 235.72it/s]
val loss: 0.132: 100%|██████████| 121/121 [00:00<00:00, 834.18it/s]


Epoch 77	Train Loss: 0.128  Val Loss: 0.137



train loss: 0.144: 100%|██████████| 482/482 [00:02<00:00, 237.42it/s]
val loss: 0.132: 100%|██████████| 121/121 [00:00<00:00, 799.46it/s]


Weight saved: epoch 78
Epoch 78	Train Loss: 0.127  Val Loss: 0.137



train loss: 0.115: 100%|██████████| 482/482 [00:02<00:00, 235.66it/s]
val loss: 0.132: 100%|██████████| 121/121 [00:00<00:00, 892.98it/s]


Weight saved: epoch 79
Epoch 79	Train Loss: 0.127  Val Loss: 0.136



train loss: 0.165: 100%|██████████| 482/482 [00:02<00:00, 237.07it/s]
val loss: 0.132: 100%|██████████| 121/121 [00:00<00:00, 831.46it/s]


Epoch 80	Train Loss: 0.127  Val Loss: 0.137



train loss: 0.105: 100%|██████████| 482/482 [00:02<00:00, 236.60it/s]
val loss: 0.130: 100%|██████████| 121/121 [00:00<00:00, 859.51it/s]


Weight saved: epoch 81
Epoch 81	Train Loss: 0.126  Val Loss: 0.136



train loss: 0.156: 100%|██████████| 482/482 [00:02<00:00, 234.72it/s]
val loss: 0.131: 100%|██████████| 121/121 [00:00<00:00, 833.14it/s]


Epoch 82	Train Loss: 0.126  Val Loss: 0.136



train loss: 0.115: 100%|██████████| 482/482 [00:02<00:00, 235.99it/s]
val loss: 0.130: 100%|██████████| 121/121 [00:00<00:00, 849.23it/s]


Weight saved: epoch 83
Epoch 83	Train Loss: 0.126  Val Loss: 0.135



train loss: 0.132: 100%|██████████| 482/482 [00:02<00:00, 235.38it/s]
val loss: 0.130: 100%|██████████| 121/121 [00:00<00:00, 863.89it/s]


Epoch 84	Train Loss: 0.126  Val Loss: 0.135



train loss: 0.121: 100%|██████████| 482/482 [00:02<00:00, 235.36it/s]
val loss: 0.131: 100%|██████████| 121/121 [00:00<00:00, 852.26it/s]


Epoch 85	Train Loss: 0.125  Val Loss: 0.135



train loss: 0.092: 100%|██████████| 482/482 [00:02<00:00, 237.68it/s]
val loss: 0.130: 100%|██████████| 121/121 [00:00<00:00, 812.26it/s]


Weight saved: epoch 86
Epoch 86	Train Loss: 0.125  Val Loss: 0.135



train loss: 0.146: 100%|██████████| 482/482 [00:02<00:00, 235.10it/s]
val loss: 0.131: 100%|██████████| 121/121 [00:00<00:00, 882.62it/s]


Epoch 87	Train Loss: 0.125  Val Loss: 0.135



train loss: 0.120: 100%|██████████| 482/482 [00:02<00:00, 236.56it/s]
val loss: 0.130: 100%|██████████| 121/121 [00:00<00:00, 833.22it/s]


Weight saved: epoch 88
Epoch 88	Train Loss: 0.125  Val Loss: 0.134



train loss: 0.142: 100%|██████████| 482/482 [00:02<00:00, 234.77it/s]
val loss: 0.130: 100%|██████████| 121/121 [00:00<00:00, 732.43it/s]


Weight saved: epoch 89
Epoch 89	Train Loss: 0.124  Val Loss: 0.134



train loss: 0.120: 100%|██████████| 482/482 [00:02<00:00, 236.36it/s]
val loss: 0.130: 100%|██████████| 121/121 [00:00<00:00, 842.38it/s]


Epoch 90	Train Loss: 0.124  Val Loss: 0.135



train loss: 0.127: 100%|██████████| 482/482 [00:02<00:00, 234.29it/s]
val loss: 0.129: 100%|██████████| 121/121 [00:00<00:00, 786.01it/s]


Weight saved: epoch 91
Epoch 91	Train Loss: 0.124  Val Loss: 0.134



train loss: 0.104: 100%|██████████| 482/482 [00:02<00:00, 235.06it/s]
val loss: 0.130: 100%|██████████| 121/121 [00:00<00:00, 823.13it/s]


Epoch 92	Train Loss: 0.124  Val Loss: 0.134



train loss: 0.148: 100%|██████████| 482/482 [00:02<00:00, 235.24it/s]
val loss: 0.130: 100%|██████████| 121/121 [00:00<00:00, 886.22it/s]


Epoch 93	Train Loss: 0.123  Val Loss: 0.134



train loss: 0.108: 100%|██████████| 482/482 [00:02<00:00, 235.87it/s]
val loss: 0.129: 100%|██████████| 121/121 [00:00<00:00, 785.47it/s]


Weight saved: epoch 94
Epoch 94	Train Loss: 0.123  Val Loss: 0.134



train loss: 0.145: 100%|██████████| 482/482 [00:02<00:00, 238.39it/s]
val loss: 0.129: 100%|██████████| 121/121 [00:00<00:00, 863.92it/s]


Weight saved: epoch 95
Epoch 95	Train Loss: 0.123  Val Loss: 0.134



train loss: 0.103: 100%|██████████| 482/482 [00:02<00:00, 238.31it/s]
val loss: 0.129: 100%|██████████| 121/121 [00:00<00:00, 790.04it/s]


Weight saved: epoch 96
Epoch 96	Train Loss: 0.123  Val Loss: 0.133



train loss: 0.113: 100%|██████████| 482/482 [00:02<00:00, 236.72it/s]
val loss: 0.129: 100%|██████████| 121/121 [00:00<00:00, 849.64it/s]


Epoch 97	Train Loss: 0.123  Val Loss: 0.133



train loss: 0.121: 100%|██████████| 482/482 [00:02<00:00, 234.01it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 883.50it/s]


Weight saved: epoch 98
Epoch 98	Train Loss: 0.122  Val Loss: 0.133



train loss: 0.115: 100%|██████████| 482/482 [00:02<00:00, 234.50it/s]
val loss: 0.129: 100%|██████████| 121/121 [00:00<00:00, 770.27it/s]


Epoch 99	Train Loss: 0.122  Val Loss: 0.133



train loss: 0.124: 100%|██████████| 482/482 [00:02<00:00, 232.54it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 792.55it/s]


Epoch 100	Train Loss: 0.122  Val Loss: 0.133



train loss: 0.105: 100%|██████████| 482/482 [00:02<00:00, 233.80it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 842.11it/s]


Epoch 101	Train Loss: 0.122  Val Loss: 0.133



train loss: 0.128: 100%|██████████| 482/482 [00:02<00:00, 233.82it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 778.31it/s]


Epoch 102	Train Loss: 0.121  Val Loss: 0.133



train loss: 0.115: 100%|██████████| 482/482 [00:02<00:00, 234.28it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 863.14it/s]


Weight saved: epoch 103
Epoch 103	Train Loss: 0.121  Val Loss: 0.133



train loss: 0.128: 100%|██████████| 482/482 [00:02<00:00, 227.43it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 809.01it/s]


Epoch 104	Train Loss: 0.121  Val Loss: 0.133



train loss: 0.123: 100%|██████████| 482/482 [00:02<00:00, 239.95it/s]
val loss: 0.129: 100%|██████████| 121/121 [00:00<00:00, 846.18it/s]


Epoch 105	Train Loss: 0.121  Val Loss: 0.133



train loss: 0.139: 100%|██████████| 482/482 [00:01<00:00, 242.03it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 771.17it/s]


Weight saved: epoch 106
Epoch 106	Train Loss: 0.121  Val Loss: 0.132



train loss: 0.083: 100%|██████████| 482/482 [00:01<00:00, 243.26it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 874.46it/s]


Weight saved: epoch 107
Epoch 107	Train Loss: 0.121  Val Loss: 0.132



train loss: 0.130: 100%|██████████| 482/482 [00:02<00:00, 237.75it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 778.29it/s]


Epoch 108	Train Loss: 0.120  Val Loss: 0.132



train loss: 0.096: 100%|██████████| 482/482 [00:02<00:00, 231.29it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 857.23it/s]


Epoch 109	Train Loss: 0.120  Val Loss: 0.132



train loss: 0.150: 100%|██████████| 482/482 [00:02<00:00, 238.36it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 853.29it/s]


Weight saved: epoch 110
Epoch 110	Train Loss: 0.120  Val Loss: 0.132



train loss: 0.096: 100%|██████████| 482/482 [00:01<00:00, 243.11it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 885.67it/s]


Epoch 111	Train Loss: 0.120  Val Loss: 0.132



train loss: 0.148: 100%|██████████| 482/482 [00:02<00:00, 234.61it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 836.86it/s]


Weight saved: epoch 112
Epoch 112	Train Loss: 0.120  Val Loss: 0.131



train loss: 0.094: 100%|██████████| 482/482 [00:02<00:00, 238.20it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 850.12it/s]


Epoch 113	Train Loss: 0.119  Val Loss: 0.132



train loss: 0.085: 100%|██████████| 482/482 [00:02<00:00, 240.13it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 840.09it/s]


Epoch 114	Train Loss: 0.119  Val Loss: 0.132



train loss: 0.106: 100%|██████████| 482/482 [00:02<00:00, 238.29it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 816.07it/s]


Weight saved: epoch 115
Epoch 115	Train Loss: 0.119  Val Loss: 0.131



train loss: 0.106: 100%|██████████| 482/482 [00:02<00:00, 237.94it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 840.94it/s]


Weight saved: epoch 116
Epoch 116	Train Loss: 0.119  Val Loss: 0.131



train loss: 0.127: 100%|██████████| 482/482 [00:02<00:00, 237.92it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 794.86it/s]


Epoch 117	Train Loss: 0.119  Val Loss: 0.132



train loss: 0.093: 100%|██████████| 482/482 [00:02<00:00, 234.33it/s]
val loss: 0.128: 100%|██████████| 121/121 [00:00<00:00, 806.60it/s]


Epoch 118	Train Loss: 0.119  Val Loss: 0.132



train loss: 0.102: 100%|██████████| 482/482 [00:02<00:00, 231.65it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 844.81it/s]


Epoch 119	Train Loss: 0.119  Val Loss: 0.131



train loss: 0.161: 100%|██████████| 482/482 [00:02<00:00, 240.48it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 821.02it/s]


Epoch 120	Train Loss: 0.118  Val Loss: 0.131



train loss: 0.110: 100%|██████████| 482/482 [00:02<00:00, 234.94it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 860.25it/s]


Epoch 121	Train Loss: 0.118  Val Loss: 0.131



train loss: 0.110: 100%|██████████| 482/482 [00:02<00:00, 237.39it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 819.90it/s]


Weight saved: epoch 122
Epoch 122	Train Loss: 0.118  Val Loss: 0.131



train loss: 0.136: 100%|██████████| 482/482 [00:02<00:00, 230.62it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 816.80it/s]


Weight saved: epoch 123
Epoch 123	Train Loss: 0.118  Val Loss: 0.131



train loss: 0.126: 100%|██████████| 482/482 [00:01<00:00, 242.81it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 864.38it/s]


Weight saved: epoch 124
Epoch 124	Train Loss: 0.118  Val Loss: 0.131



train loss: 0.129: 100%|██████████| 482/482 [00:01<00:00, 241.65it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 756.87it/s]


Epoch 125	Train Loss: 0.118  Val Loss: 0.131



train loss: 0.113: 100%|██████████| 482/482 [00:02<00:00, 235.36it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 878.26it/s]


Epoch 126	Train Loss: 0.117  Val Loss: 0.131



train loss: 0.119: 100%|██████████| 482/482 [00:02<00:00, 230.57it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 779.15it/s]


Weight saved: epoch 127
Epoch 127	Train Loss: 0.117  Val Loss: 0.131



train loss: 0.101: 100%|██████████| 482/482 [00:02<00:00, 237.49it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 832.76it/s]


Weight saved: epoch 128
Epoch 128	Train Loss: 0.117  Val Loss: 0.130



train loss: 0.108: 100%|██████████| 482/482 [00:02<00:00, 238.69it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 846.69it/s]


Epoch 129	Train Loss: 0.117  Val Loss: 0.130



train loss: 0.104: 100%|██████████| 482/482 [00:02<00:00, 240.53it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 845.84it/s]


Epoch 130	Train Loss: 0.117  Val Loss: 0.131



train loss: 0.106: 100%|██████████| 482/482 [00:02<00:00, 239.93it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 832.83it/s]


Epoch 131	Train Loss: 0.117  Val Loss: 0.130



train loss: 0.115: 100%|██████████| 482/482 [00:02<00:00, 233.45it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 832.07it/s]


Epoch 132	Train Loss: 0.117  Val Loss: 0.130



train loss: 0.130: 100%|██████████| 482/482 [00:02<00:00, 237.43it/s]
val loss: 0.125: 100%|██████████| 121/121 [00:00<00:00, 798.00it/s]


Weight saved: epoch 133
Epoch 133	Train Loss: 0.116  Val Loss: 0.129



train loss: 0.127: 100%|██████████| 482/482 [00:02<00:00, 238.01it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 802.06it/s]


Epoch 134	Train Loss: 0.116  Val Loss: 0.130



train loss: 0.104: 100%|██████████| 482/482 [00:02<00:00, 238.25it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 818.19it/s]


Epoch 135	Train Loss: 0.116  Val Loss: 0.130



train loss: 0.104: 100%|██████████| 482/482 [00:02<00:00, 239.15it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 790.90it/s]


Epoch 136	Train Loss: 0.116  Val Loss: 0.130



train loss: 0.124: 100%|██████████| 482/482 [00:02<00:00, 239.03it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 792.81it/s]


Epoch 137	Train Loss: 0.116  Val Loss: 0.130



train loss: 0.122: 100%|██████████| 482/482 [00:02<00:00, 234.69it/s]
val loss: 0.127: 100%|██████████| 121/121 [00:00<00:00, 808.63it/s]


Epoch 138	Train Loss: 0.116  Val Loss: 0.130



train loss: 0.090: 100%|██████████| 482/482 [00:02<00:00, 238.35it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 876.27it/s]


Epoch 139	Train Loss: 0.116  Val Loss: 0.130



train loss: 0.148: 100%|██████████| 482/482 [00:02<00:00, 233.81it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 812.58it/s]


Epoch 140	Train Loss: 0.116  Val Loss: 0.130



train loss: 0.126: 100%|██████████| 482/482 [00:02<00:00, 236.86it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 801.70it/s]


Epoch 141	Train Loss: 0.116  Val Loss: 0.130



train loss: 0.102: 100%|██████████| 482/482 [00:02<00:00, 237.58it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 875.98it/s]


Epoch 142	Train Loss: 0.115  Val Loss: 0.130



train loss: 0.137: 100%|██████████| 482/482 [00:02<00:00, 236.37it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 813.60it/s]


Epoch 143	Train Loss: 0.115  Val Loss: 0.130



train loss: 0.149: 100%|██████████| 482/482 [00:02<00:00, 230.68it/s]
val loss: 0.125: 100%|██████████| 121/121 [00:00<00:00, 816.85it/s]


Epoch 144	Train Loss: 0.115  Val Loss: 0.130



train loss: 0.090: 100%|██████████| 482/482 [00:02<00:00, 231.38it/s]
val loss: 0.125: 100%|██████████| 121/121 [00:00<00:00, 765.51it/s]


Weight saved: epoch 145
Epoch 145	Train Loss: 0.115  Val Loss: 0.129



train loss: 0.111: 100%|██████████| 482/482 [00:02<00:00, 234.87it/s]
val loss: 0.125: 100%|██████████| 121/121 [00:00<00:00, 823.42it/s]


Epoch 146	Train Loss: 0.115  Val Loss: 0.130



train loss: 0.128: 100%|██████████| 482/482 [00:02<00:00, 237.65it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 842.59it/s]


Epoch 147	Train Loss: 0.115  Val Loss: 0.130



train loss: 0.106: 100%|██████████| 482/482 [00:02<00:00, 239.25it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 849.94it/s]


Epoch 148	Train Loss: 0.115  Val Loss: 0.130



train loss: 0.098: 100%|██████████| 482/482 [00:02<00:00, 238.82it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 787.88it/s]


Epoch 149	Train Loss: 0.115  Val Loss: 0.130



train loss: 0.115: 100%|██████████| 482/482 [00:02<00:00, 234.51it/s]
val loss: 0.126: 100%|██████████| 121/121 [00:00<00:00, 808.73it/s]


Epoch 150	Train Loss: 0.114  Val Loss: 0.130



train loss: 0.130:  76%|███████▌  | 367/482 [00:01<00:00, 236.32it/s]


KeyboardInterrupt: 

In [24]:
model.eval()
with torch.no_grad():
    recon = model(val_data.to(device))

criterion = torch.nn.MSELoss(reduction='sum')
criterion(val_data.to(device), recon) / 15400

tensor(0.1290, device='cuda:0')