## Setup model and dataset

In [1]:
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import pandas as pd
from tqdm import tqdm

In [2]:
class COVID19Dataset(Dataset):
    '''
    x: Features.
    y: Targets, if none, do prediction.
    '''
    def __init__(self, x, y=None):
        if y is None:
            self.y = y
        else:
            self.y = torch.FloatTensor(y)
        self.x = torch.FloatTensor(x)

    def __getitem__(self, idx):
        if self.y is None:
            return self.x[idx]
        else:
            return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

In [3]:
# setup model
class SimpleMLP(nn.Module):
    """
    简单的MLP模型
    """
    def __init__(self, input_dim, hidden_layers_dim=[64, 32, 8]):
        super().__init__()
        self.layers = []
        for i in range(len(hidden_layers_dim)):
            if i == 0: 
                self.layers.append(nn.Linear(input_dim, hidden_layers_dim[i]))
            else: 
                self.layers.append(nn.Linear(hidden_layers_dim[i-1], hidden_layers_dim[i]))
            self.layers.append(nn.ReLU())
        self.layers.append(nn.Linear(hidden_layers_dim[-1], 1))
        self.layers = nn.Sequential(*self.layers)
        
    def forward(self, x):
        x = self.layers(x)  # [B, 1]
        x = x.squeeze(1)
        return x

In [4]:
def same_seed(seed): 
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_valid_split(data_set, valid_ratio, seed):
    '''Split provided training data into training set and validation set'''
    valid_set_size = int(valid_ratio * len(data_set)) 
    train_set_size = len(data_set) - valid_set_size
    train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_set), np.array(valid_set)

def select_feat(train_data, valid_data, select_all=False):
    '''Selects useful features to perform regression'''
    y_train, y_valid = train_data[:,-1], valid_data[:,-1]
    raw_x_train, raw_x_valid = train_data[:,:-1], valid_data[:,:-1]

    if select_all:
        feat_idx = list(range(raw_x_train.shape[1]))
    else:
        # [35:length - 1]
        feat_idx = list(range(35, raw_x_train.shape[1])) # TODO: Select suitable feature columns.
        
    return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], y_train, y_valid

In [5]:
# config
config = {
    'seed': 52,      # Your seed number, you can pick your lucky number. :)
    'select_all': False,   # Whether to use all features.
    'valid_ratio': 0.1,   # validation_size = train_size * valid_ratio
    'n_epochs': 1000,     # Number of epochs.            
    'batch_size': 256, 
    'learning_rate': 1e-5,              
    'weight_decay': 1e-5,  # L2 Regularization (weight decay).
    'early_stop': 600,    # If model has not improved for this many consecutive epochs, stop training.     
    'save_path': './models/model.ckpt'  # Your model will be saved here.
}

In [6]:
same_seed(config['seed'])
train_data = pd.read_csv('datasets/covid/covid_train.csv').values
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])
x_train, x_valid, y_train, y_valid = select_feat(train_data, valid_data, config['select_all'])
print(f"x_train size: {x_train.shape}, y_train size: {y_train.shape}")
train_dataset, valid_dataset = COVID19Dataset(x_train, y_train), COVID19Dataset(x_valid, y_valid)
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)

x_train size: (2709, 53), y_train size: (2709,)


In [7]:
# train function
def train(num_epoch, model, optimizer, criterion, train_loader, val_loader, device, model_path, lr_scheduler=None):
    best_loss = float('inf')
    for epoch in range(num_epoch):
        train_loss = 0.0
        val_loss = 0.0
        
        # training
        model.train() # set the model to training mode
        for i, batch in enumerate(tqdm(train_loader)):
            features, labels = batch
            features = features.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # validation
        model.eval() # set the model to evaluation mode
        with torch.no_grad():
            for i, batch in enumerate(tqdm(val_loader)):
                features, labels = batch
                features = features.to(device)
                labels = labels.to(device)
                outputs = model(features)

                loss = criterion(outputs, labels)

                val_loss += loss.item()

        print(f'[{epoch+1:03d}/{num_epoch:03d}] Train Loss: {train_loss/len(train_loader):3.5f} | Val loss: {val_loss/len(val_loader):3.5f}')

        # if the model improves, save a checkpoint at this epoch
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), model_path)
            print(f'saving model with loss {val_loss/len(val_loader):.5f}')

        if lr_scheduler:
            lr_scheduler.step()

In [8]:
input_dim = x_train.shape[1]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Define loss as a function

In [9]:
def mse_loss(pred, target):
    return torch.mean((pred - target) ** 2)

