In [18]:
# google drive link
# !gdown --id '1BjXalPZxq9mybPKNjF3h5L3NcF7XKTS-' --output covid_train.csv
# !gdown --id '1B55t74Jg2E5FCsKCsUEkPKIuqaY7UIi1' --output covid_test.csv

# dropbox link
!wget -O covid_train.csv https://www.dropbox.com/s/lmy1riadzoy0ahw/covid.train.csv?dl=0
!wget -O covid_test.csv https://www.dropbox.com/s/zalbw42lu4nmhr2/covid.test.csv?dl=0
# Numerical Operations
import math
import numpy as np

# Reading/Writing Data
import pandas as pd
import os
import csv

# For Progress Bar
from tqdm import tqdm

# Pytorch
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# For plotting learning curve
from torch.utils.tensorboard import SummaryWriter
def same_seed(seed): 
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_valid_split(data_set, valid_ratio, seed):
    '''Split provided training data into training set and validation set'''
    valid_set_size = int(valid_ratio * len(data_set)) 
    train_set_size = len(data_set) - valid_set_size
    train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_set), np.array(valid_set)

def predict(test_loader, model, device):
    model.eval() # Set your model to evaluation mode.
    preds = []
    for x in tqdm(test_loader):
        x = x.to(device)                        
        with torch.no_grad():                   
            pred = model(x)                     
            preds.append(pred.detach().cpu())   
    preds = torch.cat(preds, dim=0).numpy()  
    return preds
class COVID19Dataset(Dataset):
    '''
    x: Features.
    y: Targets, if none, do prediction.
    '''
    def __init__(self, x, y=None):
        if y is None:
            self.y = y
        else:
            self.y = torch.FloatTensor(y)
        self.x = torch.FloatTensor(x)

    def __getitem__(self, idx):
        if self.y is None:
            return self.x[idx]
        else:
            return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

'wget' 不是內部或外部命令、可執行的程式或批次檔。
'wget' 不是內部或外部命令、可執行的程式或批次檔。


In [19]:
class My_Model(nn.Module):
    def __init__(self, input_dim):
        super(My_Model, self).__init__()
        # TODO: modify model's structure, be aware of dimensions. 
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 64), 
            nn.ReLU(),
            nn.Linear(64, 32),#64 32
            nn.ReLU(),
            nn.Linear(32, 16),#32 16 
            nn.ReLU(),
            nn.Linear(16, 1),
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x

In [20]:
def select_feat(train_data, valid_data, test_data, select_all=True):
    '''Selects useful features to perform regression'''
    y_train, y_valid = train_data[:,-1], valid_data[:,-1]
    raw_x_train, raw_x_valid, raw_x_test = train_data[:,:-1], valid_data[:,:-1], test_data

    if select_all:
        feat_idx = list(range(raw_x_train.shape[1]))
    else:
        feat_idx = [35,36,37,47,48,52,
                    53,54,55,65,66,70,
                    71,72,73,83,84]
        
    return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], raw_x_test[:,feat_idx], y_train, y_valid

In [21]:
# a = 2.6
def trainer(train_loader, valid_loader, model, config, device):

    criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.

    # Define your optimization algorithm. 
    # TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.
    # TODO: L2 regularization (optimizer(weight decay...) or implement by your self).
    #optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.7)
#     optimizer = torch.optim.RMSprop(model.parameters(), lr=config['learning_rate'], momentum=0.7)
#     optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'],weight_decay = 0.01)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'],weight_decay = 0.00)
    
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", factor = 0.0005,
#                                                            patience=5, verbose=True,min_lr = 0.001)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
    factor=0.001, patience=3, threshold=0.001,verbose=True,min_lr = 0.001)
    
    writer = SummaryWriter() # Writer of tensoboard.

    if not os.path.isdir('./models'):
        os.mkdir('./models') # Create directory of saving models.

    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0

    for epoch in range(n_epochs):
        model.train() # Set your model to train mode.
        loss_record = []

        # tqdm is a package to visualize your training progress.
        train_pbar = tqdm(train_loader, position=0, leave=True)

        for x, y in train_pbar:
            optimizer.zero_grad()               # Set gradient to zero.
            x, y = x.to(device), y.to(device)   # Move your data to device. 
            pred = model(x)             
            loss = criterion(pred, y)
            loss.backward()                     # Compute gradient(backpropagation).
            optimizer.step()                    # Update parameters.
            step += 1
            loss_record.append(loss.detach().item())
            
            # Display current epoch number and loss on tqdm progress bar.
            train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')
            train_pbar.set_postfix({'loss': loss.detach().item()})

        mean_train_loss = sum(loss_record)/len(loss_record)
        writer.add_scalar('Loss/train', mean_train_loss, step)

        model.eval() # Set your model to evaluation mode.
        loss_record = []
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                loss = criterion(pred, y)

            loss_record.append(loss.item())
            
        mean_valid_loss = sum(loss_record)/len(loss_record)
        
        scheduler.step(mean_valid_loss)
        
        if epoch % 100 == 0:
          print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')
#         writer.add_scalar('Loss/valid', mean_valid_loss, step)

        if mean_valid_loss < best_loss:
            best_loss = mean_valid_loss
            torch.save(model.state_dict(), config['save_path']) # Save your best model
            print('Saving model with loss {:.3f}...'.format(best_loss))
#             a = (best_loss)
            early_stop_count = 0
        else: 
            early_stop_count += 1

        if early_stop_count >= config['early_stop']:
            print('\nModel is not improving, so we halt the training session.')
            return mean_valid_loss , mean_train_loss,best_loss

In [22]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {
    'seed': 24,      # Your seed number, you can pick your lucky number. :)
    'select_all': False,   # Whether to use all features.
    'valid_ratio': 0.2,   # validation_size = train_size * valid_ratio
    #'n_epochs': 5000,     # Number of epochs.  
    'n_epochs': 4000,          
    'batch_size': 256, 
    # 'learning_rate': 1e-5,
    'learning_rate': 0.007,              
    'early_stop': 800,    # If model has not improved for this many consecutive epochs, stop training.     
    'save_path': './models/model.ckpt'  # Your model will be saved here.
}


In [23]:
same_seed(config['seed'])
train_data, test_data = pd.read_csv('./covid_train.csv').values, pd.read_csv('./covid_test.csv').values
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])

# Print out the data size.
print(f"""train_data size: {train_data.shape} 
valid_data size: {valid_data.shape} 
test_data size: {test_data.shape}""")

