In [1]:
#%%
from array import array
from cmath import nan
from pyexpat import model
import statistics
from tkinter.ttk import Separator
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchviz import make_dot
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms
from torch.autograd import variable
from itertools import chain
from sklearn import metrics as met
import pickle
from icecream import ic

import matplotlib.pyplot as plt
import pathlib
from sklearn.model_selection import train_test_split

from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from importlib import reload
# import util
# import model_torch_simple
# from torchmetrics import Accuracy
from tqdm import tqdm
import argparse
from icecream import ic
import numpy as np
from PIL import Image
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(42)

<torch._C.Generator at 0x7f819b1bf030>

In [2]:
seed = 42
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
def value_counts_list(lst):
    """
    Computes the frequency count of unique elements in a list and returns a dictionary, sorted by frequency count in
    descending order.

    Args:
    - lst (list): List of elements

    Returns:
    - dict: Dictionary with unique elements as keys and their frequency count as values, sorted by frequency count
    in descending order
    """
    value_counts = {}
    for item in lst:
        if item in value_counts:
            value_counts[item] += 1
        else:
            value_counts[item] = 1
    sorted_value_counts = dict(sorted(value_counts.items(), key=lambda x: x[1], reverse=True))
    return sorted_value_counts

def print_full(x):
    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', 2000)
    pd.set_option('display.float_format', '{:20,.2f}'.format)
    pd.set_option('display.max_colwidth', None)
    print(x)
    pd.reset_option('display.max_rows')
    pd.reset_option('display.max_columns')
    pd.reset_option('display.width')
    pd.reset_option('display.float_format')
    pd.reset_option('display.max_colwidth')

In [4]:
original_data = pd.read_csv('data_aa/aa_rpoB.csv', header=None)
original_featrues = pd.read_csv('data_aa/RIF_MIC.csv', header=None)
data = original_data

target = original_featrues

train_data_index = np.random.choice(data.shape[0], size=int(data.shape[0]*0.8), replace=False)
all_indices = np.arange(data.shape[0])
test_data_index = np.setdiff1d(all_indices, train_data_index)

train_data = data.iloc[train_data_index,:]
train_target = target.iloc[train_data_index,:]
train_data = train_data.reset_index(drop=True)
train_target = train_target.reset_index(drop=True)
#don't touch test data, split out validation data from training data during training
test_data = data.iloc[test_data_index,:]
test_target = target.iloc[test_data_index,:]
test_data = test_data.reset_index(drop=True)
test_target = test_target.reset_index(drop=True)

In [5]:
class Dataset(torch.utils.data.Dataset): #? what's the difference between using inheritance and not?
    def __init__(
        self,
        train_df,
        mic_df,
        transform = None,
    ):
        self.transform = transform
        self.train_df = train_df
        self.mic_df = mic_df
        if not self.train_df.index.equals(self.mic_df.index):
            raise ValueError(
                "Indices of training data and resistance data don't match up"
            )

    def __getitem__(self, index):
        """
        numerical index --> get `index`-th sample
        string index --> get sample with name `index`
        """
        if isinstance(index, int):
            train = self.train_df.iloc[index]
            mic = self.mic_df.loc[index]
            
        elif isinstance(index, str):
            trains = self.train_df.loc[index]
            mic = self.mic_df.loc[index]
        else:
            raise ValueError(
                "Index needs to be an integer or a sample name present in the dataset"
            )
        
        if self.transform:
            self.mic_mean = self.mic_df.mean()
            self.mic_std = self.mic_df.std()
            mic = (mic - self.mic_mean) / self.mic_std
        
        return  torch.tensor(train),  torch.tensor(mic)
    def __len__(self):
        return self.mic_df.shape[0]
    
training_dataset = Dataset(train_data, train_target, transform=False)
train_dataset, val_dataset = random_split(training_dataset, [int(len(training_dataset)*0.8), len(training_dataset)-int(len(training_dataset)*0.8)])

In [6]:
torch.tensor(train_target.loc[1])

tensor([0.1200], dtype=torch.float64)

In [7]:
def get_masked_loss(loss_fn):
    """
    Returns a loss function that ignores NaN values
    """

    def masked_loss(y_true, y_pred):
        y_pred = y_pred.view(-1, 1)  # Ensure y_pred has the same shape as y_true and non_nan_mask
        # ic(y_true)
        non_nan_mask = ~y_true.isnan()
        # ic(non_nan_mask)
        y_true_non_nan = y_true[non_nan_mask]
        y_pred_non_nan = y_pred[non_nan_mask]

        return loss_fn(y_pred_non_nan, y_true_non_nan)

    return masked_loss

masked_MSE = get_masked_loss(torch.nn.MSELoss())

# Model

In [8]:
class Model(nn.Module):
    def __init__(self, in_channel = 869, first_h_layer = 469, out_channel=1, batch_size=1, dropout_rate=0.0, num_dense_layers=3, filter_scaling_factor=1.5):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.in_channel = in_channel
        self.first_h_layer = first_h_layer
        self.out_channel = out_channel
        self.dense_dropout_rate = dropout_rate
        self.num_dense_layers = num_dense_layers
        self.filter_scaling_factor=filter_scaling_factor
        
        self.dense_layers = nn.ModuleList()
        for i in range(self.num_dense_layers):
            layer = self._dense_layer(100,100)
            self.dense_layers.append(layer)
        
        # current_num_filters = self.cnn_output_size(self.in_channel, 12)
        current_num_filters = self.cnn_output_size_multilayer(self.in_channel, 12, num_layers=2)
        num_dense_neurons = 100
        num_dense_layers = 2
        self.dense_layers = nn.ModuleList(
            self._dense_layer(input_dim, num_dense_neurons)
            for input_dim in [current_num_filters]
            + [num_dense_neurons] * (num_dense_layers - 1) #how does this work?
        )
            # current_num_filters = int(current_num_filters * filter_scaling_factor)

        # self.feature_extraction = nn.Conv1d(in_channels, hidden, kernel_size=kernel_size),]
        # self.starting_layers = nn.Sequential(
        #     nn.Linear(self.in_channel, self.first_h_layer),
        #     nn.BatchNorm1d(self.first_h_layer),
        #     nn.ReLU(),
        #     nn.Dropout(self.dense_dropout_rate),  # Dropout layer after the first ReLU
        #     nn.Linear(self.first_h_layer, 100),
        #     nn.BatchNorm1d(100),
        #     nn.ReLU(),
        #     nn.Dropout(self.dense_dropout_rate))  # Dropout layer after the first ReLU
        
        self.starting_layers = nn.Sequential(
            nn.Conv1d(1, 1, 12),
            nn.BatchNorm1d(1),
            nn.ReLU(),
            nn.Dropout(self.dense_dropout_rate),  # Dropout layer after the first ReLU
            nn.Conv1d(1, 1, 12),
            nn.BatchNorm1d(1),
            nn.ReLU(),
            nn.Dropout(self.dense_dropout_rate))  # Dropout layer after the first ReLU
            
        self.out_layer = nn.Linear(100, self.out_channel)
        
        self.apply(self.init_weights)
        
    # def cnn_output_size(self, n, f, p = 0, s = 1):
    #     '''
    #     n: input size
    #     f: kernel size
    #     p: padding
    #     s: stride
    #     '''
    #     return int((n - f + 2 * p) / s) + 1
    
    def cnn_output_size_multilayer(self, n, f, p=0, s=1, num_layers=1):
        '''
        n: input size
        f: kernel size (assuming the same for all layers for simplicity)
        p: padding (assuming the same for all layers for simplicity)
        s: stride (assuming the same for all layers for simplicity)
        num_layers: number of convolutional layers
        '''
        output_size = n
        for _ in range(num_layers):
            output_size = int((output_size - f + 2 * p) / s) + 1
        return output_size
    
    def _dense_layer(self, n_in, n_out):
        return nn.Sequential(
            nn.Linear(n_in, n_out),
            nn.BatchNorm1d(n_out),
            nn.ReLU(),
            nn.Dropout(p=self.dense_dropout_rate)
        )

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        # print('=== before starting layer', x.size())
        x = self.starting_layers(x)
        # print('=== end of starting layer', x.size())
        x = x.view(x.size(0), -1)
        # print('=== Flattened', x.size())
        
        current_num_filters1 = self.cnn_output_size_multilayer(self.in_channel, 12, num_layers=2)
        # print(current_num_filters1, current_num_filters2)

        for layer in self.dense_layers:
            x = layer(x)
        # print('=== After Dense layer', x.size())

        out = self.out_layer(x)
        # print('After out layer', out.size())
        return out

torch.cuda.empty_cache()

epoch = 300
batch_size = 32
lr = 0.0001

model = Model(in_channel = 869, 
              first_h_layer = 469, 
              out_channel=1, 
              num_dense_layers=4,
              batch_size=batch_size)

model = model.float()
model = model.to(device)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, drop_last=True)
# criterion = nn.MSELoss()
criterion = masked_MSE
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2, verbose=True)