In [11]:
criterion_1 = mse_loss
model = SimpleMLP(input_dim=input_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
train(100, model, optimizer, criterion_1, train_loader, valid_loader, device, 'models/simple_mlp.pt')

100%|██████████| 11/11 [00:00<00:00, 277.77it/s]
100%|██████████| 2/2 [00:00<00:00, 499.95it/s]


[001/100] Train Loss: 393.93820 | Val loss: 348.84509
saving model with loss 348.84509


100%|██████████| 11/11 [00:00<00:00, 318.31it/s]
100%|██████████| 2/2 [00:00<00:00, 499.92it/s]


[002/100] Train Loss: 391.97064 | Val loss: 416.79640


100%|██████████| 11/11 [00:00<00:00, 327.50it/s]
100%|██████████| 2/2 [00:00<00:00, 666.08it/s]


[003/100] Train Loss: 389.82393 | Val loss: 351.57544


100%|██████████| 11/11 [00:00<00:00, 333.08it/s]
100%|██████████| 2/2 [00:00<00:00, 669.00it/s]


[004/100] Train Loss: 391.11426 | Val loss: 397.98251


100%|██████████| 11/11 [00:00<00:00, 333.34it/s]
100%|██████████| 2/2 [00:00<00:00, 436.13it/s]


[005/100] Train Loss: 387.85482 | Val loss: 390.87057


100%|██████████| 11/11 [00:00<00:00, 307.90it/s]
100%|██████████| 2/2 [00:00<00:00, 666.82it/s]


[006/100] Train Loss: 385.55175 | Val loss: 370.32950


100%|██████████| 11/11 [00:00<00:00, 293.20it/s]
100%|██████████| 2/2 [00:00<00:00, 666.61it/s]


[007/100] Train Loss: 385.65743 | Val loss: 404.88521


100%|██████████| 11/11 [00:00<00:00, 233.60it/s]
100%|██████████| 2/2 [00:00<00:00, 500.42it/s]


[008/100] Train Loss: 380.97885 | Val loss: 441.67381


100%|██████████| 11/11 [00:00<00:00, 292.69it/s]
100%|██████████| 2/2 [00:00<00:00, 500.16it/s]


[009/100] Train Loss: 385.49934 | Val loss: 346.06906
saving model with loss 346.06906


100%|██████████| 11/11 [00:00<00:00, 314.29it/s]
100%|██████████| 2/2 [00:00<00:00, 488.28it/s]


[010/100] Train Loss: 381.44677 | Val loss: 361.87415


100%|██████████| 11/11 [00:00<00:00, 309.02it/s]
100%|██████████| 2/2 [00:00<00:00, 666.71it/s]


[011/100] Train Loss: 381.30715 | Val loss: 347.59984


100%|██████████| 11/11 [00:00<00:00, 337.01it/s]
100%|██████████| 2/2 [00:00<00:00, 429.61it/s]


[012/100] Train Loss: 379.69528 | Val loss: 351.37102


100%|██████████| 11/11 [00:00<00:00, 206.43it/s]
100%|██████████| 2/2 [00:00<00:00, 500.01it/s]


[013/100] Train Loss: 379.40952 | Val loss: 365.36604


100%|██████████| 11/11 [00:00<00:00, 229.18it/s]
100%|██████████| 2/2 [00:00<00:00, 325.62it/s]


[014/100] Train Loss: 378.88252 | Val loss: 376.55412


100%|██████████| 11/11 [00:00<00:00, 309.36it/s]
100%|██████████| 2/2 [00:00<00:00, 499.92it/s]


[015/100] Train Loss: 376.54577 | Val loss: 397.49005


100%|██████████| 11/11 [00:00<00:00, 371.37it/s]
100%|██████████| 2/2 [00:00<00:00, 500.30it/s]


[016/100] Train Loss: 379.75498 | Val loss: 416.24019


100%|██████████| 11/11 [00:00<00:00, 308.90it/s]
100%|██████████| 2/2 [00:00<00:00, 666.61it/s]


[017/100] Train Loss: 376.30633 | Val loss: 408.62271


100%|██████████| 11/11 [00:00<00:00, 340.46it/s]
100%|██████████| 2/2 [00:00<00:00, 499.71it/s]


[018/100] Train Loss: 376.30938 | Val loss: 340.70827
saving model with loss 340.70827


100%|██████████| 11/11 [00:00<00:00, 343.76it/s]
100%|██████████| 2/2 [00:00<00:00, 666.71it/s]


[019/100] Train Loss: 374.37503 | Val loss: 360.13289


100%|██████████| 11/11 [00:00<00:00, 337.97it/s]
100%|██████████| 2/2 [00:00<00:00, 499.68it/s]


[020/100] Train Loss: 373.14528 | Val loss: 380.61720


100%|██████████| 11/11 [00:00<00:00, 341.74it/s]
100%|██████████| 2/2 [00:00<00:00, 669.96it/s]


[021/100] Train Loss: 373.27336 | Val loss: 419.03021


100%|██████████| 11/11 [00:00<00:00, 333.34it/s]
100%|██████████| 2/2 [00:00<00:00, 666.77it/s]


[022/100] Train Loss: 373.87718 | Val loss: 366.86935


100%|██████████| 11/11 [00:00<00:00, 297.32it/s]
100%|██████████| 2/2 [00:00<00:00, 667.09it/s]


[023/100] Train Loss: 370.49945 | Val loss: 380.25266


100%|██████████| 11/11 [00:00<00:00, 327.86it/s]
100%|██████████| 2/2 [00:00<00:00, 666.77it/s]


[024/100] Train Loss: 371.84031 | Val loss: 336.38823
saving model with loss 336.38823


100%|██████████| 11/11 [00:00<00:00, 293.33it/s]
100%|██████████| 2/2 [00:00<00:00, 499.98it/s]


[025/100] Train Loss: 368.44945 | Val loss: 357.11424


100%|██████████| 11/11 [00:00<00:00, 314.28it/s]
100%|██████████| 2/2 [00:00<00:00, 666.82it/s]


[026/100] Train Loss: 370.20652 | Val loss: 373.85757


100%|██████████| 11/11 [00:00<00:00, 366.70it/s]
100%|██████████| 2/2 [00:00<00:00, 666.56it/s]


[027/100] Train Loss: 366.30827 | Val loss: 351.35353


100%|██████████| 11/11 [00:00<00:00, 348.03it/s]
100%|██████████| 2/2 [00:00<00:00, 499.80it/s]


[028/100] Train Loss: 365.59146 | Val loss: 317.73834
saving model with loss 317.73834


100%|██████████| 11/11 [00:00<00:00, 274.99it/s]
100%|██████████| 2/2 [00:00<00:00, 285.76it/s]


[029/100] Train Loss: 365.04557 | Val loss: 399.39714


100%|██████████| 11/11 [00:00<00:00, 176.35it/s]
100%|██████████| 2/2 [00:00<00:00, 499.68it/s]


[030/100] Train Loss: 366.43323 | Val loss: 358.07520


100%|██████████| 11/11 [00:00<00:00, 235.86it/s]
100%|██████████| 2/2 [00:00<00:00, 500.30it/s]


[031/100] Train Loss: 363.05856 | Val loss: 350.07817


100%|██████████| 11/11 [00:00<00:00, 323.54it/s]
100%|██████████| 2/2 [00:00<00:00, 500.07it/s]


[032/100] Train Loss: 361.04658 | Val loss: 338.68642


100%|██████████| 11/11 [00:00<00:00, 328.23it/s]
100%|██████████| 2/2 [00:00<00:00, 666.40it/s]


[033/100] Train Loss: 360.28076 | Val loss: 355.18509


100%|██████████| 11/11 [00:00<00:00, 346.89it/s]
100%|██████████| 2/2 [00:00<00:00, 499.89it/s]


[034/100] Train Loss: 358.51991 | Val loss: 326.03290


100%|██████████| 11/11 [00:00<00:00, 343.75it/s]
100%|██████████| 2/2 [00:00<00:00, 666.87it/s]


[035/100] Train Loss: 358.49329 | Val loss: 340.48126


100%|██████████| 11/11 [00:00<00:00, 323.52it/s]
100%|██████████| 2/2 [00:00<00:00, 1001.15it/s]


[036/100] Train Loss: 355.80270 | Val loss: 364.86137


100%|██████████| 11/11 [00:00<00:00, 337.79it/s]
100%|██████████| 2/2 [00:00<00:00, 500.04it/s]


[037/100] Train Loss: 348.09075 | Val loss: 350.41275


100%|██████████| 11/11 [00:00<00:00, 358.96it/s]
100%|██████████| 2/2 [00:00<00:00, 500.13it/s]


[038/100] Train Loss: 347.45669 | Val loss: 353.35216


100%|██████████| 11/11 [00:00<00:00, 334.03it/s]
100%|██████████| 2/2 [00:00<00:00, 666.40it/s]


[039/100] Train Loss: 342.10945 | Val loss: 296.73550
saving model with loss 296.73550


100%|██████████| 11/11 [00:00<00:00, 189.66it/s]
100%|██████████| 2/2 [00:00<00:00, 500.57it/s]


[040/100] Train Loss: 334.05368 | Val loss: 314.62906


100%|██████████| 11/11 [00:00<00:00, 282.07it/s]
100%|██████████| 2/2 [00:00<00:00, 353.13it/s]


[041/100] Train Loss: 330.82802 | Val loss: 326.30276


100%|██████████| 11/11 [00:00<00:00, 112.17it/s]
100%|██████████| 2/2 [00:00<00:00, 666.66it/s]


[042/100] Train Loss: 324.73662 | Val loss: 307.49913


100%|██████████| 11/11 [00:00<00:00, 386.02it/s]
100%|██████████| 2/2 [00:00<00:00, 667.46it/s]


[043/100] Train Loss: 321.00876 | Val loss: 325.86070


100%|██████████| 11/11 [00:00<00:00, 366.67it/s]
100%|██████████| 2/2 [00:00<00:00, 666.77it/s]


[044/100] Train Loss: 317.79915 | Val loss: 325.62491


100%|██████████| 11/11 [00:00<00:00, 347.26it/s]
100%|██████████| 2/2 [00:00<00:00, 666.82it/s]


[045/100] Train Loss: 314.12261 | Val loss: 271.96263
saving model with loss 271.96263


100%|██████████| 11/11 [00:00<00:00, 301.26it/s]
100%|██████████| 2/2 [00:00<00:00, 500.13it/s]


[046/100] Train Loss: 310.08111 | Val loss: 318.44046


100%|██████████| 11/11 [00:00<00:00, 305.57it/s]
100%|██████████| 2/2 [00:00<00:00, 499.98it/s]


[047/100] Train Loss: 307.06041 | Val loss: 280.30881


100%|██████████| 11/11 [00:00<00:00, 323.55it/s]
100%|██████████| 2/2 [00:00<00:00, 666.66it/s]


[048/100] Train Loss: 305.73434 | Val loss: 298.27359


100%|██████████| 11/11 [00:00<00:00, 323.56it/s]
100%|██████████| 2/2 [00:00<00:00, 399.82it/s]


[049/100] Train Loss: 299.16126 | Val loss: 255.34028
saving model with loss 255.34028


100%|██████████| 11/11 [00:00<00:00, 297.31it/s]
100%|██████████| 2/2 [00:00<00:00, 666.77it/s]


[050/100] Train Loss: 294.50166 | Val loss: 313.80075


100%|██████████| 11/11 [00:00<00:00, 313.39it/s]
100%|██████████| 2/2 [00:00<00:00, 666.87it/s]


[051/100] Train Loss: 292.30799 | Val loss: 307.34181


100%|██████████| 11/11 [00:00<00:00, 333.34it/s]
100%|██████████| 2/2 [00:00<00:00, 667.14it/s]


[052/100] Train Loss: 287.44659 | Val loss: 263.88911


100%|██████████| 11/11 [00:00<00:00, 233.30it/s]
100%|██████████| 2/2 [00:00<00:00, 399.90it/s]


[053/100] Train Loss: 284.90030 | Val loss: 234.47654
saving model with loss 234.47654


100%|██████████| 11/11 [00:00<00:00, 337.16it/s]
100%|██████████| 2/2 [00:00<00:00, 519.87it/s]


[054/100] Train Loss: 280.94327 | Val loss: 255.97981


100%|██████████| 11/11 [00:00<00:00, 323.54it/s]
100%|██████████| 2/2 [00:00<00:00, 668.36it/s]


[055/100] Train Loss: 276.67465 | Val loss: 276.59926


100%|██████████| 11/11 [00:00<00:00, 318.06it/s]
100%|██████████| 2/2 [00:00<00:00, 666.87it/s]


[056/100] Train Loss: 274.65070 | Val loss: 244.83145


100%|██████████| 11/11 [00:00<00:00, 343.76it/s]
100%|██████████| 2/2 [00:00<00:00, 666.82it/s]


[057/100] Train Loss: 269.72358 | Val loss: 255.17815


100%|██████████| 11/11 [00:00<00:00, 333.33it/s]
100%|██████████| 2/2 [00:00<00:00, 500.30it/s]


[058/100] Train Loss: 265.14514 | Val loss: 262.79297


100%|██████████| 11/11 [00:00<00:00, 333.33it/s]
100%|██████████| 2/2 [00:00<00:00, 666.71it/s]


[059/100] Train Loss: 261.41701 | Val loss: 253.19979


100%|██████████| 11/11 [00:00<00:00, 351.16it/s]
100%|██████████| 2/2 [00:00<00:00, 649.02it/s]


[060/100] Train Loss: 255.97711 | Val loss: 230.77279
saving model with loss 230.77279


100%|██████████| 11/11 [00:00<00:00, 346.02it/s]
100%|██████████| 2/2 [00:00<00:00, 500.39it/s]


[061/100] Train Loss: 253.48066 | Val loss: 250.52488


100%|██████████| 11/11 [00:00<00:00, 343.75it/s]
100%|██████████| 2/2 [00:00<00:00, 666.93it/s]


[062/100] Train Loss: 249.36894 | Val loss: 213.04383
saving model with loss 213.04383


100%|██████████| 11/11 [00:00<00:00, 333.34it/s]
100%|██████████| 2/2 [00:00<00:00, 666.82it/s]


[063/100] Train Loss: 247.73480 | Val loss: 232.82159


100%|██████████| 11/11 [00:00<00:00, 321.02it/s]
100%|██████████| 2/2 [00:00<00:00, 666.82it/s]


[064/100] Train Loss: 242.07182 | Val loss: 231.01409


100%|██████████| 11/11 [00:00<00:00, 335.44it/s]
100%|██████████| 2/2 [00:00<00:00, 499.71it/s]


[065/100] Train Loss: 239.07569 | Val loss: 283.84110


100%|██████████| 11/11 [00:00<00:00, 338.38it/s]
100%|██████████| 2/2 [00:00<00:00, 666.87it/s]


[066/100] Train Loss: 234.19516 | Val loss: 236.70444


100%|██████████| 11/11 [00:00<00:00, 305.55it/s]
100%|██████████| 2/2 [00:00<00:00, 500.04it/s]


[067/100] Train Loss: 230.64863 | Val loss: 254.35416


100%|██████████| 11/11 [00:00<00:00, 314.29it/s]
100%|██████████| 2/2 [00:00<00:00, 666.87it/s]


[068/100] Train Loss: 227.78776 | Val loss: 225.83092


100%|██████████| 11/11 [00:00<00:00, 333.35it/s]
100%|██████████| 2/2 [00:00<00:00, 439.95it/s]


[069/100] Train Loss: 223.56452 | Val loss: 223.31838


100%|██████████| 11/11 [00:00<00:00, 305.55it/s]
100%|██████████| 2/2 [00:00<00:00, 500.13it/s]


[070/100] Train Loss: 221.51402 | Val loss: 202.82999
saving model with loss 202.82999


100%|██████████| 11/11 [00:00<00:00, 322.47it/s]
100%|██████████| 2/2 [00:00<00:00, 499.98it/s]


[071/100] Train Loss: 217.82715 | Val loss: 239.60362


100%|██████████| 11/11 [00:00<00:00, 314.27it/s]
100%|██████████| 2/2 [00:00<00:00, 666.98it/s]


[072/100] Train Loss: 213.75694 | Val loss: 211.01862


100%|██████████| 11/11 [00:00<00:00, 323.50it/s]
100%|██████████| 2/2 [00:00<00:00, 500.22it/s]


[073/100] Train Loss: 209.44803 | Val loss: 174.92123
saving model with loss 174.92123


100%|██████████| 11/11 [00:00<00:00, 301.10it/s]
100%|██████████| 2/2 [00:00<00:00, 399.93it/s]


[074/100] Train Loss: 208.63828 | Val loss: 179.12813


100%|██████████| 11/11 [00:00<00:00, 326.19it/s]
100%|██████████| 2/2 [00:00<00:00, 667.09it/s]


[075/100] Train Loss: 204.00215 | Val loss: 225.40251


100%|██████████| 11/11 [00:00<00:00, 327.80it/s]
100%|██████████| 2/2 [00:00<00:00, 499.92it/s]


[076/100] Train Loss: 201.18719 | Val loss: 175.64866


100%|██████████| 11/11 [00:00<00:00, 308.48it/s]
100%|██████████| 2/2 [00:00<00:00, 500.01it/s]


[077/100] Train Loss: 197.24863 | Val loss: 188.76631


100%|██████████| 11/11 [00:00<00:00, 340.87it/s]
100%|██████████| 2/2 [00:00<00:00, 499.98it/s]


[078/100] Train Loss: 194.22184 | Val loss: 218.30555


100%|██████████| 11/11 [00:00<00:00, 350.72it/s]
100%|██████████| 2/2 [00:00<00:00, 500.19it/s]


[079/100] Train Loss: 192.28506 | Val loss: 176.44527


100%|██████████| 11/11 [00:00<00:00, 359.96it/s]
100%|██████████| 2/2 [00:00<00:00, 666.71it/s]


[080/100] Train Loss: 188.38587 | Val loss: 166.60550
saving model with loss 166.60550


100%|██████████| 11/11 [00:00<00:00, 333.36it/s]
100%|██████████| 2/2 [00:00<00:00, 499.86it/s]


[081/100] Train Loss: 187.66266 | Val loss: 186.13226


100%|██████████| 11/11 [00:00<00:00, 343.74it/s]
100%|██████████| 2/2 [00:00<00:00, 666.98it/s]


[082/100] Train Loss: 183.46983 | Val loss: 195.56347


100%|██████████| 11/11 [00:00<00:00, 348.24it/s]
100%|██████████| 2/2 [00:00<00:00, 666.66it/s]


[083/100] Train Loss: 182.54137 | Val loss: 182.25615


100%|██████████| 11/11 [00:00<00:00, 337.04it/s]
100%|██████████| 2/2 [00:00<00:00, 499.86it/s]


[084/100] Train Loss: 177.68493 | Val loss: 138.49916
saving model with loss 138.49916


100%|██████████| 11/11 [00:00<00:00, 246.82it/s]
100%|██████████| 2/2 [00:00<00:00, 500.10it/s]


[085/100] Train Loss: 177.03505 | Val loss: 161.17171


100%|██████████| 11/11 [00:00<00:00, 314.29it/s]
100%|██████████| 2/2 [00:00<00:00, 666.66it/s]


[086/100] Train Loss: 174.91446 | Val loss: 168.45957


100%|██████████| 11/11 [00:00<00:00, 342.70it/s]
100%|██████████| 2/2 [00:00<00:00, 548.20it/s]


[087/100] Train Loss: 171.49157 | Val loss: 166.04636


100%|██████████| 11/11 [00:00<00:00, 328.91it/s]
100%|██████████| 2/2 [00:00<00:00, 632.20it/s]


[088/100] Train Loss: 169.51838 | Val loss: 214.27786


100%|██████████| 11/11 [00:00<00:00, 320.28it/s]
100%|██████████| 2/2 [00:00<00:00, 468.38it/s]


[089/100] Train Loss: 167.46731 | Val loss: 164.42302


100%|██████████| 11/11 [00:00<00:00, 352.54it/s]
100%|██████████| 2/2 [00:00<00:00, 667.09it/s]


[090/100] Train Loss: 163.90127 | Val loss: 169.44572


100%|██████████| 11/11 [00:00<00:00, 343.74it/s]
100%|██████████| 2/2 [00:00<00:00, 666.77it/s]


[091/100] Train Loss: 162.71009 | Val loss: 154.40158


100%|██████████| 11/11 [00:00<00:00, 340.29it/s]
100%|██████████| 2/2 [00:00<00:00, 667.09it/s]


[092/100] Train Loss: 160.79895 | Val loss: 142.01780


100%|██████████| 11/11 [00:00<00:00, 333.36it/s]
100%|██████████| 2/2 [00:00<00:00, 667.09it/s]


[093/100] Train Loss: 158.33581 | Val loss: 141.74205


100%|██████████| 11/11 [00:00<00:00, 328.20it/s]
100%|██████████| 2/2 [00:00<00:00, 499.98it/s]


[094/100] Train Loss: 156.33525 | Val loss: 158.49014


100%|██████████| 11/11 [00:00<00:00, 305.57it/s]
100%|██████████| 2/2 [00:00<00:00, 500.01it/s]


[095/100] Train Loss: 153.14508 | Val loss: 180.98647


100%|██████████| 11/11 [00:00<00:00, 196.42it/s]
100%|██████████| 2/2 [00:00<00:00, 400.03it/s]


[096/100] Train Loss: 151.53591 | Val loss: 160.80691


100%|██████████| 11/11 [00:00<00:00, 234.06it/s]
100%|██████████| 2/2 [00:00<00:00, 499.92it/s]


[097/100] Train Loss: 149.08411 | Val loss: 143.84098


100%|██████████| 11/11 [00:00<00:00, 333.33it/s]
100%|██████████| 2/2 [00:00<00:00, 499.83it/s]


[098/100] Train Loss: 146.60638 | Val loss: 144.35190


100%|██████████| 11/11 [00:00<00:00, 282.07it/s]
100%|██████████| 2/2 [00:00<00:00, 399.88it/s]


[099/100] Train Loss: 145.18762 | Val loss: 143.07728


100%|██████████| 11/11 [00:00<00:00, 318.21it/s]
100%|██████████| 2/2 [00:00<00:00, 666.61it/s]

[100/100] Train Loss: 143.60540 | Val loss: 136.26231
saving model with loss 136.26231





## Define loss as a nn.Module

In [12]:
class MSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, pred, target):
        return torch.mean((pred - target) ** 2)