# Select features
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'])

# Print out the number of features.
print(f'number of features: {x_train.shape[1]}')

train_dataset, valid_dataset, test_dataset = COVID19Dataset(x_train, y_train), \
                                            COVID19Dataset(x_valid, y_valid), \
                                            COVID19Dataset(x_test)

# Pytorch data loader loads pytorch dataset into batches.
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)

train_data size: (2408, 89) 
valid_data size: (601, 89) 
test_data size: (997, 88)
number of features: 17


In [24]:
model = My_Model(input_dim=x_train.shape[1]).to(device) # put your model and data on the same computation device.
trainer(train_loader, valid_loader, model, config, device)

Epoch [1/4000]: 100%|██████████| 10/10 [00:00<00:00, 94.59it/s, loss=73.1]


Epoch [1/4000]: Train loss: 157.5667, Valid loss: 49.3653
Saving model with loss 49.365...


Epoch [2/4000]: 100%|██████████| 10/10 [00:00<00:00, 105.54it/s, loss=23.4]


Saving model with loss 17.856...


Epoch [3/4000]: 100%|██████████| 10/10 [00:00<00:00, 110.18it/s, loss=10.8]


Saving model with loss 10.489...


Epoch [4/4000]: 100%|██████████| 10/10 [00:00<00:00, 107.81it/s, loss=8.52]


Saving model with loss 9.945...


Epoch [5/4000]: 100%|██████████| 10/10 [00:00<00:00, 108.72it/s, loss=10.5]


Saving model with loss 9.803...


Epoch [6/4000]: 100%|██████████| 10/10 [00:00<00:00, 108.99it/s, loss=10.2]


Saving model with loss 8.168...


Epoch [7/4000]: 100%|██████████| 10/10 [00:00<00:00, 85.70it/s, loss=6.71]


Saving model with loss 7.398...


Epoch [8/4000]: 100%|██████████| 10/10 [00:00<00:00, 116.59it/s, loss=6.53]


Saving model with loss 6.338...


Epoch [9/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.98it/s, loss=6.15]


Saving model with loss 5.466...


Epoch [10/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.94it/s, loss=5.1]


Saving model with loss 4.493...


Epoch [11/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.38it/s, loss=3.51]


Saving model with loss 3.487...


Epoch [12/4000]: 100%|██████████| 10/10 [00:00<00:00, 106.67it/s, loss=2.38]


Saving model with loss 2.504...


Epoch [13/4000]: 100%|██████████| 10/10 [00:00<00:00, 103.37it/s, loss=1.92]


Saving model with loss 1.776...


Epoch [14/4000]: 100%|██████████| 10/10 [00:00<00:00, 106.66it/s, loss=1.47]


Saving model with loss 1.582...


Epoch [15/4000]: 100%|██████████| 10/10 [00:00<00:00, 75.96it/s, loss=1.61]


Saving model with loss 1.512...


Epoch [16/4000]: 100%|██████████| 10/10 [00:00<00:00, 97.65it/s, loss=1.37]


Saving model with loss 1.474...


Epoch [17/4000]: 100%|██████████| 10/10 [00:00<00:00, 111.39it/s, loss=1.34]


Saving model with loss 1.464...


Epoch [18/4000]: 100%|██████████| 10/10 [00:00<00:00, 116.45it/s, loss=1.24]


Saving model with loss 1.416...


Epoch [19/4000]: 100%|██████████| 10/10 [00:00<00:00, 112.66it/s, loss=1.35]
Epoch [20/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.99it/s, loss=1.53]


Saving model with loss 1.338...


Epoch [21/4000]: 100%|██████████| 10/10 [00:00<00:00, 116.57it/s, loss=1.15]
Epoch [22/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.94it/s, loss=1.18]


Saving model with loss 1.301...


Epoch [23/4000]: 100%|██████████| 10/10 [00:00<00:00, 100.61it/s, loss=1.61]
Epoch [24/4000]: 100%|██████████| 10/10 [00:00<00:00, 105.54it/s, loss=1.29]
Epoch [25/4000]: 100%|██████████| 10/10 [00:00<00:00, 97.35it/s, loss=1.06]


Saving model with loss 1.258...


Epoch [26/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.99it/s, loss=1.28]
Epoch [27/4000]: 100%|██████████| 10/10 [00:00<00:00, 111.57it/s, loss=0.971]
Epoch [28/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.94it/s, loss=1.18]
Epoch [29/4000]: 100%|██████████| 10/10 [00:00<00:00, 115.25it/s, loss=1.29]
Epoch [30/4000]: 100%|██████████| 10/10 [00:00<00:00, 117.96it/s, loss=1.11]
Epoch [31/4000]: 100%|██████████| 10/10 [00:00<00:00, 108.96it/s, loss=1.43]
Epoch [32/4000]: 100%|██████████| 10/10 [00:00<00:00, 88.73it/s, loss=1.54]


Saving model with loss 1.201...


Epoch [33/4000]: 100%|██████████| 10/10 [00:00<00:00, 101.28it/s, loss=0.999]


Saving model with loss 1.198...


Epoch [34/4000]: 100%|██████████| 10/10 [00:00<00:00, 93.71it/s, loss=1.13]
Epoch [35/4000]: 100%|██████████| 10/10 [00:00<00:00, 108.99it/s, loss=1.07]


Saving model with loss 1.181...


Epoch [36/4000]: 100%|██████████| 10/10 [00:00<00:00, 110.94it/s, loss=1.6]
Epoch [37/4000]: 100%|██████████| 10/10 [00:00<00:00, 103.37it/s, loss=1.48]
Epoch [38/4000]: 100%|██████████| 10/10 [00:00<00:00, 110.19it/s, loss=1.57]


Saving model with loss 1.178...


Epoch [39/4000]: 100%|██████████| 10/10 [00:00<00:00, 107.09it/s, loss=0.964]
Epoch [40/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.94it/s, loss=1.37]


Saving model with loss 1.166...


Epoch [41/4000]: 100%|██████████| 10/10 [00:00<00:00, 116.52it/s, loss=0.851]


Saving model with loss 1.152...