# Training

In [9]:
torch.cuda.empty_cache()
import gc; gc.collect()
# ic.enable()
ic.disable()

train_epoch_loss = []
test_epoch_loss = []

for e in tqdm(range(1, epoch+1)):
    model.train()
    train_batch_loss = []
    test_batch_loss = []
    # print('epoch:', e)
    for x, y in train_loader:
        # print('batch x size:',x.size() )
        x_batch = torch.unsqueeze(x, 1).to(device)
        y_batch = y.to(device)
        x_batch = x_batch.float()
        y_batch = y_batch.float()
        # y_batch = y_batch.view(-1)

        # y_batch = one_hot_torch(y).to(device)
        # print('batch y size before flatten:',y_batch.size())
        # y_batch = y_batch.flatten()
        # print('batch y size after flatten:',y_batch.size())
        # print(x_batch.size())
        # print(x_batch.size())
# For example, if you have a convolutional layer with 64 output channels, 3 input channels, and a kernel size of 3x3, the weight parameters would have a dimension of (64, 3, 3, 3)
        # print(x_batch.size())
        pred = model(x_batch.float())
        # print(x_batch)
        # print(pred)
        # pred = pred.unsqueeze(0)
        # ic(pred)
        # ic(y_batch)
        ic(pred.size())
        loss_train = criterion(y_batch, pred)
        ic(loss_train)
        train_batch_loss.append(loss_train.detach())
        
        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()

        # print(f'Batch - GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB')

    train_epoch_loss.append(torch.mean(torch.stack(train_batch_loss)).detach().cpu().numpy())
    with torch.no_grad():
        # print('>> test')
        # for x, y in test_loader:
            # print('x size:',x.size() )
        x_batch = torch.unsqueeze(x, 1).to(device)
        # print('batch x size:',x_batch.size() )
        y_batch = y.to(device)
        # print(x_batch.size())
        # y_batch = torch.Tensor.float(y).to(device)
        # x_batch = x_batch.permute(0, 3, 1, 2).to(device)
        pred = model(x_batch.float())
        # pred = pred.unsqueeze(0)

        loss_test = criterion(y_batch, pred)
        test_batch_loss.append(loss_test )
    test_epoch_loss.append(torch.mean(torch.stack(test_batch_loss)).detach().cpu().numpy())
        
    print(f'Epoch {e}')
    print(f"Training loss: {torch.mean(torch.stack(train_batch_loss)).detach().cpu().numpy()}")
    print(f"Validation loss: {torch.mean(torch.stack(test_batch_loss)).detach().cpu().numpy()}") 
    # scheduler.step(torch.mean(torch.stack(test_batch_loss)))
    # print(train_batch_loss)
    # print(test_batch_loss)
    # print(f"Training loss: {np.mean(train_batch_loss)}")
    # print(f"Validation loss: {np.mean(test_batch_loss)}")
print('==='*10)
# torch.save(model.state_dict(), '/mnt/storageG1/lwang/Projects/tb_dr_MIC/saved_models/rif_model.pt')
# fig, ax = plt.subplots()
# x = np.arange(1, epoch+1, 1)
# ax.plot(x, train_epoch_loss,label='Training')
# ax.plot(x, test_epoch_loss,label='Validation')
# ax.legend()
# ax.set_xlabel("Number of Epoch")
# ax.set_ylabel("Loss")
# ax.set_xticks(np.arange(0, epoch+1, 10))
# ax.set_title(f'Loss: Learning_rate:{lr}, cnn_dr:{cnn_dr}, cnn_dr:{fc_dr}')
# # ax_2 = ax.twinx()
# # ax_2.plot(history["lr"], "k--", lw=1)
# # ax_2.set_yscale("log")
# # ax.set_ylim(ax.get_ylim()[0], history["training_losses"][0])
# ax.grid(axis="x")
# fig.tight_layout()
# fig.show()
# fig.savefig(f'./graphs1/loss_lr_{lr}_cnn_dr_{cnn_dr}_fc_dr_{fc_dr}.png')
# print(f'./graphs1/loss_lr_{lr}_cnn_dr_{cnn_dr}_fc_dr_{fc_dr}.png')

  0%|          | 1/300 [00:30<2:33:49, 30.87s/it]

Epoch 1
Training loss: 49.66083908081055
Validation loss: 55.50987452682339


  1%|          | 2/300 [01:00<2:28:38, 29.93s/it]

Epoch 2
Training loss: 33.56352996826172
Validation loss: 29.15038745023667


  1%|          | 3/300 [01:29<2:26:46, 29.65s/it]

Epoch 3
Training loss: 22.357574462890625
Validation loss: 18.985679772099424


  1%|▏         | 4/300 [01:58<2:25:38, 29.52s/it]

Epoch 4
Training loss: 15.404451370239258
Validation loss: 21.320027736264628


  2%|▏         | 5/300 [02:28<2:24:37, 29.42s/it]

Epoch 5
Training loss: 12.803400039672852
Validation loss: 13.339097853612266


  2%|▏         | 6/300 [02:57<2:23:49, 29.35s/it]

Epoch 6
Training loss: 11.291816711425781
Validation loss: 7.309680347510268


  2%|▏         | 7/300 [03:26<2:23:16, 29.34s/it]

Epoch 7
Training loss: 10.274328231811523
Validation loss: 18.336202262076604


  3%|▎         | 8/300 [03:55<2:22:39, 29.31s/it]

Epoch 8
Training loss: 9.993467330932617
Validation loss: 8.28920683379579


  3%|▎         | 9/300 [04:25<2:22:03, 29.29s/it]

Epoch 9
Training loss: 9.58013916015625
Validation loss: 10.820426511078638


  3%|▎         | 10/300 [04:54<2:21:28, 29.27s/it]

Epoch 10
Training loss: 9.368375778198242
Validation loss: 11.738287405553844


  4%|▎         | 11/300 [05:23<2:20:55, 29.26s/it]

Epoch 11
Training loss: 9.14955997467041
Validation loss: 5.2669991669221075


  4%|▍         | 12/300 [05:52<2:20:21, 29.24s/it]

Epoch 12
Training loss: 9.153104782104492
Validation loss: 11.940626812902755


  4%|▍         | 13/300 [06:21<2:19:45, 29.22s/it]

Epoch 13
Training loss: 9.178110122680664
Validation loss: 7.956749939080966


  5%|▍         | 14/300 [06:51<2:19:13, 29.21s/it]

Epoch 14
Training loss: 9.191662788391113
Validation loss: 11.669393446114052


  5%|▌         | 15/300 [07:20<2:18:51, 29.23s/it]

Epoch 15
Training loss: 8.7367582321167
Validation loss: 5.472449863816139


  5%|▌         | 16/300 [07:49<2:18:24, 29.24s/it]

Epoch 16
Training loss: 8.745396614074707
Validation loss: 7.081565644136534


  6%|▌         | 17/300 [08:18<2:17:53, 29.23s/it]

Epoch 17
Training loss: 8.731491088867188
Validation loss: 8.078559354176347


  6%|▌         | 18/300 [08:48<2:17:22, 29.23s/it]

Epoch 18
Training loss: 8.642000198364258
Validation loss: 9.034059423733876


  6%|▋         | 19/300 [09:17<2:16:51, 29.22s/it]

Epoch 19
Training loss: 8.697237014770508
Validation loss: 4.880261218293275


  7%|▋         | 20/300 [09:46<2:16:24, 29.23s/it]

Epoch 20
Training loss: 8.613138198852539
Validation loss: 8.628604181636291


  7%|▋         | 21/300 [10:15<2:15:53, 29.22s/it]

Epoch 21
Training loss: 8.79513168334961
Validation loss: 10.272270939893213


  7%|▋         | 22/300 [10:44<2:15:26, 29.23s/it]

Epoch 22
Training loss: 8.541104316711426
Validation loss: 4.00701676354429


  8%|▊         | 23/300 [11:14<2:15:00, 29.24s/it]

Epoch 23
Training loss: 8.430657386779785
Validation loss: 8.238205522252532


  8%|▊         | 24/300 [11:43<2:14:22, 29.21s/it]

Epoch 24
Training loss: 8.445185661315918
Validation loss: 5.390398495040644


  8%|▊         | 25/300 [12:12<2:13:59, 29.23s/it]

Epoch 25
Training loss: 8.443514823913574
Validation loss: 16.245963981979138


  9%|▊         | 26/300 [12:41<2:13:37, 29.26s/it]

Epoch 26
Training loss: 8.439979553222656
Validation loss: 4.848776109281932


  9%|▉         | 27/300 [13:11<2:13:05, 29.25s/it]

Epoch 27
Training loss: 8.356911659240723
Validation loss: 8.988897571562884


  9%|▉         | 28/300 [13:40<2:12:39, 29.26s/it]

Epoch 28
Training loss: 8.380887031555176
Validation loss: 2.008327764157399


 10%|▉         | 29/300 [14:09<2:12:12, 29.27s/it]