In [14]:
criterion_2 = MSELoss()
model = SimpleMLP(input_dim=input_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
train(100, model, optimizer, criterion_2, train_loader, valid_loader, device, 'models/simple_mlp.pt')

100%|██████████| 11/11 [00:00<00:00, 233.97it/s]
100%|██████████| 2/2 [00:00<00:00, 400.07it/s]


[001/100] Train Loss: 415.25088 | Val loss: 355.43562
saving model with loss 355.43562


100%|██████████| 11/11 [00:00<00:00, 235.45it/s]
100%|██████████| 2/2 [00:00<00:00, 499.62it/s]


[002/100] Train Loss: 410.53878 | Val loss: 412.89938


100%|██████████| 11/11 [00:00<00:00, 297.28it/s]
100%|██████████| 2/2 [00:00<00:00, 641.13it/s]


[003/100] Train Loss: 401.92654 | Val loss: 360.08675


100%|██████████| 11/11 [00:00<00:00, 354.83it/s]
100%|██████████| 2/2 [00:00<00:00, 499.92it/s]


[004/100] Train Loss: 399.08572 | Val loss: 346.76186
saving model with loss 346.76186


100%|██████████| 11/11 [00:00<00:00, 332.01it/s]
100%|██████████| 2/2 [00:00<00:00, 665.92it/s]


[005/100] Train Loss: 390.37546 | Val loss: 372.53314


100%|██████████| 11/11 [00:00<00:00, 333.32it/s]
100%|██████████| 2/2 [00:00<00:00, 499.53it/s]


[006/100] Train Loss: 386.42104 | Val loss: 352.11662


100%|██████████| 11/11 [00:00<00:00, 349.84it/s]
100%|██████████| 2/2 [00:00<00:00, 665.92it/s]


[007/100] Train Loss: 383.71618 | Val loss: 369.05339


100%|██████████| 11/11 [00:00<00:00, 341.76it/s]
100%|██████████| 2/2 [00:00<00:00, 667.09it/s]


[008/100] Train Loss: 376.45032 | Val loss: 422.79732


100%|██████████| 11/11 [00:00<00:00, 354.01it/s]
100%|██████████| 2/2 [00:00<00:00, 500.30it/s]


[009/100] Train Loss: 370.12427 | Val loss: 323.39365
saving model with loss 323.39365


100%|██████████| 11/11 [00:00<00:00, 343.71it/s]
100%|██████████| 2/2 [00:00<00:00, 667.14it/s]


[010/100] Train Loss: 367.89822 | Val loss: 351.69429


100%|██████████| 11/11 [00:00<00:00, 107.84it/s]
100%|██████████| 2/2 [00:00<00:00, 399.80it/s]


[011/100] Train Loss: 361.17482 | Val loss: 310.90100
saving model with loss 310.90100


100%|██████████| 11/11 [00:00<00:00, 325.97it/s]
100%|██████████| 2/2 [00:00<00:00, 666.03it/s]


[012/100] Train Loss: 354.62376 | Val loss: 387.76668


100%|██████████| 11/11 [00:00<00:00, 336.51it/s]
100%|██████████| 2/2 [00:00<00:00, 666.29it/s]


[013/100] Train Loss: 347.81491 | Val loss: 329.46042


100%|██████████| 11/11 [00:00<00:00, 339.85it/s]
100%|██████████| 2/2 [00:00<00:00, 666.71it/s]


[014/100] Train Loss: 345.44505 | Val loss: 327.28943


100%|██████████| 11/11 [00:00<00:00, 314.25it/s]
100%|██████████| 2/2 [00:00<00:00, 499.98it/s]


[015/100] Train Loss: 340.60597 | Val loss: 349.16106


100%|██████████| 11/11 [00:00<00:00, 333.30it/s]
100%|██████████| 2/2 [00:00<00:00, 998.76it/s]


[016/100] Train Loss: 334.96383 | Val loss: 302.60980
saving model with loss 302.60980


100%|██████████| 11/11 [00:00<00:00, 343.25it/s]
100%|██████████| 2/2 [00:00<00:00, 665.97it/s]


[017/100] Train Loss: 332.74907 | Val loss: 332.49298


100%|██████████| 11/11 [00:00<00:00, 353.79it/s]
100%|██████████| 2/2 [00:00<00:00, 666.13it/s]


[018/100] Train Loss: 326.48591 | Val loss: 375.86018


100%|██████████| 11/11 [00:00<00:00, 333.32it/s]
100%|██████████| 2/2 [00:00<00:00, 665.97it/s]


[019/100] Train Loss: 320.94327 | Val loss: 314.80675


100%|██████████| 11/11 [00:00<00:00, 343.71it/s]
100%|██████████| 2/2 [00:00<00:00, 666.03it/s]


[020/100] Train Loss: 318.56206 | Val loss: 335.54543


100%|██████████| 11/11 [00:00<00:00, 333.34it/s]
100%|██████████| 2/2 [00:00<00:00, 499.71it/s]


[021/100] Train Loss: 314.03794 | Val loss: 302.19656
saving model with loss 302.19656


100%|██████████| 11/11 [00:00<00:00, 319.69it/s]
100%|██████████| 2/2 [00:00<00:00, 666.50it/s]


[022/100] Train Loss: 307.30606 | Val loss: 261.77443
saving model with loss 261.77443


100%|██████████| 11/11 [00:00<00:00, 183.33it/s]
100%|██████████| 2/2 [00:00<00:00, 333.40it/s]


[023/100] Train Loss: 303.99734 | Val loss: 308.63342


100%|██████████| 11/11 [00:00<00:00, 274.82it/s]
100%|██████████| 2/2 [00:00<00:00, 499.98it/s]


[024/100] Train Loss: 298.67432 | Val loss: 319.97108


100%|██████████| 11/11 [00:00<00:00, 229.17it/s]
100%|██████████| 2/2 [00:00<00:00, 666.87it/s]


[025/100] Train Loss: 296.26205 | Val loss: 282.32860


100%|██████████| 11/11 [00:00<00:00, 354.81it/s]
100%|██████████| 2/2 [00:00<00:00, 666.13it/s]


[026/100] Train Loss: 292.58768 | Val loss: 275.08389


100%|██████████| 11/11 [00:00<00:00, 343.76it/s]
100%|██████████| 2/2 [00:00<00:00, 666.29it/s]


[027/100] Train Loss: 288.93526 | Val loss: 271.97090


100%|██████████| 11/11 [00:00<00:00, 366.66it/s]
100%|██████████| 2/2 [00:00<00:00, 665.87it/s]


[028/100] Train Loss: 284.13320 | Val loss: 311.92975


100%|██████████| 11/11 [00:00<00:00, 341.60it/s]
100%|██████████| 2/2 [00:00<00:00, 665.92it/s]


[029/100] Train Loss: 279.99349 | Val loss: 275.82140


100%|██████████| 11/11 [00:00<00:00, 333.30it/s]
100%|██████████| 2/2 [00:00<00:00, 667.14it/s]


[030/100] Train Loss: 275.14777 | Val loss: 254.88579
saving model with loss 254.88579


100%|██████████| 11/11 [00:00<00:00, 340.87it/s]
100%|██████████| 2/2 [00:00<00:00, 666.82it/s]


[031/100] Train Loss: 276.09480 | Val loss: 242.12348
saving model with loss 242.12348


100%|██████████| 11/11 [00:00<00:00, 347.83it/s]
100%|██████████| 2/2 [00:00<00:00, 666.08it/s]


[032/100] Train Loss: 269.34133 | Val loss: 231.01942
saving model with loss 231.01942


100%|██████████| 11/11 [00:00<00:00, 359.07it/s]
100%|██████████| 2/2 [00:00<00:00, 665.97it/s]


[033/100] Train Loss: 265.33410 | Val loss: 264.03349


100%|██████████| 11/11 [00:00<00:00, 354.84it/s]
100%|██████████| 2/2 [00:00<00:00, 666.77it/s]


[034/100] Train Loss: 259.24427 | Val loss: 305.65553


100%|██████████| 11/11 [00:00<00:00, 354.82it/s]
100%|██████████| 2/2 [00:00<00:00, 668.68it/s]


[035/100] Train Loss: 256.13343 | Val loss: 256.65440


100%|██████████| 11/11 [00:00<00:00, 323.51it/s]
100%|██████████| 2/2 [00:00<00:00, 666.45it/s]


[036/100] Train Loss: 257.05483 | Val loss: 251.50847


100%|██████████| 11/11 [00:00<00:00, 371.57it/s]
100%|██████████| 2/2 [00:00<00:00, 666.03it/s]


[037/100] Train Loss: 249.49139 | Val loss: 259.73201


100%|██████████| 11/11 [00:00<00:00, 354.82it/s]
100%|██████████| 2/2 [00:00<00:00, 666.29it/s]


[038/100] Train Loss: 246.78289 | Val loss: 250.16164


100%|██████████| 11/11 [00:00<00:00, 282.02it/s]
100%|██████████| 2/2 [00:00<00:00, 400.07it/s]


[039/100] Train Loss: 245.43576 | Val loss: 263.01567


100%|██████████| 11/11 [00:00<00:00, 207.55it/s]
100%|██████████| 2/2 [00:00<00:00, 666.40it/s]


[040/100] Train Loss: 242.18148 | Val loss: 218.57642
saving model with loss 218.57642


100%|██████████| 11/11 [00:00<00:00, 305.54it/s]
100%|██████████| 2/2 [00:00<00:00, 666.87it/s]


[041/100] Train Loss: 238.53967 | Val loss: 223.39133


100%|██████████| 11/11 [00:00<00:00, 354.83it/s]
100%|██████████| 2/2 [00:00<00:00, 666.71it/s]


[042/100] Train Loss: 236.88574 | Val loss: 239.87509


100%|██████████| 11/11 [00:00<00:00, 354.80it/s]
100%|██████████| 2/2 [00:00<00:00, 666.98it/s]


[043/100] Train Loss: 234.87333 | Val loss: 252.45853


100%|██████████| 11/11 [00:00<00:00, 343.72it/s]
100%|██████████| 2/2 [00:00<00:00, 666.24it/s]


[044/100] Train Loss: 231.84683 | Val loss: 220.94212


100%|██████████| 11/11 [00:00<00:00, 366.66it/s]
100%|██████████| 2/2 [00:00<00:00, 666.77it/s]


[045/100] Train Loss: 230.49235 | Val loss: 223.99585


100%|██████████| 11/11 [00:00<00:00, 352.49it/s]
100%|██████████| 2/2 [00:00<00:00, 500.19it/s]


[046/100] Train Loss: 226.53855 | Val loss: 248.54328


100%|██████████| 11/11 [00:00<00:00, 333.31it/s]
100%|██████████| 2/2 [00:00<00:00, 666.40it/s]


[047/100] Train Loss: 224.58228 | Val loss: 208.28471
saving model with loss 208.28471


100%|██████████| 11/11 [00:00<00:00, 337.26it/s]
100%|██████████| 2/2 [00:00<00:00, 399.99it/s]


[048/100] Train Loss: 221.83262 | Val loss: 220.90668


100%|██████████| 11/11 [00:00<00:00, 344.51it/s]
100%|██████████| 2/2 [00:00<00:00, 666.98it/s]


[049/100] Train Loss: 218.45155 | Val loss: 269.36901


100%|██████████| 11/11 [00:00<00:00, 343.71it/s]
100%|██████████| 2/2 [00:00<00:00, 667.14it/s]


[050/100] Train Loss: 216.82826 | Val loss: 220.30611


100%|██████████| 11/11 [00:00<00:00, 337.50it/s]
100%|██████████| 2/2 [00:00<00:00, 537.01it/s]


[051/100] Train Loss: 214.23151 | Val loss: 220.07653


100%|██████████| 11/11 [00:00<00:00, 333.33it/s]
100%|██████████| 2/2 [00:00<00:00, 667.09it/s]


[052/100] Train Loss: 213.12594 | Val loss: 211.50231


100%|██████████| 11/11 [00:00<00:00, 354.79it/s]
100%|██████████| 2/2 [00:00<00:00, 499.68it/s]


[053/100] Train Loss: 209.63508 | Val loss: 216.26110


100%|██████████| 11/11 [00:00<00:00, 360.40it/s]
100%|██████████| 2/2 [00:00<00:00, 665.82it/s]


[054/100] Train Loss: 209.08576 | Val loss: 195.58816
saving model with loss 195.58816


100%|██████████| 11/11 [00:00<00:00, 343.73it/s]
100%|██████████| 2/2 [00:00<00:00, 400.09it/s]


[055/100] Train Loss: 205.49175 | Val loss: 227.97851


100%|██████████| 11/11 [00:00<00:00, 194.52it/s]
100%|██████████| 2/2 [00:00<00:00, 499.92it/s]


[056/100] Train Loss: 203.64022 | Val loss: 214.01001


100%|██████████| 11/11 [00:00<00:00, 282.06it/s]
100%|██████████| 2/2 [00:00<00:00, 499.56it/s]


[057/100] Train Loss: 198.56367 | Val loss: 190.04350
saving model with loss 190.04350


100%|██████████| 11/11 [00:00<00:00, 333.33it/s]
100%|██████████| 2/2 [00:00<00:00, 500.10it/s]


[058/100] Train Loss: 198.55096 | Val loss: 212.55478


100%|██████████| 11/11 [00:00<00:00, 331.88it/s]
100%|██████████| 2/2 [00:00<00:00, 499.53it/s]


[059/100] Train Loss: 195.81887 | Val loss: 200.49474


100%|██████████| 11/11 [00:00<00:00, 379.28it/s]
100%|██████████| 2/2 [00:00<00:00, 671.30it/s]


[060/100] Train Loss: 197.06642 | Val loss: 192.16066


100%|██████████| 11/11 [00:00<00:00, 331.76it/s]
100%|██████████| 2/2 [00:00<00:00, 550.83it/s]


[061/100] Train Loss: 193.84807 | Val loss: 168.54341
saving model with loss 168.54341


100%|██████████| 11/11 [00:00<00:00, 333.33it/s]
100%|██████████| 2/2 [00:00<00:00, 666.45it/s]


[062/100] Train Loss: 191.12086 | Val loss: 187.06903


100%|██████████| 11/11 [00:00<00:00, 354.70it/s]
100%|██████████| 2/2 [00:00<00:00, 499.83it/s]


[063/100] Train Loss: 189.51999 | Val loss: 192.86327


100%|██████████| 11/11 [00:00<00:00, 338.22it/s]
100%|██████████| 2/2 [00:00<00:00, 998.76it/s]


[064/100] Train Loss: 185.45644 | Val loss: 167.50639
saving model with loss 167.50639


100%|██████████| 11/11 [00:00<00:00, 366.47it/s]
100%|██████████| 2/2 [00:00<00:00, 666.71it/s]


[065/100] Train Loss: 185.99910 | Val loss: 152.88757
saving model with loss 152.88757


100%|██████████| 11/11 [00:00<00:00, 343.76it/s]
100%|██████████| 2/2 [00:00<00:00, 499.11it/s]


[066/100] Train Loss: 179.95446 | Val loss: 198.63578


100%|██████████| 11/11 [00:00<00:00, 366.62it/s]
100%|██████████| 2/2 [00:00<00:00, 666.13it/s]


[067/100] Train Loss: 180.98487 | Val loss: 192.29945


100%|██████████| 11/11 [00:00<00:00, 359.92it/s]
100%|██████████| 2/2 [00:00<00:00, 666.71it/s]


[068/100] Train Loss: 177.79755 | Val loss: 191.11934


100%|██████████| 11/11 [00:00<00:00, 319.95it/s]
100%|██████████| 2/2 [00:00<00:00, 666.03it/s]


[069/100] Train Loss: 175.63420 | Val loss: 161.32424


100%|██████████| 11/11 [00:00<00:00, 228.11it/s]
100%|██████████| 2/2 [00:00<00:00, 333.32it/s]


[070/100] Train Loss: 173.91065 | Val loss: 171.09642


100%|██████████| 11/11 [00:00<00:00, 213.32it/s]
100%|██████████| 2/2 [00:00<00:00, 500.10it/s]


[071/100] Train Loss: 171.90752 | Val loss: 185.02934


100%|██████████| 11/11 [00:00<00:00, 328.16it/s]
100%|██████████| 2/2 [00:00<00:00, 666.93it/s]


[072/100] Train Loss: 167.85520 | Val loss: 175.21474


100%|██████████| 11/11 [00:00<00:00, 354.85it/s]
100%|██████████| 2/2 [00:00<00:00, 666.93it/s]


[073/100] Train Loss: 167.11489 | Val loss: 185.17775


100%|██████████| 11/11 [00:00<00:00, 322.91it/s]
100%|██████████| 2/2 [00:00<00:00, 499.65it/s]


[074/100] Train Loss: 166.97223 | Val loss: 142.30447
saving model with loss 142.30447


100%|██████████| 11/11 [00:00<00:00, 353.67it/s]
100%|██████████| 2/2 [00:00<00:00, 499.68it/s]


[075/100] Train Loss: 163.08740 | Val loss: 178.48118


100%|██████████| 11/11 [00:00<00:00, 354.84it/s]
100%|██████████| 2/2 [00:00<00:00, 665.97it/s]


[076/100] Train Loss: 160.97202 | Val loss: 158.20760


100%|██████████| 11/11 [00:00<00:00, 354.81it/s]
100%|██████████| 2/2 [00:00<00:00, 666.98it/s]


[077/100] Train Loss: 158.67486 | Val loss: 155.23668


100%|██████████| 11/11 [00:00<00:00, 343.71it/s]
100%|██████████| 2/2 [00:00<00:00, 568.53it/s]


[078/100] Train Loss: 157.80767 | Val loss: 139.31641
saving model with loss 139.31641


100%|██████████| 11/11 [00:00<00:00, 346.26it/s]
100%|██████████| 2/2 [00:00<00:00, 666.71it/s]


[079/100] Train Loss: 155.94818 | Val loss: 161.56347


100%|██████████| 11/11 [00:00<00:00, 239.13it/s]
100%|██████████| 2/2 [00:00<00:00, 666.29it/s]


[080/100] Train Loss: 153.78936 | Val loss: 155.38445


100%|██████████| 11/11 [00:00<00:00, 359.35it/s]
100%|██████████| 2/2 [00:00<00:00, 666.50it/s]


[081/100] Train Loss: 152.71188 | Val loss: 155.11383


100%|██████████| 11/11 [00:00<00:00, 323.51it/s]
100%|██████████| 2/2 [00:00<00:00, 666.08it/s]


[082/100] Train Loss: 150.89781 | Val loss: 142.88160


100%|██████████| 11/11 [00:00<00:00, 327.51it/s]
100%|██████████| 2/2 [00:00<00:00, 665.87it/s]


[083/100] Train Loss: 148.87130 | Val loss: 132.03282
saving model with loss 132.03282


100%|██████████| 11/11 [00:00<00:00, 342.14it/s]
100%|██████████| 2/2 [00:00<00:00, 667.09it/s]


[084/100] Train Loss: 146.61216 | Val loss: 121.62381
saving model with loss 121.62381


100%|██████████| 11/11 [00:00<00:00, 335.83it/s]
100%|██████████| 2/2 [00:00<00:00, 666.19it/s]


[085/100] Train Loss: 146.22310 | Val loss: 165.96301


100%|██████████| 11/11 [00:00<00:00, 171.92it/s]
100%|██████████| 2/2 [00:00<00:00, 500.30it/s]


[086/100] Train Loss: 143.73740 | Val loss: 139.35241


100%|██████████| 11/11 [00:00<00:00, 271.45it/s]
100%|██████████| 2/2 [00:00<00:00, 499.86it/s]


[087/100] Train Loss: 141.69568 | Val loss: 139.14838


100%|██████████| 11/11 [00:00<00:00, 327.67it/s]
100%|██████████| 2/2 [00:00<00:00, 666.29it/s]


[088/100] Train Loss: 140.56241 | Val loss: 134.86542


100%|██████████| 11/11 [00:00<00:00, 318.14it/s]
100%|██████████| 2/2 [00:00<00:00, 638.94it/s]


[089/100] Train Loss: 138.28176 | Val loss: 137.83912


100%|██████████| 11/11 [00:00<00:00, 245.12it/s]
100%|██████████| 2/2 [00:00<00:00, 352.06it/s]


[090/100] Train Loss: 139.59013 | Val loss: 130.94886


100%|██████████| 11/11 [00:00<00:00, 286.83it/s]
100%|██████████| 2/2 [00:00<00:00, 500.04it/s]


[091/100] Train Loss: 136.50755 | Val loss: 128.07392


100%|██████████| 11/11 [00:00<00:00, 297.32it/s]
100%|██████████| 2/2 [00:00<00:00, 425.13it/s]


[092/100] Train Loss: 136.30792 | Val loss: 112.73536
saving model with loss 112.73536


100%|██████████| 11/11 [00:00<00:00, 297.30it/s]
100%|██████████| 2/2 [00:00<00:00, 540.36it/s]


[093/100] Train Loss: 133.45371 | Val loss: 139.72060


100%|██████████| 11/11 [00:00<00:00, 209.43it/s]
100%|██████████| 2/2 [00:00<00:00, 199.99it/s]


[094/100] Train Loss: 132.09804 | Val loss: 131.59442


100%|██████████| 11/11 [00:00<00:00, 220.00it/s]
100%|██████████| 2/2 [00:00<00:00, 666.61it/s]


[095/100] Train Loss: 131.70418 | Val loss: 128.11469


100%|██████████| 11/11 [00:00<00:00, 99.43it/s]
100%|██████████| 2/2 [00:00<00:00, 666.50it/s]


[096/100] Train Loss: 128.91473 | Val loss: 121.29679


100%|██████████| 11/11 [00:00<00:00, 343.74it/s]
100%|██████████| 2/2 [00:00<00:00, 499.71it/s]


[097/100] Train Loss: 128.51185 | Val loss: 129.57015


100%|██████████| 11/11 [00:00<00:00, 327.75it/s]
100%|██████████| 2/2 [00:00<00:00, 499.59it/s]


[098/100] Train Loss: 126.71949 | Val loss: 125.84037


100%|██████████| 11/11 [00:00<00:00, 342.27it/s]
100%|██████████| 2/2 [00:00<00:00, 666.98it/s]


[099/100] Train Loss: 125.48237 | Val loss: 102.93153
saving model with loss 102.93153


100%|██████████| 11/11 [00:00<00:00, 333.31it/s]
100%|██████████| 2/2 [00:00<00:00, 499.77it/s]

[100/100] Train Loss: 124.54389 | Val loss: 129.30560