Epoch [42/4000]: 100%|██████████| 10/10 [00:00<00:00, 55.41it/s, loss=1.24]
Epoch [43/4000]: 100%|██████████| 10/10 [00:00<00:00, 93.71it/s, loss=1.26]
Epoch [44/4000]: 100%|██████████| 10/10 [00:00<00:00, 102.31it/s, loss=1.16]
Epoch [45/4000]: 100%|██████████| 10/10 [00:00<00:00, 143.24it/s, loss=1.04]
Epoch [46/4000]: 100%|██████████| 10/10 [00:00<00:00, 147.45it/s, loss=1.5]
Epoch [47/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.65it/s, loss=1.8]
Epoch [48/4000]: 100%|██████████| 10/10 [00:00<00:00, 125.33it/s, loss=1.54]


Saving model with loss 1.063...


Epoch [49/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.65it/s, loss=1.27]
Epoch [50/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.55it/s, loss=1.26]
Epoch [51/4000]: 100%|██████████| 10/10 [00:00<00:00, 72.96it/s, loss=1.16]
Epoch [52/4000]: 100%|██████████| 10/10 [00:00<00:00, 106.67it/s, loss=1.07]
Epoch [53/4000]: 100%|██████████| 10/10 [00:00<00:00, 145.32it/s, loss=1.02]


Saving model with loss 1.063...


Epoch [54/4000]: 100%|██████████| 10/10 [00:00<00:00, 143.24it/s, loss=0.912]
Epoch [55/4000]: 100%|██████████| 10/10 [00:00<00:00, 137.35it/s, loss=1.43]
Epoch [56/4000]: 100%|██████████| 10/10 [00:00<00:00, 147.45it/s, loss=1.09]
Epoch [57/4000]: 100%|██████████| 10/10 [00:00<00:00, 147.45it/s, loss=0.883]
Epoch [58/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.55it/s, loss=0.984]
Epoch [59/4000]: 100%|██████████| 10/10 [00:00<00:00, 145.32it/s, loss=0.928]
Epoch [60/4000]: 100%|██████████| 10/10 [00:00<00:00, 137.35it/s, loss=1.01]
Epoch [61/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.49it/s, loss=1]


Saving model with loss 1.057...


Epoch [62/4000]: 100%|██████████| 10/10 [00:00<00:00, 139.26it/s, loss=1.34]
Epoch [63/4000]: 100%|██████████| 10/10 [00:00<00:00, 139.26it/s, loss=1.23]
Epoch [64/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.49it/s, loss=1.09]
Epoch [65/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.50it/s, loss=1.09]
Epoch [66/4000]: 100%|██████████| 10/10 [00:00<00:00, 133.69it/s, loss=1.06]
Epoch [67/4000]: 100%|██████████| 10/10 [00:00<00:00, 144.32it/s, loss=0.937]
Epoch [68/4000]: 100%|██████████| 10/10 [00:00<00:00, 164.37it/s, loss=1]


Saving model with loss 1.056...


Epoch [69/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.26it/s, loss=0.684]


Saving model with loss 1.034...


Epoch [70/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.26it/s, loss=1.11]
Epoch [71/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=1.01]


Saving model with loss 1.030...


Epoch [72/4000]: 100%|██████████| 10/10 [00:00<00:00, 152.02it/s, loss=1.32]


Saving model with loss 1.012...


Epoch [73/4000]: 100%|██████████| 10/10 [00:00<00:00, 147.35it/s, loss=1.02]


Saving model with loss 0.988...


Epoch [74/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.57it/s, loss=0.974]
Epoch [75/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.65it/s, loss=1.42]
Epoch [76/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.66it/s, loss=1.15]
Epoch [77/4000]: 100%|██████████| 10/10 [00:00<00:00, 145.36it/s, loss=1.13]
Epoch [78/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.93it/s, loss=1]
Epoch [79/4000]: 100%|██████████| 10/10 [00:00<00:00, 117.96it/s, loss=1.2]
Epoch [80/4000]: 100%|██████████| 10/10 [00:00<00:00, 120.80it/s, loss=1.15]
Epoch [81/4000]: 100%|██████████| 10/10 [00:00<00:00, 97.08it/s, loss=0.907]


Saving model with loss 0.973...


Epoch [82/4000]: 100%|██████████| 10/10 [00:00<00:00, 104.45it/s, loss=0.968]
Epoch [83/4000]: 100%|██████████| 10/10 [00:00<00:00, 133.69it/s, loss=1.03]
Epoch [84/4000]: 100%|██████████| 10/10 [00:00<00:00, 139.14it/s, loss=1.1]
Epoch [85/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.76it/s, loss=0.879]
Epoch [86/4000]: 100%|██████████| 10/10 [00:00<00:00, 161.72it/s, loss=1.18]
Epoch [87/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.94it/s, loss=1.35]
Epoch [88/4000]: 100%|██████████| 10/10 [00:00<00:00, 153.69it/s, loss=0.798]


Saving model with loss 0.972...


Epoch [89/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.74it/s, loss=1.24]
Epoch [90/4000]: 100%|██████████| 10/10 [00:00<00:00, 147.45it/s, loss=0.912]
Epoch [91/4000]: 100%|██████████| 10/10 [00:00<00:00, 141.22it/s, loss=1.17]
Epoch [92/4000]: 100%|██████████| 10/10 [00:00<00:00, 143.24it/s, loss=0.886]
Epoch [93/4000]: 100%|██████████| 10/10 [00:00<00:00, 139.26it/s, loss=0.929]
Epoch [94/4000]: 100%|██████████| 10/10 [00:00<00:00, 145.32it/s, loss=1.1]
Epoch [95/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.00it/s, loss=1.1]


Saving model with loss 0.922...


Epoch [96/4000]: 100%|██████████| 10/10 [00:00<00:00, 143.71it/s, loss=1.01]
Epoch [97/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.15it/s, loss=1.01]
Epoch [98/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.05it/s, loss=0.92]
Epoch [99/4000]: 100%|██████████| 10/10 [00:00<00:00, 158.15it/s, loss=0.942]
Epoch [100/4000]: 100%|██████████| 10/10 [00:00<00:00, 152.84it/s, loss=1.36]
Epoch [101/4000]: 100%|██████████| 10/10 [00:00<00:00, 153.98it/s, loss=0.782]


Epoch [101/4000]: Train loss: 0.9712, Valid loss: 1.3770


Epoch [102/4000]: 100%|██████████| 10/10 [00:00<00:00, 141.01it/s, loss=1.09]


Saving model with loss 0.910...