Epoch 29
Training loss: 8.26388168334961
Validation loss: 5.166547120916279


 10%|█         | 30/300 [14:39<2:11:46, 29.28s/it]

Epoch 30
Training loss: 8.24036693572998
Validation loss: 8.201739760181866


 10%|█         | 31/300 [15:08<2:11:13, 29.27s/it]

Epoch 31
Training loss: 8.331669807434082
Validation loss: 3.788938098990146


 11%|█         | 32/300 [15:37<2:10:40, 29.26s/it]

Epoch 32
Training loss: 8.1993989944458
Validation loss: 3.6489230907426657


 11%|█         | 33/300 [16:06<2:10:12, 29.26s/it]

Epoch 33
Training loss: 8.210387229919434
Validation loss: 13.42974463797494


 11%|█▏        | 34/300 [16:36<2:09:45, 29.27s/it]

Epoch 34
Training loss: 8.243133544921875
Validation loss: 10.724645717448325


 12%|█▏        | 35/300 [17:05<2:09:16, 29.27s/it]

Epoch 35
Training loss: 8.192461967468262
Validation loss: 11.263926122770568


 12%|█▏        | 36/300 [17:34<2:08:57, 29.31s/it]

Epoch 36
Training loss: 8.152098655700684
Validation loss: 10.930930420457464


 12%|█▏        | 37/300 [18:04<2:08:26, 29.30s/it]

Epoch 37
Training loss: 8.237125396728516
Validation loss: 6.688894201741747


 13%|█▎        | 38/300 [18:33<2:07:51, 29.28s/it]

Epoch 38
Training loss: 8.128585815429688
Validation loss: 9.960670643748404


 13%|█▎        | 39/300 [19:02<2:07:22, 29.28s/it]

Epoch 39
Training loss: 8.069990158081055
Validation loss: 7.849955029630461


 13%|█▎        | 40/300 [19:31<2:06:55, 29.29s/it]

Epoch 40
Training loss: 8.055827140808105
Validation loss: 12.078702098093073


 14%|█▎        | 41/300 [20:01<2:06:31, 29.31s/it]

Epoch 41
Training loss: 8.165495872497559
Validation loss: 11.627730140750518


 14%|█▍        | 42/300 [20:30<2:06:06, 29.33s/it]

Epoch 42
Training loss: 8.128790855407715
Validation loss: 11.070181064813445


 14%|█▍        | 43/300 [21:01<2:07:11, 29.69s/it]

Epoch 43
Training loss: 8.176855087280273
Validation loss: 6.407235709423962


 15%|█▍        | 44/300 [21:31<2:08:02, 30.01s/it]

Epoch 44
Training loss: 8.099689483642578
Validation loss: 10.81352495924059


 15%|█▌        | 45/300 [22:02<2:08:18, 30.19s/it]

Epoch 45
Training loss: 8.088249206542969
Validation loss: 8.400149528376332


 15%|█▌        | 46/300 [22:33<2:08:25, 30.34s/it]

Epoch 46
Training loss: 8.074565887451172
Validation loss: 15.9977563353915


 16%|█▌        | 47/300 [23:02<2:06:35, 30.02s/it]

Epoch 47
Training loss: 8.06512451171875
Validation loss: 5.403563276978686


 16%|█▌        | 48/300 [23:31<2:04:57, 29.75s/it]

Epoch 48
Training loss: 7.961007118225098
Validation loss: 7.273782211629223


 16%|█▋        | 49/300 [24:00<2:03:46, 29.59s/it]

Epoch 49
Training loss: 8.030779838562012
Validation loss: 6.721518968654087


 17%|█▋        | 50/300 [24:31<2:04:39, 29.92s/it]

Epoch 50
Training loss: 7.981684684753418
Validation loss: 7.92495172827471


 17%|█▋        | 51/300 [25:02<2:05:15, 30.18s/it]

Epoch 51
Training loss: 7.961514472961426
Validation loss: 5.400659878188867


 17%|█▋        | 52/300 [25:33<2:05:31, 30.37s/it]

Epoch 52
Training loss: 8.08914566040039
Validation loss: 8.252520560014503


 18%|█▊        | 53/300 [26:03<2:05:13, 30.42s/it]

Epoch 53
Training loss: 7.928755283355713
Validation loss: 3.2397162586175168


 18%|█▊        | 54/300 [26:32<2:03:23, 30.09s/it]

Epoch 54
Training loss: 7.992559432983398
Validation loss: 7.906432301008593


 18%|█▊        | 55/300 [27:02<2:01:48, 29.83s/it]

Epoch 55
Training loss: 8.000497817993164
Validation loss: 5.733089159510429


 19%|█▊        | 56/300 [27:31<2:00:34, 29.65s/it]

Epoch 56
Training loss: 7.95219612121582
Validation loss: 3.5827295629619886


 19%|█▉        | 57/300 [28:00<1:59:31, 29.51s/it]

Epoch 57
Training loss: 7.915036201477051
Validation loss: 5.347194834925707


 19%|█▉        | 58/300 [28:29<1:58:39, 29.42s/it]

Epoch 58
Training loss: 7.997901439666748
Validation loss: 4.14484269438614


 20%|█▉        | 59/300 [28:58<1:57:50, 29.34s/it]

Epoch 59
Training loss: 7.889901638031006
Validation loss: 6.60486792341656


 20%|██        | 60/300 [29:28<1:57:12, 29.30s/it]

Epoch 60
Training loss: 7.9992876052856445
Validation loss: 13.974092855400313


 20%|██        | 61/300 [29:57<1:56:34, 29.26s/it]

Epoch 61
Training loss: 7.8683624267578125
Validation loss: 7.6672242358401155


 21%|██        | 62/300 [30:26<1:56:01, 29.25s/it]

Epoch 62
Training loss: 7.881592750549316
Validation loss: 7.156959930214505


 21%|██        | 63/300 [30:55<1:55:29, 29.24s/it]

Epoch 63
Training loss: 7.894551753997803
Validation loss: 6.25990277477222


 21%|██▏       | 64/300 [31:24<1:54:56, 29.22s/it]

Epoch 64
Training loss: 7.842770099639893
Validation loss: 9.876112834153023


 22%|██▏       | 65/300 [31:54<1:54:25, 29.22s/it]

Epoch 65
Training loss: 7.9019246101379395
Validation loss: 12.420518083477656


 22%|██▏       | 66/300 [32:23<1:53:57, 29.22s/it]

Epoch 66
Training loss: 7.886077880859375
Validation loss: 4.725978964992568


 22%|██▏       | 67/300 [32:52<1:53:28, 29.22s/it]

Epoch 67
Training loss: 7.808902740478516
Validation loss: 4.595295395261657


 23%|██▎       | 68/300 [33:21<1:53:01, 29.23s/it]

Epoch 68
Training loss: 7.823269844055176
Validation loss: 6.608535504780011


 23%|██▎       | 69/300 [33:50<1:52:22, 29.19s/it]

Epoch 69
Training loss: 7.802811145782471
Validation loss: 4.622199409142007


 23%|██▎       | 70/300 [34:20<1:51:59, 29.21s/it]

Epoch 70
Training loss: 7.845193862915039
Validation loss: 8.428539787061528


 24%|██▎       | 71/300 [34:49<1:51:27, 29.20s/it]

Epoch 71
Training loss: 7.829935073852539
Validation loss: 3.6351057056370726


 24%|██▍       | 72/300 [35:18<1:50:59, 29.21s/it]

Epoch 72
Training loss: 7.837051868438721
Validation loss: 15.001408735185915


 24%|██▍       | 73/300 [35:47<1:50:26, 29.19s/it]

Epoch 73
Training loss: 7.877902984619141
Validation loss: 4.023919236027955


 25%|██▍       | 74/300 [36:17<1:50:03, 29.22s/it]

Epoch 74
Training loss: 7.849288463592529
Validation loss: 9.821581419671547


 25%|██▌       | 75/300 [36:46<1:49:38, 29.24s/it]

Epoch 75
Training loss: 7.953643321990967
Validation loss: 11.945088108882356


 25%|██▌       | 76/300 [37:15<1:49:11, 29.25s/it]

Epoch 76
Training loss: 7.828729629516602
Validation loss: 7.293837091563297


 26%|██▌       | 77/300 [37:44<1:48:44, 29.26s/it]

Epoch 77
Training loss: 7.744184970855713
Validation loss: 13.50576986648915


 26%|██▌       | 78/300 [38:14<1:48:12, 29.25s/it]

Epoch 78
Training loss: 7.857835292816162
Validation loss: 7.258489700833014


 26%|██▋       | 79/300 [38:43<1:47:44, 29.25s/it]

Epoch 79
Training loss: 7.832492828369141
Validation loss: 2.768994236772821


 27%|██▋       | 80/300 [39:12<1:47:15, 29.25s/it]

Epoch 80
Training loss: 7.7589826583862305
Validation loss: 3.747514580489585


 27%|██▋       | 81/300 [39:41<1:46:45, 29.25s/it]

Epoch 81
Training loss: 7.8219099044799805
Validation loss: 3.9043627019833824


 27%|██▋       | 82/300 [40:11<1:46:13, 29.24s/it]

Epoch 82
Training loss: 7.789309501647949
Validation loss: 6.219312469206535


 28%|██▊       | 83/300 [40:40<1:45:38, 29.21s/it]

Epoch 83
Training loss: 7.847796440124512
Validation loss: 7.797228475795444


 28%|██▊       | 84/300 [41:09<1:45:04, 29.19s/it]

Epoch 84
Training loss: 7.770894527435303
Validation loss: 10.681081725985651


 28%|██▊       | 85/300 [41:38<1:44:38, 29.20s/it]

Epoch 85
Training loss: 7.744667053222656
Validation loss: 7.475604361935453


 29%|██▊       | 86/300 [42:07<1:44:03, 29.18s/it]

Epoch 86
Training loss: 7.816633224487305
Validation loss: 10.3080302320554


 29%|██▉       | 87/300 [42:36<1:43:35, 29.18s/it]

Epoch 87
Training loss: 7.805539131164551
Validation loss: 7.9275323062311465


 29%|██▉       | 88/300 [43:06<1:43:04, 29.17s/it]

Epoch 88
Training loss: 7.7813029289245605
Validation loss: 3.6342073016767507


 30%|██▉       | 89/300 [43:35<1:42:35, 29.17s/it]

Epoch 89
Training loss: 7.8177618980407715
Validation loss: 7.2625789143837


 30%|███       | 90/300 [44:04<1:42:06, 29.18s/it]

Epoch 90
Training loss: 7.7509636878967285
Validation loss: 3.8896002858209826


 30%|███       | 91/300 [44:33<1:41:34, 29.16s/it]

Epoch 91
Training loss: 7.728334903717041
Validation loss: 9.008366521161642


 31%|███       | 92/300 [45:02<1:41:02, 29.15s/it]

Epoch 92
Training loss: 7.740721702575684
Validation loss: 6.862793862815344


 31%|███       | 93/300 [45:31<1:40:36, 29.16s/it]

Epoch 93
Training loss: 7.772917747497559
Validation loss: 14.809911497040773


 31%|███▏      | 94/300 [46:00<1:40:01, 29.14s/it]

Epoch 94
Training loss: 7.659872531890869
Validation loss: 8.80884855445936


 32%|███▏      | 95/300 [46:30<1:39:36, 29.15s/it]

Epoch 95
Training loss: 7.790480136871338
Validation loss: 3.2628377242769804


 32%|███▏      | 96/300 [46:59<1:39:08, 29.16s/it]

Epoch 96
Training loss: 7.7538676261901855
Validation loss: 6.341197272935222


 32%|███▏      | 97/300 [47:28<1:38:36, 29.14s/it]

Epoch 97
Training loss: 7.842715263366699
Validation loss: 3.080102342450854


 33%|███▎      | 98/300 [47:57<1:38:12, 29.17s/it]

Epoch 98
Training loss: 7.722519874572754
Validation loss: 5.248787109416444


 33%|███▎      | 99/300 [48:26<1:37:45, 29.18s/it]

Epoch 99
Training loss: 7.745506286621094
Validation loss: 5.618805893526567


 33%|███▎      | 100/300 [48:56<1:37:17, 29.19s/it]

Epoch 100
Training loss: 7.804781913757324
Validation loss: 13.465292784402434


 34%|███▎      | 101/300 [49:25<1:36:48, 29.19s/it]

Epoch 101
Training loss: 7.725213050842285
Validation loss: 12.414558151735148


 34%|███▍      | 102/300 [49:54<1:36:21, 29.20s/it]

Epoch 102
Training loss: 7.677433013916016
Validation loss: 8.230833841842166


 34%|███▍      | 103/300 [50:23<1:35:50, 29.19s/it]

Epoch 103
Training loss: 7.690713882446289
Validation loss: 8.112575225382304


 35%|███▍      | 104/300 [50:52<1:35:22, 29.20s/it]

Epoch 104
Training loss: 7.6954522132873535
Validation loss: 0.933753215264959


 35%|███▌      | 105/300 [51:22<1:34:54, 29.20s/it]

Epoch 105
Training loss: 7.749810695648193
Validation loss: 9.438146151153166


 35%|███▌      | 106/300 [51:51<1:34:23, 29.19s/it]

Epoch 106
Training loss: 7.647247791290283
Validation loss: 4.064257371345464


 36%|███▌      | 107/300 [52:20<1:34:01, 29.23s/it]

Epoch 107
Training loss: 7.746039867401123
Validation loss: 6.471990298576484


 36%|███▌      | 108/300 [52:49<1:33:33, 29.24s/it]

Epoch 108
Training loss: 7.744247913360596
Validation loss: 18.96993294951073


 36%|███▋      | 109/300 [53:19<1:33:04, 29.24s/it]

Epoch 109
Training loss: 7.759212017059326
Validation loss: 7.2879098478612


 37%|███▋      | 110/300 [53:48<1:32:32, 29.22s/it]

Epoch 110
Training loss: 7.617599010467529
Validation loss: 9.012102769920702


 37%|███▋      | 111/300 [54:17<1:32:01, 29.21s/it]

Epoch 111
Training loss: 7.6746649742126465
Validation loss: 10.441264233069287


 37%|███▋      | 112/300 [54:46<1:31:27, 29.19s/it]

Epoch 112
Training loss: 7.666688919067383
Validation loss: 2.4463145725820596


 38%|███▊      | 113/300 [55:15<1:30:52, 29.16s/it]

Epoch 113
Training loss: 7.705674171447754
Validation loss: 6.281906585267666


 38%|███▊      | 114/300 [55:44<1:30:26, 29.18s/it]

Epoch 114
Training loss: 7.646799564361572
Validation loss: 10.131962576418964


 38%|███▊      | 115/300 [56:14<1:29:58, 29.18s/it]

Epoch 115
Training loss: 7.753618240356445
Validation loss: 17.166197908063808


 39%|███▊      | 116/300 [56:43<1:29:30, 29.19s/it]

Epoch 116
Training loss: 7.704839706420898
Validation loss: 10.107736159076357


 39%|███▉      | 117/300 [57:12<1:28:59, 29.18s/it]

Epoch 117
Training loss: 7.707926273345947
Validation loss: 5.212863428189654


 39%|███▉      | 118/300 [57:41<1:28:29, 29.17s/it]

Epoch 118
Training loss: 7.666252136230469
Validation loss: 9.435593234936533


 40%|███▉      | 119/300 [58:10<1:27:59, 29.17s/it]

Epoch 119
Training loss: 7.717185974121094
Validation loss: 17.88000259139506


 40%|████      | 120/300 [58:39<1:27:31, 29.17s/it]

Epoch 120
Training loss: 7.7202043533325195
Validation loss: 7.4745816575162936


 40%|████      | 121/300 [59:09<1:27:04, 29.19s/it]

Epoch 121
Training loss: 7.693953990936279
Validation loss: 5.610586470362032


 41%|████      | 122/300 [59:38<1:26:34, 29.18s/it]

Epoch 122
Training loss: 7.655743598937988
Validation loss: 4.215632436028416


 41%|████      | 123/300 [1:00:07<1:26:06, 29.19s/it]

Epoch 123
Training loss: 7.666297912597656
Validation loss: 9.399258103857463


 41%|████▏     | 124/300 [1:00:36<1:25:36, 29.18s/it]

Epoch 124
Training loss: 7.674428462982178
Validation loss: 8.947182407588961


 42%|████▏     | 125/300 [1:01:05<1:25:05, 29.17s/it]

Epoch 125
Training loss: 7.651838302612305
Validation loss: 2.918806453744322


 42%|████▏     | 126/300 [1:01:34<1:24:32, 29.15s/it]

Epoch 126
Training loss: 7.682387351989746
Validation loss: 5.643839320446524


 42%|████▏     | 127/300 [1:02:04<1:24:03, 29.15s/it]

Epoch 127
Training loss: 7.611039638519287
Validation loss: 10.92624131998066


 43%|████▎     | 128/300 [1:02:33<1:23:35, 29.16s/it]

Epoch 128
Training loss: 7.611446857452393
Validation loss: 6.902343306888346


 43%|████▎     | 129/300 [1:03:02<1:23:05, 29.16s/it]

Epoch 129
Training loss: 7.62588357925415
Validation loss: 5.083057382282798


 43%|████▎     | 130/300 [1:03:31<1:22:37, 29.16s/it]

Epoch 130
Training loss: 7.621755599975586
Validation loss: 6.880023898106527


 44%|████▎     | 131/300 [1:04:00<1:22:09, 29.17s/it]

Epoch 131
Training loss: 7.583136558532715
Validation loss: 7.857247449205209


 44%|████▍     | 132/300 [1:04:29<1:21:38, 29.16s/it]