Epoch [103/4000]: 100%|██████████| 10/10 [00:00<00:00, 77.15it/s, loss=0.65]
Epoch [104/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.87it/s, loss=1.06]
Epoch [105/4000]: 100%|██████████| 10/10 [00:00<00:00, 136.57it/s, loss=0.816]
Epoch [106/4000]: 100%|██████████| 10/10 [00:00<00:00, 112.06it/s, loss=1.25]
Epoch [107/4000]: 100%|██████████| 10/10 [00:00<00:00, 115.17it/s, loss=1.53]
Epoch [108/4000]: 100%|██████████| 10/10 [00:00<00:00, 161.72it/s, loss=0.733]
Epoch [109/4000]: 100%|██████████| 10/10 [00:00<00:00, 164.38it/s, loss=1.32]
Epoch [110/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.15it/s, loss=1.03]


Saving model with loss 0.902...


Epoch [111/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.64it/s, loss=1.45]
Epoch [112/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.77it/s, loss=0.803]
Epoch [113/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.61it/s, loss=0.799]


Saving model with loss 0.890...


Epoch [114/4000]: 100%|██████████| 10/10 [00:00<00:00, 152.60it/s, loss=1.35]
Epoch [115/4000]: 100%|██████████| 10/10 [00:00<00:00, 151.61it/s, loss=0.835]
Epoch [116/4000]: 100%|██████████| 10/10 [00:00<00:00, 143.24it/s, loss=1.36]
Epoch [117/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.42it/s, loss=1.03]
Epoch [118/4000]: 100%|██████████| 10/10 [00:00<00:00, 150.05it/s, loss=0.787]
Epoch [119/4000]: 100%|██████████| 10/10 [00:00<00:00, 142.30it/s, loss=0.839]
Epoch [120/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.20it/s, loss=0.841]
Epoch [121/4000]: 100%|██████████| 10/10 [00:00<00:00, 150.77it/s, loss=1.2]
Epoch [122/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.32it/s, loss=0.614]
Epoch [123/4000]: 100%|██████████| 10/10 [00:00<00:00, 158.31it/s, loss=1.12]


Saving model with loss 0.876...