Epoch 132
Training loss: 7.6586384773254395
Validation loss: 4.266144941351767


 44%|████▍     | 133/300 [1:04:59<1:21:08, 29.15s/it]

Epoch 133
Training loss: 7.611662864685059
Validation loss: 6.2140387979156415


 45%|████▍     | 134/300 [1:05:28<1:20:39, 29.15s/it]

Epoch 134
Training loss: 7.63029670715332
Validation loss: 2.3835385674324208


 45%|████▌     | 135/300 [1:05:57<1:20:11, 29.16s/it]

Epoch 135
Training loss: 7.645773887634277
Validation loss: 5.376766637690165


 45%|████▌     | 136/300 [1:06:26<1:19:44, 29.18s/it]

Epoch 136
Training loss: 7.694850444793701
Validation loss: 10.867344066306892


 46%|████▌     | 137/300 [1:06:55<1:19:13, 29.17s/it]

Epoch 137
Training loss: 7.687257766723633
Validation loss: 7.1381199653440115


 46%|████▌     | 138/300 [1:07:24<1:18:44, 29.16s/it]

Epoch 138
Training loss: 7.643862247467041
Validation loss: 6.986284953304931


 46%|████▋     | 139/300 [1:07:54<1:18:14, 29.16s/it]

Epoch 139
Training loss: 7.587247371673584
Validation loss: 7.013330539061923


 47%|████▋     | 140/300 [1:08:23<1:17:45, 29.16s/it]

Epoch 140
Training loss: 7.585099220275879
Validation loss: 3.729819722147866


 47%|████▋     | 141/300 [1:08:52<1:17:16, 29.16s/it]

Epoch 141
Training loss: 7.61516809463501
Validation loss: 7.980809091063156


 47%|████▋     | 142/300 [1:09:21<1:16:46, 29.15s/it]

Epoch 142
Training loss: 7.637420654296875
Validation loss: 4.7397942971839555


 48%|████▊     | 143/300 [1:09:50<1:16:16, 29.15s/it]

Epoch 143
Training loss: 7.582626819610596
Validation loss: 7.530870868185179


 48%|████▊     | 144/300 [1:10:19<1:15:52, 29.18s/it]

Epoch 144
Training loss: 7.57761812210083
Validation loss: 4.763000259056961


 48%|████▊     | 145/300 [1:10:49<1:15:22, 29.18s/it]

Epoch 145
Training loss: 7.598067760467529
Validation loss: 3.971963154655546


 49%|████▊     | 146/300 [1:11:18<1:14:55, 29.19s/it]

Epoch 146
Training loss: 7.620502471923828
Validation loss: 8.85059941876192


 49%|████▉     | 147/300 [1:11:47<1:14:22, 29.17s/it]

Epoch 147
Training loss: 7.618775844573975
Validation loss: 10.406233725410813


 49%|████▉     | 148/300 [1:12:16<1:13:55, 29.18s/it]

Epoch 148
Training loss: 7.570343971252441
Validation loss: 12.62831047255428


 50%|████▉     | 149/300 [1:12:45<1:13:27, 29.19s/it]

Epoch 149
Training loss: 7.589951992034912
Validation loss: 8.3241402672348


 50%|█████     | 150/300 [1:13:15<1:12:59, 29.19s/it]

Epoch 150
Training loss: 7.597195148468018
Validation loss: 5.490060640487943


 50%|█████     | 151/300 [1:13:44<1:12:31, 29.20s/it]

Epoch 151
Training loss: 7.6125359535217285
Validation loss: 6.672438319535319


 51%|█████     | 152/300 [1:14:13<1:12:01, 29.20s/it]

Epoch 152
Training loss: 7.64846658706665
Validation loss: 6.55164004890071


 51%|█████     | 153/300 [1:14:42<1:11:30, 29.19s/it]

Epoch 153
Training loss: 7.573565483093262
Validation loss: 5.123043301483299


 51%|█████▏    | 154/300 [1:15:11<1:10:58, 29.17s/it]

Epoch 154
Training loss: 7.543949604034424
Validation loss: 4.803538764877781


 52%|█████▏    | 155/300 [1:15:40<1:10:27, 29.16s/it]

Epoch 155
Training loss: 7.574248313903809
Validation loss: 12.235497446932625


 52%|█████▏    | 156/300 [1:16:09<1:09:55, 29.14s/it]

Epoch 156
Training loss: 7.580355167388916
Validation loss: 10.774117649490911


 52%|█████▏    | 157/300 [1:16:39<1:09:29, 29.16s/it]

Epoch 157
Training loss: 7.548698425292969
Validation loss: 5.554015470692425


 53%|█████▎    | 158/300 [1:17:08<1:09:00, 29.16s/it]

Epoch 158
Training loss: 7.564929008483887
Validation loss: 11.770978131521488


 53%|█████▎    | 159/300 [1:17:37<1:08:33, 29.17s/it]

Epoch 159
Training loss: 7.571467876434326
Validation loss: 3.7008461774802486


 53%|█████▎    | 160/300 [1:18:06<1:08:03, 29.17s/it]

Epoch 160
Training loss: 7.574387550354004
Validation loss: 2.4656590857365406


 54%|█████▎    | 161/300 [1:18:35<1:07:33, 29.16s/it]

Epoch 161
Training loss: 7.5866851806640625
Validation loss: 4.425951930617332


 54%|█████▍    | 162/300 [1:19:05<1:07:05, 29.17s/it]

Epoch 162
Training loss: 7.548247814178467
Validation loss: 3.488578563474487


 54%|█████▍    | 163/300 [1:19:34<1:06:34, 29.16s/it]

Epoch 163
Training loss: 7.547278881072998
Validation loss: 7.710997826943419


 55%|█████▍    | 164/300 [1:20:03<1:06:06, 29.16s/it]

Epoch 164
Training loss: 7.576850414276123
Validation loss: 9.814760993032511


 55%|█████▌    | 165/300 [1:20:32<1:05:36, 29.16s/it]

Epoch 165
Training loss: 7.576894283294678
Validation loss: 4.268253374370164


 55%|█████▌    | 166/300 [1:21:01<1:05:09, 29.17s/it]

Epoch 166
Training loss: 7.559891223907471
Validation loss: 6.8761398808141845


 56%|█████▌    | 167/300 [1:21:30<1:04:40, 29.18s/it]

Epoch 167
Training loss: 7.620884895324707
Validation loss: 13.191939783978585


 56%|█████▌    | 168/300 [1:22:00<1:04:09, 29.17s/it]

Epoch 168
Training loss: 7.532492637634277
Validation loss: 6.406117989731257


 56%|█████▋    | 169/300 [1:22:29<1:03:42, 29.18s/it]

Epoch 169
Training loss: 7.666147232055664
Validation loss: 7.844349335365931


 57%|█████▋    | 170/300 [1:22:58<1:03:13, 29.18s/it]

Epoch 170
Training loss: 7.525668621063232
Validation loss: 7.0497299240359


 57%|█████▋    | 171/300 [1:23:27<1:02:43, 29.17s/it]

Epoch 171
Training loss: 7.5616679191589355
Validation loss: 10.445336686939529


 57%|█████▋    | 172/300 [1:23:56<1:02:16, 29.19s/it]

Epoch 172
Training loss: 7.508922576904297
Validation loss: 6.020205092775967


 58%|█████▊    | 173/300 [1:24:26<1:01:49, 29.21s/it]

Epoch 173
Training loss: 7.572762489318848
Validation loss: 6.925372076240027


 58%|█████▊    | 174/300 [1:24:55<1:01:18, 29.19s/it]

Epoch 174
Training loss: 7.509922027587891
Validation loss: 7.3487796350710965


 58%|█████▊    | 175/300 [1:25:24<1:00:50, 29.20s/it]

Epoch 175
Training loss: 7.527528762817383
Validation loss: 12.694020353399138


 59%|█████▊    | 176/300 [1:25:53<1:00:22, 29.21s/it]

Epoch 176
Training loss: 7.542186260223389
Validation loss: 14.637236068016573


 59%|█████▉    | 177/300 [1:26:22<59:53, 29.21s/it]  

Epoch 177
Training loss: 7.582786560058594
Validation loss: 5.351180386050834


 59%|█████▉    | 178/300 [1:26:52<59:22, 29.20s/it]

Epoch 178
Training loss: 7.5277581214904785
Validation loss: 9.804983778268648


 60%|█████▉    | 179/300 [1:27:21<58:50, 29.18s/it]

Epoch 179
Training loss: 7.603463172912598
Validation loss: 12.63237577098815


 60%|██████    | 180/300 [1:27:50<58:19, 29.16s/it]

Epoch 180
Training loss: 7.534034252166748
Validation loss: 5.626439478615296


 60%|██████    | 181/300 [1:28:19<57:49, 29.16s/it]

Epoch 181
Training loss: 7.587716579437256
Validation loss: 11.353417071281463


 61%|██████    | 182/300 [1:28:48<57:18, 29.14s/it]

Epoch 182
Training loss: 7.5763421058654785
Validation loss: 5.117779558926193


 61%|██████    | 183/300 [1:29:17<56:47, 29.13s/it]

Epoch 183
Training loss: 7.5713887214660645
Validation loss: 8.485479831197736


 61%|██████▏   | 184/300 [1:29:46<56:19, 29.13s/it]

Epoch 184
Training loss: 7.498690605163574
Validation loss: 4.015127309632519


 62%|██████▏   | 185/300 [1:30:15<55:50, 29.14s/it]

Epoch 185
Training loss: 7.5907063484191895
Validation loss: 10.456355691795165


 62%|██████▏   | 186/300 [1:30:45<55:24, 29.16s/it]

Epoch 186
Training loss: 7.486595153808594
Validation loss: 4.856349202944926


 62%|██████▏   | 187/300 [1:31:14<54:53, 29.14s/it]

Epoch 187
Training loss: 7.532492637634277
Validation loss: 13.05623519190461


 63%|██████▎   | 188/300 [1:31:43<54:24, 29.15s/it]

Epoch 188
Training loss: 7.57375431060791
Validation loss: 11.163012318720309


 63%|██████▎   | 189/300 [1:32:12<53:54, 29.14s/it]

Epoch 189
Training loss: 7.507139682769775
Validation loss: 8.882408341860145


 63%|██████▎   | 190/300 [1:32:41<53:25, 29.15s/it]

Epoch 190
Training loss: 7.506623268127441
Validation loss: 11.153859114377529


 64%|██████▎   | 191/300 [1:33:10<52:58, 29.16s/it]

Epoch 191
Training loss: 7.501060485839844
Validation loss: 7.179756603113264


 64%|██████▍   | 192/300 [1:33:40<52:31, 29.18s/it]

Epoch 192
Training loss: 7.6212687492370605
Validation loss: 8.29037021068407


 64%|██████▍   | 193/300 [1:34:09<52:00, 29.16s/it]

Epoch 193
Training loss: 7.506795883178711
Validation loss: 4.449526064938034


 65%|██████▍   | 194/300 [1:34:38<51:30, 29.16s/it]

Epoch 194
Training loss: 7.552479267120361
Validation loss: 7.389966089461446


 65%|██████▌   | 195/300 [1:35:07<51:02, 29.16s/it]

Epoch 195
Training loss: 7.568328857421875
Validation loss: 12.874584050794711


 65%|██████▌   | 196/300 [1:35:36<50:35, 29.19s/it]

Epoch 196
Training loss: 7.548863410949707
Validation loss: 5.945005354032691


 66%|██████▌   | 197/300 [1:36:05<50:06, 29.19s/it]

Epoch 197
Training loss: 7.549054145812988
Validation loss: 12.847956437599603


 66%|██████▌   | 198/300 [1:36:35<49:37, 29.19s/it]

Epoch 198
Training loss: 7.517399311065674
Validation loss: 5.827170677507313


 66%|██████▋   | 199/300 [1:37:04<49:07, 29.18s/it]

Epoch 199
Training loss: 7.525303363800049
Validation loss: 4.487982452209673


 67%|██████▋   | 200/300 [1:37:33<48:38, 29.19s/it]

Epoch 200
Training loss: 7.538659572601318
Validation loss: 8.826344437599914


 67%|██████▋   | 201/300 [1:38:02<48:08, 29.18s/it]

Epoch 201
Training loss: 7.565200328826904
Validation loss: 4.393674959734381


 67%|██████▋   | 202/300 [1:38:31<47:38, 29.17s/it]

Epoch 202
Training loss: 7.574506759643555
Validation loss: 6.905855331349123


 68%|██████▊   | 203/300 [1:39:01<47:10, 29.18s/it]

Epoch 203
Training loss: 7.519017696380615
Validation loss: 7.150977686952955


 68%|██████▊   | 204/300 [1:39:30<46:40, 29.18s/it]

Epoch 204
Training loss: 7.5667595863342285
Validation loss: 8.644158329075477


 68%|██████▊   | 205/300 [1:39:59<46:14, 29.20s/it]

Epoch 205
Training loss: 7.518638610839844
Validation loss: 6.977228777588833


 69%|██████▊   | 206/300 [1:40:28<45:42, 29.17s/it]

Epoch 206
Training loss: 7.503490447998047
Validation loss: 10.985379076136539


 69%|██████▉   | 207/300 [1:40:57<45:09, 29.14s/it]

Epoch 207
Training loss: 7.55114221572876
Validation loss: 7.61807347893283


 69%|██████▉   | 208/300 [1:41:26<44:39, 29.13s/it]

Epoch 208
Training loss: 7.513164520263672
Validation loss: 5.490614260775825


 70%|██████▉   | 209/300 [1:41:55<44:10, 29.13s/it]

Epoch 209
Training loss: 7.527181148529053
Validation loss: 4.030696280740624


 70%|███████   | 210/300 [1:42:25<43:42, 29.14s/it]

Epoch 210
Training loss: 7.524775505065918
Validation loss: 3.762412661049225


 70%|███████   | 211/300 [1:42:54<43:14, 29.15s/it]

Epoch 211
Training loss: 7.578694820404053
Validation loss: 4.194192775798861


 71%|███████   | 212/300 [1:43:23<42:43, 29.13s/it]

Epoch 212
Training loss: 7.565788269042969
Validation loss: 4.612949158171277


 71%|███████   | 213/300 [1:43:52<42:15, 29.14s/it]

Epoch 213
Training loss: 7.473491668701172
Validation loss: 8.915223835201925


 71%|███████▏  | 214/300 [1:44:21<41:43, 29.11s/it]

Epoch 214
Training loss: 7.528615951538086
Validation loss: 9.105396162982617


 72%|███████▏  | 215/300 [1:44:50<41:13, 29.10s/it]

Epoch 215
Training loss: 7.509617328643799
Validation loss: 10.724138445011388


 72%|███████▏  | 216/300 [1:45:19<40:45, 29.11s/it]

Epoch 216
Training loss: 7.527342319488525
Validation loss: 3.684471865175692


 72%|███████▏  | 217/300 [1:45:48<40:15, 29.10s/it]

Epoch 217
Training loss: 7.4399189949035645
Validation loss: 5.542046570616526


 73%|███████▎  | 218/300 [1:46:17<39:48, 29.13s/it]

Epoch 218
Training loss: 7.513025283813477
Validation loss: 6.881766537921571


 73%|███████▎  | 219/300 [1:46:47<39:18, 29.12s/it]

Epoch 219
Training loss: 7.52094841003418
Validation loss: 3.990941370729598


 73%|███████▎  | 220/300 [1:47:16<38:51, 29.14s/it]

Epoch 220
Training loss: 7.521109104156494
Validation loss: 5.882322753127791


 74%|███████▎  | 221/300 [1:47:45<38:20, 29.12s/it]

Epoch 221
Training loss: 7.454725742340088
Validation loss: 16.134046116491554


 74%|███████▍  | 222/300 [1:48:14<37:52, 29.14s/it]

Epoch 222
Training loss: 7.505536079406738
Validation loss: 3.1781993038379035


 74%|███████▍  | 223/300 [1:48:43<37:26, 29.17s/it]

Epoch 223
Training loss: 7.48244571685791
Validation loss: 9.75522909522234


 75%|███████▍  | 224/300 [1:49:12<36:56, 29.17s/it]

Epoch 224
Training loss: 7.537018299102783
Validation loss: 5.512400027668152


 75%|███████▌  | 225/300 [1:49:42<36:28, 29.18s/it]

Epoch 225
Training loss: 7.504829406738281
Validation loss: 10.143493215382192


 75%|███████▌  | 226/300 [1:50:11<35:58, 29.17s/it]

Epoch 226
Training loss: 7.533702850341797
Validation loss: 7.997235348524152


 76%|███████▌  | 227/300 [1:50:40<35:29, 29.17s/it]

Epoch 227
Training loss: 7.505074977874756
Validation loss: 9.970469239871193


 76%|███████▌  | 228/300 [1:51:09<35:00, 29.17s/it]

Epoch 228
Training loss: 7.510496616363525
Validation loss: 3.6112734004302327


 76%|███████▋  | 229/300 [1:51:38<34:29, 29.15s/it]

Epoch 229
Training loss: 7.546969413757324
Validation loss: 7.829356157685977


 77%|███████▋  | 230/300 [1:52:07<33:57, 29.11s/it]

Epoch 230
Training loss: 7.480377674102783
Validation loss: 12.280908956955658


 77%|███████▋  | 231/300 [1:52:36<33:29, 29.13s/it]

Epoch 231
Training loss: 7.506648063659668
Validation loss: 6.252298837565001


 77%|███████▋  | 232/300 [1:53:06<33:01, 29.13s/it]