Epoch [124/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.16it/s, loss=0.934]
Epoch [125/4000]: 100%|██████████| 10/10 [00:00<00:00, 160.65it/s, loss=0.889]
Epoch [126/4000]: 100%|██████████| 10/10 [00:00<00:00, 158.46it/s, loss=1.15]
Epoch [127/4000]: 100%|██████████| 10/10 [00:00<00:00, 153.99it/s, loss=1.23]
Epoch [128/4000]: 100%|██████████| 10/10 [00:00<00:00, 157.19it/s, loss=1.02]
Epoch [129/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.11it/s, loss=1.11]
Epoch [130/4000]: 100%|██████████| 10/10 [00:00<00:00, 147.45it/s, loss=1.18]
Epoch [131/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.26it/s, loss=0.988]
Epoch [132/4000]: 100%|██████████| 10/10 [00:00<00:00, 161.72it/s, loss=0.848]
Epoch [133/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.15it/s, loss=1.27]
Epoch [134/4000]: 100%|██████████| 10/10 [00:00<00:00, 150.02it/s, loss=1.14]
Epoch [135/4000]: 100%|██████████| 10/10 [00:00<00:00, 152.57it/s, loss=0.995]
Epoch [136/4000]: 100%|██████████| 10/10 [00:00<00:00, 157.

Saving model with loss 0.864...


Epoch [140/4000]: 100%|██████████| 10/10 [00:00<00:00, 151.64it/s, loss=1.02]
Epoch [141/4000]: 100%|██████████| 10/10 [00:00<00:00, 150.03it/s, loss=0.835]
Epoch [142/4000]: 100%|██████████| 10/10 [00:00<00:00, 152.72it/s, loss=1.18]
Epoch [143/4000]: 100%|██████████| 10/10 [00:00<00:00, 148.36it/s, loss=1.08]
Epoch [144/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.65it/s, loss=0.806]
Epoch [145/4000]: 100%|██████████| 10/10 [00:00<00:00, 150.18it/s, loss=1.44]
Epoch [146/4000]: 100%|██████████| 10/10 [00:00<00:00, 144.06it/s, loss=1.21]
Epoch [147/4000]: 100%|██████████| 10/10 [00:00<00:00, 157.22it/s, loss=0.877]
Epoch [148/4000]: 100%|██████████| 10/10 [00:00<00:00, 151.92it/s, loss=1.26]
Epoch [149/4000]: 100%|██████████| 10/10 [00:00<00:00, 151.89it/s, loss=1.17]
Epoch [150/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.16it/s, loss=0.792]
Epoch [151/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.77it/s, loss=0.828]
Epoch [152/4000]: 100%|██████████| 10/10 [00:00<00:00, 141.

Saving model with loss 0.863...


Epoch [167/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=1.02]
Epoch [168/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=0.673]
Epoch [169/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.36it/s, loss=1.07]
Epoch [170/4000]: 100%|██████████| 10/10 [00:00<00:00, 161.72it/s, loss=1.14]
Epoch [171/4000]: 100%|██████████| 10/10 [00:00<00:00, 137.74it/s, loss=0.905]
Epoch [172/4000]: 100%|██████████| 10/10 [00:00<00:00, 112.20it/s, loss=1.27]
Epoch [173/4000]: 100%|██████████| 10/10 [00:00<00:00, 163.59it/s, loss=0.746]
Epoch [174/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.93it/s, loss=0.872]
Epoch [175/4000]: 100%|██████████| 10/10 [00:00<00:00, 94.79it/s, loss=0.833]
Epoch [176/4000]: 100%|██████████| 10/10 [00:00<00:00, 150.81it/s, loss=0.619]
Epoch [177/4000]: 100%|██████████| 10/10 [00:00<00:00, 90.26it/s, loss=0.76]
Epoch [178/4000]: 100%|██████████| 10/10 [00:00<00:00, 91.15it/s, loss=0.886]
Epoch [179/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.2

Saving model with loss 0.861...


Epoch [189/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.16it/s, loss=0.901]
Epoch [190/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.82it/s, loss=0.977]
Epoch [191/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.78it/s, loss=0.735]
Epoch [192/4000]: 100%|██████████| 10/10 [00:00<00:00, 157.81it/s, loss=0.834]
Epoch [193/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.95it/s, loss=0.828]
Epoch [194/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.57it/s, loss=0.821]
Epoch [195/4000]: 100%|██████████| 10/10 [00:00<00:00, 145.32it/s, loss=0.736]
Epoch [196/4000]: 100%|██████████| 10/10 [00:00<00:00, 94.59it/s, loss=0.798]
Epoch [197/4000]: 100%|██████████| 10/10 [00:00<00:00, 143.24it/s, loss=1.3]
Epoch [198/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.26it/s, loss=0.577]
Epoch [199/4000]: 100%|██████████| 10/10 [00:00<00:00, 111.41it/s, loss=0.93]
Epoch [200/4000]: 100%|██████████| 10/10 [00:00<00:00, 163.00it/s, loss=1.02]
Epoch [201/4000]: 100%|██████████| 10/10 [00:00<00:00, 10

Epoch [201/4000]: Train loss: 0.9842, Valid loss: 0.8644


Epoch [202/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.75it/s, loss=0.746]
Epoch [203/4000]: 100%|██████████| 10/10 [00:00<00:00, 88.08it/s, loss=0.924]
Epoch [204/4000]: 100%|██████████| 10/10 [00:00<00:00, 157.58it/s, loss=0.856]
Epoch [205/4000]: 100%|██████████| 10/10 [00:00<00:00, 157.70it/s, loss=0.869]
Epoch [206/4000]: 100%|██████████| 10/10 [00:00<00:00, 160.38it/s, loss=0.92]


Saving model with loss 0.842...


Epoch [207/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.15it/s, loss=0.927]
Epoch [208/4000]: 100%|██████████| 10/10 [00:00<00:00, 162.19it/s, loss=0.768]
Epoch [209/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.99it/s, loss=0.776]
Epoch [210/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.62it/s, loss=0.833]
Epoch [211/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=1.17]
Epoch [212/4000]: 100%|██████████| 10/10 [00:00<00:00, 153.62it/s, loss=0.637]
Epoch [213/4000]: 100%|██████████| 10/10 [00:00<00:00, 153.60it/s, loss=0.844]
Epoch [214/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.72it/s, loss=0.954]
Epoch [215/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.16it/s, loss=0.866]
Epoch [216/4000]: 100%|██████████| 10/10 [00:00<00:00, 151.78it/s, loss=0.805]
Epoch [217/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=1]
Epoch [218/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.15it/s, loss=0.712]
Epoch [219/4000]: 100%|██████████| 10/10 [00:00<00:00, 15

Saving model with loss 0.838...


Epoch [271/4000]: 100%|██████████| 10/10 [00:00<00:00, 157.34it/s, loss=0.857]
Epoch [272/4000]: 100%|██████████| 10/10 [00:00<00:00, 151.96it/s, loss=1.3]
Epoch [273/4000]: 100%|██████████| 10/10 [00:00<00:00, 119.36it/s, loss=1.08]
Epoch [274/4000]: 100%|██████████| 10/10 [00:00<00:00, 137.30it/s, loss=0.815]
Epoch [275/4000]: 100%|██████████| 10/10 [00:00<00:00, 161.72it/s, loss=0.94]
Epoch [276/4000]: 100%|██████████| 10/10 [00:00<00:00, 161.78it/s, loss=1.13]
Epoch [277/4000]: 100%|██████████| 10/10 [00:00<00:00, 153.84it/s, loss=0.706]
Epoch [278/4000]: 100%|██████████| 10/10 [00:00<00:00, 151.15it/s, loss=0.763]
Epoch [279/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.86it/s, loss=1.04]
Epoch [280/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.94it/s, loss=0.828]
Epoch [281/4000]: 100%|██████████| 10/10 [00:00<00:00, 161.72it/s, loss=0.849]
Epoch [282/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.16it/s, loss=0.972]
Epoch [283/4000]: 100%|██████████| 10/10 [00:00<00:00, 164

Epoch [301/4000]: Train loss: 0.9259, Valid loss: 0.8731


Epoch [302/4000]: 100%|██████████| 10/10 [00:00<00:00, 143.24it/s, loss=1.17]
Epoch [303/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.93it/s, loss=0.702]
Epoch [304/4000]: 100%|██████████| 10/10 [00:00<00:00, 126.92it/s, loss=1.52]
Epoch [305/4000]: 100%|██████████| 10/10 [00:00<00:00, 164.37it/s, loss=1.21]
Epoch [306/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.16it/s, loss=0.907]
Epoch [307/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=0.81]
Epoch [308/4000]: 100%|██████████| 10/10 [00:00<00:00, 161.85it/s, loss=0.806]
Epoch [309/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.76it/s, loss=0.799]
Epoch [310/4000]: 100%|██████████| 10/10 [00:00<00:00, 141.22it/s, loss=0.836]
Epoch [311/4000]: 100%|██████████| 10/10 [00:00<00:00, 93.39it/s, loss=0.543]
Epoch [312/4000]: 100%|██████████| 10/10 [00:00<00:00, 86.44it/s, loss=1.06]
Epoch [313/4000]: 100%|██████████| 10/10 [00:00<00:00, 90.33it/s, loss=0.862]
Epoch [314/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.4

Saving model with loss 0.836...


Epoch [382/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=0.891]
Epoch [383/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.69it/s, loss=0.69]
Epoch [384/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.65it/s, loss=1.11]
Epoch [385/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.93it/s, loss=1.14]
Epoch [386/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.39it/s, loss=1.03]
Epoch [387/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.41it/s, loss=1.04]
Epoch [388/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.49it/s, loss=0.927]
Epoch [389/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=0.779]
Epoch [390/4000]: 100%|██████████| 10/10 [00:00<00:00, 153.82it/s, loss=0.937]
Epoch [391/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=0.996]
Epoch [392/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.61it/s, loss=0.76]
Epoch [393/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.65it/s, loss=0.634]
Epoch [394/4000]: 100%|██████████| 10/10 [00:00<00:00, 155

Epoch [401/4000]: Train loss: 0.8974, Valid loss: 0.8531


Epoch [402/4000]: 100%|██████████| 10/10 [00:00<00:00, 109.59it/s, loss=0.888]
Epoch [403/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.36it/s, loss=1.08]
Epoch [404/4000]: 100%|██████████| 10/10 [00:00<00:00, 106.62it/s, loss=1]
Epoch [405/4000]: 100%|██████████| 10/10 [00:00<00:00, 101.27it/s, loss=0.982]
Epoch [406/4000]: 100%|██████████| 10/10 [00:00<00:00, 97.77it/s, loss=1.02]
Epoch [407/4000]: 100%|██████████| 10/10 [00:00<00:00, 98.30it/s, loss=0.885]
Epoch [408/4000]: 100%|██████████| 10/10 [00:00<00:00, 119.37it/s, loss=1]
Epoch [409/4000]: 100%|██████████| 10/10 [00:00<00:00, 84.97it/s, loss=0.812]
Epoch [410/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.26it/s, loss=0.844]
Epoch [411/4000]: 100%|██████████| 10/10 [00:00<00:00, 160.79it/s, loss=0.868]
Epoch [412/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.31it/s, loss=0.767]
Epoch [413/4000]: 100%|██████████| 10/10 [00:00<00:00, 151.92it/s, loss=0.777]
Epoch [414/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.38it/s

Saving model with loss 0.834...


Epoch [444/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.49it/s, loss=0.792]
Epoch [445/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.93it/s, loss=1.05]


Saving model with loss 0.832...


Epoch [446/4000]: 100%|██████████| 10/10 [00:00<00:00, 103.37it/s, loss=1.05]
Epoch [447/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.94it/s, loss=1.05]
Epoch [448/4000]: 100%|██████████| 10/10 [00:00<00:00, 130.22it/s, loss=1.04]
Epoch [449/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.50it/s, loss=0.937]
Epoch [450/4000]: 100%|██████████| 10/10 [00:00<00:00, 139.26it/s, loss=0.792]
Epoch [451/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.26it/s, loss=1.2]
Epoch [452/4000]: 100%|██████████| 10/10 [00:00<00:00, 137.35it/s, loss=0.778]
Epoch [453/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.50it/s, loss=0.858]
Epoch [454/4000]: 100%|██████████| 10/10 [00:00<00:00, 140.14it/s, loss=0.99]
Epoch [455/4000]: 100%|██████████| 10/10 [00:00<00:00, 142.77it/s, loss=0.75]
Epoch [456/4000]: 100%|██████████| 10/10 [00:00<00:00, 149.84it/s, loss=1.13]
Epoch [457/4000]: 100%|██████████| 10/10 [00:00<00:00, 143.25it/s, loss=1.05]
Epoch [458/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.74

Epoch [501/4000]: Train loss: 0.8933, Valid loss: 0.9743


Epoch [502/4000]: 100%|██████████| 10/10 [00:00<00:00, 160.23it/s, loss=0.888]
Epoch [503/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.67it/s, loss=0.955]
Epoch [504/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.16it/s, loss=0.937]
Epoch [505/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.29it/s, loss=0.973]
Epoch [506/4000]: 100%|██████████| 10/10 [00:00<00:00, 148.33it/s, loss=1.19]
Epoch [507/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.61it/s, loss=0.932]
Epoch [508/4000]: 100%|██████████| 10/10 [00:00<00:00, 143.35it/s, loss=1.17]
Epoch [509/4000]: 100%|██████████| 10/10 [00:00<00:00, 153.85it/s, loss=0.647]
Epoch [510/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.17it/s, loss=0.788]
Epoch [511/4000]: 100%|██████████| 10/10 [00:00<00:00, 161.72it/s, loss=1.41]
Epoch [512/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.83it/s, loss=1.35]
Epoch [513/4000]: 100%|██████████| 10/10 [00:00<00:00, 157.81it/s, loss=1]
Epoch [514/4000]: 100%|██████████| 10/10 [00:00<00:00, 158.1

Epoch [601/4000]: Train loss: 0.8994, Valid loss: 0.8793


Epoch [602/4000]: 100%|██████████| 10/10 [00:00<00:00, 158.48it/s, loss=1.15]
Epoch [603/4000]: 100%|██████████| 10/10 [00:00<00:00, 159.25it/s, loss=0.553]
Epoch [604/4000]: 100%|██████████| 10/10 [00:00<00:00, 156.66it/s, loss=1.4]
Epoch [605/4000]: 100%|██████████| 10/10 [00:00<00:00, 155.23it/s, loss=0.915]
Epoch [606/4000]: 100%|██████████| 10/10 [00:00<00:00, 154.17it/s, loss=0.888]


Saving model with loss 0.821...


Epoch [607/4000]: 100%|██████████| 10/10 [00:00<00:00, 112.20it/s, loss=1.06]
Epoch [608/4000]: 100%|██████████| 10/10 [00:00<00:00, 117.96it/s, loss=1.05]
Epoch [609/4000]: 100%|██████████| 10/10 [00:00<00:00, 100.27it/s, loss=0.753]
Epoch [610/4000]: 100%|██████████| 10/10 [00:00<00:00, 109.94it/s, loss=0.87]
Epoch [611/4000]: 100%|██████████| 10/10 [00:00<00:00, 102.32it/s, loss=1.02]
Epoch [612/4000]: 100%|██████████| 10/10 [00:00<00:00, 116.59it/s, loss=0.811]
Epoch [613/4000]: 100%|██████████| 10/10 [00:00<00:00, 111.46it/s, loss=0.834]
Epoch [614/4000]: 100%|██████████| 10/10 [00:00<00:00, 90.33it/s, loss=1.2] 
Epoch [615/4000]: 100%|██████████| 10/10 [00:00<00:00, 108.99it/s, loss=1.08]
Epoch [616/4000]: 100%|██████████| 10/10 [00:00<00:00, 93.71it/s, loss=0.833]
Epoch [617/4000]: 100%|██████████| 10/10 [00:00<00:00, 89.52it/s, loss=1.32]
Epoch [618/4000]: 100%|██████████| 10/10 [00:00<00:00, 88.73it/s, loss=0.995]
Epoch [619/4000]: 100%|██████████| 10/10 [00:00<00:00, 93.71it/

Saving model with loss 0.810...


Epoch [684/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.28it/s, loss=1.21]
Epoch [685/4000]: 100%|██████████| 10/10 [00:00<00:00, 123.44it/s, loss=0.992]
Epoch [686/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.28it/s, loss=0.831]
Epoch [687/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.28it/s, loss=1.03]
Epoch [688/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.28it/s, loss=0.797]
Epoch [689/4000]: 100%|██████████| 10/10 [00:00<00:00, 119.42it/s, loss=0.824]
Epoch [690/4000]: 100%|██████████| 10/10 [00:00<00:00, 115.94it/s, loss=0.827]
Epoch [691/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.23it/s, loss=0.877]
Epoch [692/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.55it/s, loss=0.816]
Epoch [693/4000]: 100%|██████████| 10/10 [00:00<00:00, 120.47it/s, loss=1.27]
Epoch [694/4000]: 100%|██████████| 10/10 [00:00<00:00, 110.05it/s, loss=1.08]
Epoch [695/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.77it/s, loss=0.771]
Epoch [696/4000]: 100%|██████████| 10/10 [00:00<00:00, 1

Epoch [701/4000]: Train loss: 0.9742, Valid loss: 1.0372


Epoch [702/4000]: 100%|██████████| 10/10 [00:00<00:00, 123.79it/s, loss=1.1]
Epoch [703/4000]: 100%|██████████| 10/10 [00:00<00:00, 123.87it/s, loss=0.683]
Epoch [704/4000]: 100%|██████████| 10/10 [00:00<00:00, 125.34it/s, loss=0.479]
Epoch [705/4000]: 100%|██████████| 10/10 [00:00<00:00, 126.88it/s, loss=1.13]
Epoch [706/4000]: 100%|██████████| 10/10 [00:00<00:00, 123.69it/s, loss=1.98]
Epoch [707/4000]: 100%|██████████| 10/10 [00:00<00:00, 126.85it/s, loss=1.19]
Epoch [708/4000]: 100%|██████████| 10/10 [00:00<00:00, 127.90it/s, loss=0.789]
Epoch [709/4000]: 100%|██████████| 10/10 [00:00<00:00, 125.33it/s, loss=0.973]
Epoch [710/4000]: 100%|██████████| 10/10 [00:00<00:00, 130.22it/s, loss=0.954]
Epoch [711/4000]: 100%|██████████| 10/10 [00:00<00:00, 127.28it/s, loss=0.737]
Epoch [712/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.49it/s, loss=0.799]
Epoch [713/4000]: 100%|██████████| 10/10 [00:00<00:00, 125.33it/s, loss=0.694]
Epoch [714/4000]: 100%|██████████| 10/10 [00:00<00:00, 13

Saving model with loss 0.781...


Epoch [726/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.19it/s, loss=0.919]
Epoch [727/4000]: 100%|██████████| 10/10 [00:00<00:00, 125.33it/s, loss=1.02]
Epoch [728/4000]: 100%|██████████| 10/10 [00:00<00:00, 124.31it/s, loss=1.2]
Epoch [729/4000]: 100%|██████████| 10/10 [00:00<00:00, 119.37it/s, loss=1.16]
Epoch [730/4000]: 100%|██████████| 10/10 [00:00<00:00, 113.94it/s, loss=0.979]
Epoch [731/4000]: 100%|██████████| 10/10 [00:00<00:00, 126.92it/s, loss=0.681]
Epoch [732/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.32it/s, loss=0.899]
Epoch [733/4000]: 100%|██████████| 10/10 [00:00<00:00, 119.07it/s, loss=0.738]
Epoch [734/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.28it/s, loss=0.896]
Epoch [735/4000]: 100%|██████████| 10/10 [00:00<00:00, 123.41it/s, loss=1.35]
Epoch [736/4000]: 100%|██████████| 10/10 [00:00<00:00, 121.66it/s, loss=1.28]
Epoch [737/4000]: 100%|██████████| 10/10 [00:00<00:00, 127.06it/s, loss=1.07]
Epoch [738/4000]: 100%|██████████| 10/10 [00:00<00:00, 120.

Epoch [801/4000]: Train loss: 1.0466, Valid loss: 1.0344


Epoch [802/4000]: 100%|██████████| 10/10 [00:00<00:00, 119.86it/s, loss=1.33]
Epoch [803/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.55it/s, loss=1.03]
Epoch [804/4000]: 100%|██████████| 10/10 [00:00<00:00, 123.79it/s, loss=1.01]
Epoch [805/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.55it/s, loss=0.947]
Epoch [806/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.93it/s, loss=0.908]
Epoch [807/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.77it/s, loss=0.667]
Epoch [808/4000]: 100%|██████████| 10/10 [00:00<00:00, 136.70it/s, loss=0.962]
Epoch [809/4000]: 100%|██████████| 10/10 [00:00<00:00, 133.69it/s, loss=1.09]
Epoch [810/4000]: 100%|██████████| 10/10 [00:00<00:00, 123.79it/s, loss=0.98]
Epoch [811/4000]: 100%|██████████| 10/10 [00:00<00:00, 130.22it/s, loss=0.986]
Epoch [812/4000]: 100%|██████████| 10/10 [00:00<00:00, 125.04it/s, loss=1.17]
Epoch [813/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.55it/s, loss=0.78]
Epoch [814/4000]: 100%|██████████| 10/10 [00:00<00:00, 127.

Epoch [901/4000]: Train loss: 0.9649, Valid loss: 0.8900


Epoch [902/4000]: 100%|██████████| 10/10 [00:00<00:00, 126.92it/s, loss=0.909]
Epoch [903/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.00it/s, loss=0.669]
Epoch [904/4000]: 100%|██████████| 10/10 [00:00<00:00, 123.79it/s, loss=0.757]
Epoch [905/4000]: 100%|██████████| 10/10 [00:00<00:00, 133.62it/s, loss=0.711]
Epoch [906/4000]: 100%|██████████| 10/10 [00:00<00:00, 120.24it/s, loss=1.03]
Epoch [907/4000]: 100%|██████████| 10/10 [00:00<00:00, 134.42it/s, loss=0.645]
Epoch [908/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.54it/s, loss=0.788]
Epoch [909/4000]: 100%|██████████| 10/10 [00:00<00:00, 134.31it/s, loss=1.05]
Epoch [910/4000]: 100%|██████████| 10/10 [00:00<00:00, 130.49it/s, loss=0.997]
Epoch [911/4000]: 100%|██████████| 10/10 [00:00<00:00, 130.95it/s, loss=0.816]
Epoch [912/4000]: 100%|██████████| 10/10 [00:00<00:00, 126.91it/s, loss=0.788]
Epoch [913/4000]: 100%|██████████| 10/10 [00:00<00:00, 106.90it/s, loss=1.1]
Epoch [914/4000]: 100%|██████████| 10/10 [00:00<00:00, 1

Epoch [1001/4000]: Train loss: 0.9420, Valid loss: 1.0049


Epoch [1002/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.28it/s, loss=1.22]
Epoch [1003/4000]: 100%|██████████| 10/10 [00:00<00:00, 81.87it/s, loss=0.836]
Epoch [1004/4000]: 100%|██████████| 10/10 [00:00<00:00, 120.80it/s, loss=1.14]
Epoch [1005/4000]: 100%|██████████| 10/10 [00:00<00:00, 126.92it/s, loss=0.89]
Epoch [1006/4000]: 100%|██████████| 10/10 [00:00<00:00, 117.68it/s, loss=0.969]
Epoch [1007/4000]: 100%|██████████| 10/10 [00:00<00:00, 120.50it/s, loss=0.61]
Epoch [1008/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.28it/s, loss=0.727]
Epoch [1009/4000]: 100%|██████████| 10/10 [00:00<00:00, 105.34it/s, loss=1.23]
Epoch [1010/4000]: 100%|██████████| 10/10 [00:00<00:00, 84.32it/s, loss=0.824]
Epoch [1011/4000]: 100%|██████████| 10/10 [00:00<00:00, 94.64it/s, loss=0.983]
Epoch [1012/4000]: 100%|██████████| 10/10 [00:00<00:00, 89.55it/s, loss=0.746]
Epoch [1013/4000]: 100%|██████████| 10/10 [00:00<00:00, 100.31it/s, loss=1.29]
Epoch [1014/4000]: 100%|██████████| 10/10 [00:00<0

Epoch [1101/4000]: Train loss: 0.9362, Valid loss: 0.9298


Epoch [1102/4000]: 100%|██████████| 10/10 [00:00<00:00, 130.58it/s, loss=0.999]
Epoch [1103/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.21it/s, loss=0.709]
Epoch [1104/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.93it/s, loss=0.917]
Epoch [1105/4000]: 100%|██████████| 10/10 [00:00<00:00, 126.25it/s, loss=1.72]
Epoch [1106/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.50it/s, loss=0.766]
Epoch [1107/4000]: 100%|██████████| 10/10 [00:00<00:00, 136.51it/s, loss=0.883]
Epoch [1108/4000]: 100%|██████████| 10/10 [00:00<00:00, 135.49it/s, loss=0.696]
Epoch [1109/4000]: 100%|██████████| 10/10 [00:00<00:00, 117.61it/s, loss=0.89]
Epoch [1110/4000]: 100%|██████████| 10/10 [00:00<00:00, 123.78it/s, loss=1.1]
Epoch [1111/4000]: 100%|██████████| 10/10 [00:00<00:00, 109.17it/s, loss=0.922]
Epoch [1112/4000]: 100%|██████████| 10/10 [00:00<00:00, 118.14it/s, loss=1.26]
Epoch [1113/4000]: 100%|██████████| 10/10 [00:00<00:00, 137.23it/s, loss=0.856]
Epoch [1114/4000]: 100%|██████████| 10/10 [00

Epoch [1201/4000]: Train loss: 0.8991, Valid loss: 0.8926


Epoch [1202/4000]: 100%|██████████| 10/10 [00:00<00:00, 119.42it/s, loss=0.794]
Epoch [1203/4000]: 100%|██████████| 10/10 [00:00<00:00, 107.82it/s, loss=0.816]
Epoch [1204/4000]: 100%|██████████| 10/10 [00:00<00:00, 53.62it/s, loss=1]   
Epoch [1205/4000]: 100%|██████████| 10/10 [00:00<00:00, 70.61it/s, loss=0.828]
Epoch [1206/4000]: 100%|██████████| 10/10 [00:00<00:00, 100.84it/s, loss=1.02]
Epoch [1207/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.61it/s, loss=0.93]
Epoch [1208/4000]: 100%|██████████| 10/10 [00:00<00:00, 133.69it/s, loss=0.935]
Epoch [1209/4000]: 100%|██████████| 10/10 [00:00<00:00, 130.56it/s, loss=0.899]
Epoch [1210/4000]: 100%|██████████| 10/10 [00:00<00:00, 130.22it/s, loss=0.985]
Epoch [1211/4000]: 100%|██████████| 10/10 [00:00<00:00, 108.98it/s, loss=0.962]
Epoch [1212/4000]: 100%|██████████| 10/10 [00:00<00:00, 117.96it/s, loss=0.865]
Epoch [1213/4000]: 100%|██████████| 10/10 [00:00<00:00, 122.28it/s, loss=1]
Epoch [1214/4000]: 100%|██████████| 10/10 [00:00<

Epoch [1301/4000]: Train loss: 0.9362, Valid loss: 0.9191


Epoch [1302/4000]: 100%|██████████| 10/10 [00:00<00:00, 129.04it/s, loss=1.07]
Epoch [1303/4000]: 100%|██████████| 10/10 [00:00<00:00, 129.65it/s, loss=0.975]
Epoch [1304/4000]: 100%|██████████| 10/10 [00:00<00:00, 131.93it/s, loss=0.918]
Epoch [1305/4000]: 100%|██████████| 10/10 [00:00<00:00, 119.36it/s, loss=1.17]
Epoch [1306/4000]: 100%|██████████| 10/10 [00:00<00:00, 93.68it/s, loss=0.771]
Epoch [1307/4000]: 100%|██████████| 10/10 [00:00<00:00, 88.47it/s, loss=0.852]
Epoch [1308/4000]: 100%|██████████| 10/10 [00:00<00:00, 130.49it/s, loss=1.03]
Epoch [1309/4000]: 100%|██████████| 10/10 [00:00<00:00, 128.55it/s, loss=1.22]
Epoch [1310/4000]: 100%|██████████| 10/10 [00:00<00:00, 132.85it/s, loss=1]
Epoch [1311/4000]: 100%|██████████| 10/10 [00:00<00:00, 127.37it/s, loss=0.9]
Epoch [1312/4000]: 100%|██████████| 10/10 [00:00<00:00, 119.19it/s, loss=1.46]
Epoch [1313/4000]: 100%|██████████| 10/10 [00:00<00:00, 121.26it/s, loss=1.06]
Epoch [1314/4000]: 100%|██████████| 10/10 [00:00<00:00


Model is not improving, so we halt the training session.


(0.8498845100402832, 0.893162339925766, 0.780752052863439)

In [26]:
%reload_ext tensorboard
%tensorboard --logdir=./runs/

Reusing TensorBoard on port 6006 (pid 11888), started 2 days, 22:03:54 ago. (Use '!kill 11888' to kill it.)

In [17]:
def save_pred(preds, file):
    ''' Save predictions to specified file '''
    with open(file, 'w') as fp:
        writer = csv.writer(fp)
        writer.writerow(['id', 'tested_positive'])
        for i, p in enumerate(preds):
            writer.writerow([i, p])

model = My_Model(input_dim=x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
preds = predict(test_loader, model, device) 
save_pred(preds, 'pred.csv')

100%|██████████| 4/4 [00:00<00:00, 294.14it/s]