Epoch 232
Training loss: 7.570308685302734
Validation loss: 9.409262607466381


 78%|███████▊  | 233/300 [1:53:35<32:32, 29.14s/it]

Epoch 233
Training loss: 7.445641994476318
Validation loss: 6.553747303373843


 78%|███████▊  | 234/300 [1:54:04<32:01, 29.11s/it]

Epoch 234
Training loss: 7.508319854736328
Validation loss: 2.4322988600069833


 78%|███████▊  | 235/300 [1:54:33<31:33, 29.13s/it]

Epoch 235
Training loss: 7.52794075012207
Validation loss: 4.689173520474178


 79%|███████▊  | 236/300 [1:55:02<31:03, 29.12s/it]

Epoch 236
Training loss: 7.513403415679932
Validation loss: 3.8184493573498


 79%|███████▉  | 237/300 [1:55:31<30:34, 29.13s/it]

Epoch 237
Training loss: 7.487625598907471
Validation loss: 7.862225023206776


 79%|███████▉  | 238/300 [1:56:00<30:05, 29.12s/it]

Epoch 238
Training loss: 7.46475076675415
Validation loss: 6.1542791952698765


 80%|███████▉  | 239/300 [1:56:29<29:36, 29.12s/it]

Epoch 239
Training loss: 7.5176801681518555
Validation loss: 17.035464724089938


 80%|████████  | 240/300 [1:56:59<29:08, 29.14s/it]

Epoch 240
Training loss: 7.452470779418945
Validation loss: 7.757182658474656


 80%|████████  | 241/300 [1:57:28<28:39, 29.14s/it]

Epoch 241
Training loss: 7.522359848022461
Validation loss: 7.617478519766965


 81%|████████  | 242/300 [1:57:57<28:11, 29.16s/it]

Epoch 242
Training loss: 7.498340606689453
Validation loss: 3.3017022388534265


 81%|████████  | 243/300 [1:58:26<27:41, 29.15s/it]

Epoch 243
Training loss: 7.536552906036377
Validation loss: 2.553026282162678


 81%|████████▏ | 244/300 [1:58:55<27:10, 29.12s/it]

Epoch 244
Training loss: 7.4688215255737305
Validation loss: 8.518943636341222


 82%|████████▏ | 245/300 [1:59:24<26:43, 29.15s/it]

Epoch 245
Training loss: 7.451802730560303
Validation loss: 9.275451964544594


 82%|████████▏ | 246/300 [1:59:54<26:15, 29.18s/it]

Epoch 246
Training loss: 7.48989725112915
Validation loss: 6.168320785837157


 82%|████████▏ | 247/300 [2:00:23<25:46, 29.18s/it]

Epoch 247
Training loss: 7.421236515045166
Validation loss: 2.64712277876293


 83%|████████▎ | 248/300 [2:00:52<25:16, 29.17s/it]

Epoch 248
Training loss: 7.445207118988037
Validation loss: 9.698662871015632


 83%|████████▎ | 249/300 [2:01:22<24:59, 29.39s/it]

Epoch 249
Training loss: 7.472497463226318
Validation loss: 3.460271924278179


 83%|████████▎ | 250/300 [2:01:52<24:49, 29.79s/it]

Epoch 250
Training loss: 7.501831531524658
Validation loss: 8.820810351062612


 84%|████████▎ | 251/300 [2:02:23<24:31, 30.03s/it]

Epoch 251
Training loss: 7.488330841064453
Validation loss: 7.19729303813569


 84%|████████▍ | 252/300 [2:02:54<24:11, 30.23s/it]

Epoch 252
Training loss: 7.470185279846191
Validation loss: 3.706715732469452


 84%|████████▍ | 253/300 [2:03:24<23:34, 30.11s/it]

Epoch 253
Training loss: 7.439122200012207
Validation loss: 2.8765897347503824


 85%|████████▍ | 254/300 [2:03:53<22:52, 29.85s/it]

Epoch 254
Training loss: 7.497878551483154
Validation loss: 4.665779221433514


 85%|████████▌ | 255/300 [2:04:22<22:15, 29.67s/it]

Epoch 255
Training loss: 7.423191547393799
Validation loss: 5.006347155990161


 85%|████████▌ | 256/300 [2:04:51<21:39, 29.54s/it]

Epoch 256
Training loss: 7.522618293762207
Validation loss: 6.74335510614792


 86%|████████▌ | 257/300 [2:05:21<21:05, 29.44s/it]

Epoch 257
Training loss: 7.4778242111206055
Validation loss: 4.8909123923844575


 86%|████████▌ | 258/300 [2:05:50<20:33, 29.36s/it]

Epoch 258
Training loss: 7.410022258758545
Validation loss: 6.829554110305537


 86%|████████▋ | 259/300 [2:06:19<20:00, 29.29s/it]

Epoch 259
Training loss: 7.453204154968262
Validation loss: 6.651974513586664


 87%|████████▋ | 260/300 [2:06:48<19:30, 29.26s/it]

Epoch 260
Training loss: 7.510406970977783
Validation loss: 8.692063280190691


 87%|████████▋ | 261/300 [2:07:17<19:01, 29.26s/it]

Epoch 261
Training loss: 7.46525239944458
Validation loss: 9.94239384276214


 87%|████████▋ | 262/300 [2:07:46<18:30, 29.21s/it]

Epoch 262
Training loss: 7.440524578094482
Validation loss: 2.560285203070139


 88%|████████▊ | 263/300 [2:08:16<18:00, 29.21s/it]

Epoch 263
Training loss: 7.507771015167236
Validation loss: 4.037597956837437


 88%|████████▊ | 264/300 [2:08:45<17:31, 29.22s/it]

Epoch 264
Training loss: 7.496967315673828
Validation loss: 3.821566898905026


 88%|████████▊ | 265/300 [2:09:14<17:02, 29.21s/it]

Epoch 265
Training loss: 7.480098724365234
Validation loss: 8.261417207141


 89%|████████▊ | 266/300 [2:09:43<16:33, 29.21s/it]

Epoch 266
Training loss: 7.463139057159424
Validation loss: 2.828212513466363


 89%|████████▉ | 267/300 [2:10:13<16:04, 29.23s/it]

Epoch 267
Training loss: 7.514607906341553
Validation loss: 13.082410254922266


 89%|████████▉ | 268/300 [2:10:42<15:34, 29.22s/it]

Epoch 268
Training loss: 7.51702880859375
Validation loss: 4.924328993422958


 90%|████████▉ | 269/300 [2:11:11<15:05, 29.22s/it]

Epoch 269
Training loss: 7.433343887329102
Validation loss: 3.331892995145436


 90%|█████████ | 270/300 [2:11:40<14:36, 29.22s/it]

Epoch 270
Training loss: 7.417924880981445
Validation loss: 9.62704534509603


 90%|█████████ | 271/300 [2:12:09<14:07, 29.21s/it]

Epoch 271
Training loss: 7.514335632324219
Validation loss: 7.977633213176439


 91%|█████████ | 272/300 [2:12:39<13:38, 29.22s/it]

Epoch 272
Training loss: 7.518860816955566
Validation loss: 5.040307167660089


 91%|█████████ | 273/300 [2:13:08<13:09, 29.22s/it]

Epoch 273
Training loss: 7.466367244720459
Validation loss: 9.036913929259837


 91%|█████████▏| 274/300 [2:13:37<12:39, 29.22s/it]

Epoch 274
Training loss: 7.454698085784912
Validation loss: 10.718332639744185


 92%|█████████▏| 275/300 [2:14:06<12:10, 29.21s/it]

Epoch 275
Training loss: 7.423634052276611
Validation loss: 4.60236483989994


 92%|█████████▏| 276/300 [2:14:35<11:40, 29.19s/it]

Epoch 276
Training loss: 7.467136859893799
Validation loss: 7.875640758413905


 92%|█████████▏| 277/300 [2:15:05<11:11, 29.18s/it]

Epoch 277
Training loss: 7.457156658172607
Validation loss: 11.258280293904981


 93%|█████████▎| 278/300 [2:15:34<10:41, 29.18s/it]

Epoch 278
Training loss: 7.45864200592041
Validation loss: 5.385488474452519


 93%|█████████▎| 279/300 [2:16:03<10:12, 29.18s/it]

Epoch 279
Training loss: 7.51370906829834
Validation loss: 3.759968009473715


 93%|█████████▎| 280/300 [2:16:32<09:43, 29.19s/it]

Epoch 280
Training loss: 7.469463348388672
Validation loss: 8.49011079425181


 94%|█████████▎| 281/300 [2:17:01<09:14, 29.19s/it]

Epoch 281
Training loss: 7.407552242279053
Validation loss: 3.3921693521022513


 94%|█████████▍| 282/300 [2:17:31<08:45, 29.20s/it]

Epoch 282
Training loss: 7.451141834259033
Validation loss: 4.625569417127903


 94%|█████████▍| 283/300 [2:18:00<08:16, 29.18s/it]

Epoch 283
Training loss: 7.497371196746826
Validation loss: 11.737399537005652


 95%|█████████▍| 284/300 [2:18:29<07:46, 29.16s/it]

Epoch 284
Training loss: 7.510715007781982
Validation loss: 11.515507814704822


 95%|█████████▌| 285/300 [2:18:58<07:17, 29.15s/it]

Epoch 285
Training loss: 7.464388847351074
Validation loss: 10.267963757015957


 95%|█████████▌| 286/300 [2:19:27<06:48, 29.16s/it]

Epoch 286
Training loss: 7.471865653991699
Validation loss: 7.730173958752985


 96%|█████████▌| 287/300 [2:19:56<06:19, 29.16s/it]

Epoch 287
Training loss: 7.4345173835754395
Validation loss: 11.039017759143245


 96%|█████████▌| 288/300 [2:20:25<05:49, 29.13s/it]

Epoch 288
Training loss: 7.498251438140869
Validation loss: 1.51957017115726


 96%|█████████▋| 289/300 [2:20:54<05:20, 29.13s/it]

Epoch 289
Training loss: 7.4676713943481445
Validation loss: 9.290035552361044


 97%|█████████▋| 290/300 [2:21:23<04:51, 29.10s/it]

Epoch 290
Training loss: 7.434639930725098
Validation loss: 4.993773429039162


 97%|█████████▋| 291/300 [2:21:53<04:22, 29.12s/it]

Epoch 291
Training loss: 7.479793071746826
Validation loss: 4.417086883299271


 97%|█████████▋| 292/300 [2:22:22<03:52, 29.12s/it]

Epoch 292
Training loss: 7.461124897003174
Validation loss: 2.741163218751471


 98%|█████████▊| 293/300 [2:22:51<03:24, 29.15s/it]

Epoch 293
Training loss: 7.440001487731934
Validation loss: 4.4910416658923085


 98%|█████████▊| 294/300 [2:23:20<02:55, 29.17s/it]

Epoch 294
Training loss: 7.524068832397461
Validation loss: 8.569367018714956


 98%|█████████▊| 295/300 [2:23:49<02:25, 29.17s/it]

Epoch 295
Training loss: 7.4653496742248535
Validation loss: 7.389540070907872


 99%|█████████▊| 296/300 [2:24:19<01:56, 29.20s/it]

Epoch 296
Training loss: 7.417683124542236
Validation loss: 6.985678190285274


 99%|█████████▉| 297/300 [2:24:48<01:27, 29.24s/it]

Epoch 297
Training loss: 7.418878078460693
Validation loss: 10.009322519908489


 99%|█████████▉| 298/300 [2:25:17<00:58, 29.24s/it]

Epoch 298
Training loss: 7.509354591369629
Validation loss: 8.130785662576493


100%|█████████▉| 299/300 [2:25:46<00:29, 29.22s/it]

Epoch 299
Training loss: 7.443146705627441
Validation loss: 8.340312315299208


100%|██████████| 300/300 [2:26:16<00:00, 29.25s/it]

Epoch 300
Training loss: 7.466678142547607
Validation loss: 7.859086402183354





In [72]:
def save_to_file(file_path, appendix, epoch, lr, dr, train_loss, test_loss):
    train_loss = [float(arr) for arr in train_loss]
    test_loss = [float(arr) for arr in test_loss]

    with open(file_path, "a") as f:
        f.write(f">> {appendix}, Epoch: {epoch}, LR: {lr}, DR: {dr}\n")
        f.write(f"--- Train Loss: {train_loss}\n")
        f.write(f"--- Test Loss: {test_loss}\n")
        
def training(appendix:str, epoch:int, dropout_rate:float, lr:float, batch_size:int=16, train_dataset=train_dataset, val_dataset=val_dataset, verbose:bool=False, graphics:bool=False):
    print(f"====lr: {lr}, dropout rate: {dropout_rate}")
    torch.cuda.empty_cache()
    import gc; gc.collect()
    # ic.enable()
    ic.disable()
    
    model = Model(in_channel = 869, 
              first_h_layer = 469, 
              out_channel=1, 
              num_dense_layers=4,
              batch_size=batch_size,
              dropout_rate=dropout_rate)
    
    model = model.float()
    model = model.to(device)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, drop_last=True)
    # criterion = nn.MSELoss()
    criterion = masked_MSE
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=2, verbose=True)

    train_epoch_loss = []
    test_epoch_loss = []

    for e in tqdm(range(1, epoch+1)):
        model.train()
        train_batch_loss = []
        test_batch_loss = []
        
        for x, y in train_loader:
            x_batch = torch.unsqueeze(x, 1).to(device)
            print(x_batch.size())
            # x_batch = x.permute(1, 0, 2)
            # print(x_batch.size())
            y_batch = y.to(device)
            x_batch = x_batch.float()
            y_batch = y_batch.float()

            pred = model(x_batch.float())

            ic(pred.size())
            loss_train = criterion(y_batch, pred)
            ic(loss_train)
            train_batch_loss.append(loss_train.detach())
            
            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step()

            # print(f'Batch - GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB')

        train_epoch_loss.append(torch.mean(torch.stack(train_batch_loss)).detach().cpu().numpy())
        with torch.no_grad():
            # print('test')
            for x, y in test_loader:
                x_batch = x.to(device)
                y_batch = y.to(device)
                pred = model(x_batch.float())
                loss_test = criterion(y_batch, pred)
                test_batch_loss.append(loss_test )
            test_epoch_loss.append(torch.mean(torch.stack(test_batch_loss)).detach().cpu().numpy())
        if verbose:
            print(f'Epoch {e}')
            print(f"Training loss: {torch.mean(torch.stack(train_batch_loss)).detach().cpu().numpy()}")
            print(f"Validation loss: {torch.mean(torch.stack(test_batch_loss)).detach().cpu().numpy()}") 
            print('==='*10)
        
    save_to_file('trials_rif.txt', appendix ,epoch, lr, dropout_rate, train_epoch_loss, test_epoch_loss)

    if graphics:
        fig, ax = plt.subplots()
        x = np.arange(1, epoch+1, 1)
        ax.plot(x, train_epoch_loss,label='Training')
        ax.plot(x, test_epoch_loss,label='Validation')
        ax.legend()
        ax.set_xlabel("Number of Epoch")
        ax.set_ylabel("Loss")
        ax.set_xticks(np.arange(0, epoch+1, 10))
        ax.set_title(f'Loss: Learning_rate:{lr},dr:{dropout_rate}')
        # ax_2 = ax.twinx()
        # ax_2.plot(history["lr"], "k--", lw=1)
        # ax_2.set_yscale("log")
        # ax.set_ylim(ax.get_ylim()[0], history["training_losses"][0])
        ax.grid(axis="x")
        fig.tight_layout()
        fig.savefig(f'/mnt/storageG1/lwang/Projects/tb_dr_MIC/graph_rif/{appendix}_loss_lr_{lr}_dr_{dropout_rate}.png')
        fig.show()
    return torch.mean(torch.stack(test_batch_loss)).detach().cpu().numpy(), torch.mean(torch.stack(train_batch_loss)).detach().cpu().numpy()
# training(epoch=50, dropout_rate=0.2, lr=0.001, batch_size=128, train_dataset=train_dataset, val_dataset=val_dataset, graphics=True)

In [74]:
lr_values = [1e-5, 1e-3, 1e-1,]
# dropout_rates = [0.2, 0.3, 0.4,0.5]
dropout_rates = [0]

# lr_values = [0.001]
# dropout_rates = [0.5]

results_test = []
results_train = []

for lr in lr_values:
    for dropout_rate in dropout_rates:
        result_test, result_train = training('5layers', epoch=100, dropout_rate=dropout_rate, lr=lr, batch_size=16, train_dataset=train_dataset, val_dataset=val_dataset, graphics=True)
        results_test.append((lr, dropout_rate, result_test))
        results_train.append((lr, dropout_rate, result_train))
# Sort the results based on the validation loss

# with open(f'results_test_lr{lr}_dr{dropout_rate}.txt', 'w') as file:
#     for lr, dropout_rate, result in results_test:
#         file.write(f'Learning Rate: {lr}, Dropout Rate: {dropout_rate}, Test Result: {result}\n')

# # Save the train results to a text file
# with open(f'results_train_lr{lr}_dr{dropout_rate}.txt', 'w') as file:
#     for lr, dropout_rate, result in results_train:
#         file.write(f'Learning Rate: {lr}, Dropout Rate: {dropout_rate}, Train Result: {result}\n')

results_test.sort(key=lambda x: x[2])

# Print the best lr and dropout rate
best_lr, best_dropout_rate, best_result = results_test[0]
print(f"Best lr: {best_lr}, Best dropout rate: {best_dropout_rate}, Best result: {best_result}")

====lr: 1e-05, dropout rate: 0


  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([16, 1, 869])





AttributeError: 'Model' object has no attribute 'cnn_output_size'