In [1]:
import os
import pandas as pd
import numpy as np
import dataframe as df
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import torchmetrics
from torchmetrics.functional import accuracy

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
X_train = pd.read_csv('../data2.csv')

In [4]:
X_train.head()

Unnamed: 0,time,id,group,state,HEOL,HEOR,FP1,FP2,VEOU,VEOL,...,P4,T6,A2,O1,OZ,O2,FT9,FT10,PO1,PO2
0,0,0,0,Fatigue,1.215492e-13,-4.8697030000000006e-17,-6.140989e-14,2.5985100000000002e-17,-1.160435e-13,-3.1678430000000003e-17,...,1.651714e-13,-3.30743e-17,3.455894e-13,-6.391263000000001e-17,8.724439e-14,-3.155565e-17,-7.326835e-14,-3.155565e-17,4.404571e-14,-2.665396e-17
1,1,1,0,Fatigue,-78.31237,-0.01362717,-56.90807,-0.04192949,321.2726,-0.0139765,...,-264.692,-0.006988248,-228.1396,0.04192949,-58.1866,0.03494124,104.0445,0.03494124,-32.91532,0.02795299
2,2,2,0,Fatigue,-153.3161,-0.02671759,-111.4424,-0.08209756,629.0321,-0.02736585,...,-518.3021,-0.01368293,-445.9774,0.08209756,-113.8533,0.06841464,203.7677,0.06841464,-64.39023,0.05473171
3,3,3,0,Fatigue,-221.9502,-0.03877185,-161.4074,-0.1188756,910.7824,-0.03962519,...,-750.5852,-0.0198126,-644.0388,0.1188756,-164.6722,0.09906298,295.1723,0.09906298,-93.09419,0.07925038
4,4,4,0,Fatigue,-281.6333,-0.04936209,-204.9461,-0.1508884,1155.975,-0.05029612,...,-952.8851,-0.02514806,-814.4362,0.1508884,-208.684,0.1257403,374.8789,0.1257403,-117.9092,0.1005922


In [5]:
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(X_train.state)

label_encoder.classes_

array(['Fatigue', 'Normal'], dtype=object)

In [6]:
X_train['label'] = encoded_labels
X_train.head()

Unnamed: 0,time,id,group,state,HEOL,HEOR,FP1,FP2,VEOU,VEOL,...,T6,A2,O1,OZ,O2,FT9,FT10,PO1,PO2,label
0,0,0,0,Fatigue,1.215492e-13,-4.8697030000000006e-17,-6.140989e-14,2.5985100000000002e-17,-1.160435e-13,-3.1678430000000003e-17,...,-3.30743e-17,3.455894e-13,-6.391263000000001e-17,8.724439e-14,-3.155565e-17,-7.326835e-14,-3.155565e-17,4.404571e-14,-2.665396e-17,0
1,1,1,0,Fatigue,-78.31237,-0.01362717,-56.90807,-0.04192949,321.2726,-0.0139765,...,-0.006988248,-228.1396,0.04192949,-58.1866,0.03494124,104.0445,0.03494124,-32.91532,0.02795299,0
2,2,2,0,Fatigue,-153.3161,-0.02671759,-111.4424,-0.08209756,629.0321,-0.02736585,...,-0.01368293,-445.9774,0.08209756,-113.8533,0.06841464,203.7677,0.06841464,-64.39023,0.05473171,0
3,3,3,0,Fatigue,-221.9502,-0.03877185,-161.4074,-0.1188756,910.7824,-0.03962519,...,-0.0198126,-644.0388,0.1188756,-164.6722,0.09906298,295.1723,0.09906298,-93.09419,0.07925038,0
4,4,4,0,Fatigue,-281.6333,-0.04936209,-204.9461,-0.1508884,1155.975,-0.05029612,...,-0.02514806,-814.4362,0.1508884,-208.684,0.1257403,374.8789,0.1257403,-117.9092,0.1005922,0


In [7]:
FEATURE_COLUMNS = X_train.columns.tolist()[4:-2]

In [8]:
g = X_train.groupby("group")

In [9]:
sequences = [] 

for name, group in g:
    sequence_features = group[FEATURE_COLUMNS]
    label = group.label.iloc[0]
    
#     print((sequence_features, label))
    sequences.append((sequence_features, label))

In [10]:
sequences[0]

(              HEOL          HEOR           FP1           FP2          VEOU  \
 0     1.215492e-13 -4.869703e-17 -6.140989e-14  2.598510e-17 -1.160435e-13   
 1    -7.831237e+01 -1.362717e-02 -5.690807e+01 -4.192949e-02  3.212726e+02   
 2    -1.533161e+02 -2.671759e-02 -1.114424e+02 -8.209756e-02  6.290321e+02   
 3    -2.219502e+02 -3.877185e-02 -1.614074e+02 -1.188756e-01  9.107824e+02   
 4    -2.816333e+02 -4.936209e-02 -2.049461e+02 -1.508884e-01  1.155975e+03   
 ...            ...           ...           ...           ...           ...   
 1019  9.189533e+00  2.849707e-03  1.783298e+01  6.329993e-03 -4.930598e+01   
 1020  7.436606e+00  2.869100e-03  1.661349e+01  6.348566e-03 -5.123579e+01   
 1021  5.528420e+00  2.921818e-03  1.508811e+01  6.324462e-03 -5.286875e+01   
 1022  3.726376e+00  3.024905e-03  1.349644e+01  6.342777e-03 -5.465339e+01   
 1023  1.962049e+00  3.137497e-03  1.184765e+01  6.318416e-03 -5.574251e+01   
 
               VEOL            F7            F3   

In [11]:
train_sequences, test_sequences = train_test_split(sequences, test_size=0.2)

In [12]:
len(train_sequences), len(test_sequences)

(11268, 2818)

## Data Set

In [13]:
class DrowsyDataset(Dataset):
    
    def __init__(self, sequences):
        self.sequences = sequences
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence, label = self.sequences[idx]
        return dict(
            sequence=torch.Tensor(sequence.to_numpy()),
            label=torch.tensor(label).long()
        )

In [14]:
class DrowsyDataModule(pl.LightningDataModule):
    
    def __init__(self, train_sequences, test_sequences, batch_size):
        super().__init__()
        self.train_sequences = train_sequences
        self.test_sequences = test_sequences
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = DrowsyDataset(self.train_sequences)
        self.test_dataset = DrowsyDataset(self.test_sequences)
        
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=os.cpu_count()
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=os.cpu_count()
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=os.cpu_count()
        ) 

In [15]:
N_EPOCHS = 128
BATCH_SIZE = 100

data_module = DrowsyDataModule(train_sequences, test_sequences, BATCH_SIZE)

## Models

In [16]:
class DrowsyModel(nn.Module):
    
    def __init__(self, n_features, n_classes, n_hidden=256, n_layers=3):
        super().__init__()
        
        self.n_hidden = n_hidden
        
        self.lstm = nn.LSTM(
            input_size=n_features,
            hidden_size=n_hidden,
            num_layers=n_layers,
            batch_first=True,
            dropout=0.75
        )
        
        self.classifier = nn.Linear(n_hidden, n_classes)
        
    def forward(self, x):
        self.lstm.flatten_parameters()
        _, (hidden, _) = self.lstm(x)
        
        out = hidden[-1]
        return self.classifier(out)

In [17]:
class DrowsyPredictor(pl.LightningModule):
    
    def __init__(self, n_features: int, n_classes: int):
        super().__init__()
        self.model = DrowsyModel(n_features, n_classes)
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, x, label=None):
        output = self.model(x)
        loss = 0
        if label is not None:
            loss = self.criterion(output, label)
        return loss, output
        
    def training_step(self, batch, batch_idx):
        sequences = batch["sequence"]
        labels = batch["label"]
        loss, outputs = self(sequences, labels)
        predictions = torch.argmax(outputs, dim=1)
        step_accuracy = accuracy(predictions, labels)
        
        self.log("train_loss", loss, prog_bar=True, logger=True)
        self.log("train_accuracy", step_accuracy, prog_bar=True, logger=True)
        
        return {"loss": loss, "accuracy": step_accuracy}
    
    def validation_step(self, batch, batch_idx):
        sequences = batch["sequence"]
        labels = batch["label"]
        loss, outputs = self(sequences, labels)
        predictions = torch.argmax(outputs, dim=1)
        step_accuracy = accuracy(predictions, labels)
        
        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_accuracy", step_accuracy, prog_bar=True, logger=True)
        
        return {"loss": loss, "accuracy": step_accuracy}
    
    def testing_step(self, batch, batch_idx):
        sequences = batch["sequences"]
        labels = batch["label"]
        loss, outputs = self(sequence, labels)
        predictions = torch.argmax(outputs, dim=1)
        step_accuracy = accuracy(predictions, labels)
        
        self.log("test_loss", loss, prog_bar=True, logger=True)
        self.log("test_accuracy", step_accuracy, prog_bar=True, logger=True)
        
        return {"loss": loss, "accuracy": step_accuracy}
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.0001)

In [18]:
model = DrowsyPredictor(
    n_features=len(FEATURE_COLUMNS), 
    n_classes=len(label_encoder.classes_)
)

In [19]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min"
)

logger = TensorBoardLogger("lightning_logs", name="Drowsy")

trainer = pl.Trainer(
    logger=logger,
    callbacks=[checkpoint_callback],
    max_epochs=N_EPOCHS,
    gpus=[0],
#     progress_bar_refresh_rate=20
)

## Training 

In [None]:
trainer.fit(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | DrowsyModel      | 1.4 M 
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.429     Total estimated model params size (MB)


Epoch 0:  80%|███████▉  | 113/142 [00:52<00:13,  2.14it/s, loss=0.672, v_num=0, train_loss=0.652, train_accuracy=0.676]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Epoch 0:  81%|████████  | 115/142 [00:55<00:12,  2.09it/s, loss=0.672, v_num=0, train_loss=0.652, train_accuracy=0.676]
Validating:   7%|▋         | 2/29 [00:02<00:27,  1.02s/it][A
Epoch 0:  82%|████████▏ | 117/142 [00:55<00:11,  2.12it/s, loss=0.672, v_num=0, train_loss=0.652, train_accuracy=0.676]
Validating:  14%|█▍        | 4/29 [00:02<00:10,  2.42it/s][A
Epoch 0:  84%|████████▍ | 119/142 [00:55<00:10,  2.14it/s, loss=0.672, v_num=0, train_loss=0.652, train_accuracy=0.676]
Validating:  21%|██        | 6/29 [00:02<00:05,  4.22it/s][A
Epoch 0:  85%|████████▌ | 121/142 [00:55<00:09,  2.17it/s, loss=0.672, v_num=0, train_loss=0.652, train_accuracy=0.676]
Validating:  28%|██▊       | 8/29 [00:03<00:03,  5.91it/s][A
Epoch 0:  87%|████████▋ | 123/142 [00:55<00:08,  2.20it/s, loss=

Epoch 0, global step 112: val_loss reached 0.64530 (best 0.64530), saving model to "/Workspace/code/checkpoints/sample-mnist-epoch=00-val_loss=0.65.ckpt" as top 1


Epoch 1:  80%|████████  | 114/142 [00:53<00:13,  2.13it/s, loss=0.576, v_num=0, train_loss=0.691, train_accuracy=0.574, val_loss=0.645, val_accuracy=0.621]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:17,  2.76s/it][A
Epoch 1:  82%|████████▏ | 116/142 [00:56<00:12,  2.06it/s, loss=0.576, v_num=0, train_loss=0.691, train_accuracy=0.574, val_loss=0.645, val_accuracy=0.621]
Validating:  10%|█         | 3/29 [00:02<00:18,  1.42it/s][A
Epoch 1:  83%|████████▎ | 118/142 [00:56<00:11,  2.08it/s, loss=0.576, v_num=0, train_loss=0.691, train_accuracy=0.574, val_loss=0.645, val_accuracy=0.621]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.95it/s][A
Epoch 1:  85%|████████▍ | 120/142 [00:56<00:10,  2.11it/s, loss=0.576, v_num=0, train_loss=0.691, train_accuracy=0.574, val_loss=0.645, val_accuracy=0.621]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.76it/s][A
Epoch 1:  86%|████████▌ | 122/142 [00:57

Epoch 1, global step 225: val_loss reached 0.54734 (best 0.54734), saving model to "/Workspace/code/checkpoints/sample-mnist-epoch=01-val_loss=0.55.ckpt" as top 1


Epoch 2:  80%|████████  | 114/142 [00:53<00:13,  2.12it/s, loss=0.59, v_num=0, train_loss=0.631, train_accuracy=0.632, val_loss=0.547, val_accuracy=0.739] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:19,  2.84s/it][A
Epoch 2:  82%|████████▏ | 116/142 [00:56<00:12,  2.05it/s, loss=0.59, v_num=0, train_loss=0.631, train_accuracy=0.632, val_loss=0.547, val_accuracy=0.739]
Validating:  10%|█         | 3/29 [00:03<00:18,  1.39it/s][A
Epoch 2:  83%|████████▎ | 118/142 [00:56<00:11,  2.07it/s, loss=0.59, v_num=0, train_loss=0.631, train_accuracy=0.632, val_loss=0.547, val_accuracy=0.739]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.88it/s][A
Epoch 2:  85%|████████▍ | 120/142 [00:57<00:10,  2.10it/s, loss=0.59, v_num=0, train_loss=0.631, train_accuracy=0.632, val_loss=0.547, val_accuracy=0.739]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.64it/s][A
Epoch 2:  86%|████████▌ | 122/142 [00:57<00

Epoch 2, global step 338: val_loss was not in top 1


Epoch 3:  80%|████████  | 114/142 [00:53<00:13,  2.12it/s, loss=0.589, v_num=0, train_loss=0.668, train_accuracy=0.603, val_loss=0.580, val_accuracy=0.718]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:20,  2.89s/it][A
Epoch 3:  82%|████████▏ | 116/142 [00:56<00:12,  2.04it/s, loss=0.589, v_num=0, train_loss=0.668, train_accuracy=0.603, val_loss=0.580, val_accuracy=0.718]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.37it/s][A
Epoch 3:  83%|████████▎ | 118/142 [00:57<00:11,  2.07it/s, loss=0.589, v_num=0, train_loss=0.668, train_accuracy=0.603, val_loss=0.580, val_accuracy=0.718]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.86it/s][A
Epoch 3:  85%|████████▍ | 120/142 [00:57<00:10,  2.10it/s, loss=0.589, v_num=0, train_loss=0.668, train_accuracy=0.603, val_loss=0.580, val_accuracy=0.718]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.58it/s][A
Epoch 3:  86%|████████▌ | 122/142 [00:57

Epoch 3, global step 451: val_loss was not in top 1


Epoch 4:  80%|████████  | 114/142 [00:54<00:13,  2.10it/s, loss=0.495, v_num=0, train_loss=0.403, train_accuracy=0.853, val_loss=0.613, val_accuracy=0.660]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:20,  2.86s/it][A
Epoch 4:  82%|████████▏ | 116/142 [00:57<00:12,  2.03it/s, loss=0.495, v_num=0, train_loss=0.403, train_accuracy=0.853, val_loss=0.613, val_accuracy=0.660]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.37it/s][A
Epoch 4:  83%|████████▎ | 118/142 [00:57<00:11,  2.05it/s, loss=0.495, v_num=0, train_loss=0.403, train_accuracy=0.853, val_loss=0.613, val_accuracy=0.660]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.87it/s][A
Epoch 4:  85%|████████▍ | 120/142 [00:57<00:10,  2.08it/s, loss=0.495, v_num=0, train_loss=0.403, train_accuracy=0.853, val_loss=0.613, val_accuracy=0.660]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.66it/s][A
Epoch 4:  86%|████████▌ | 122/142 [00:57

Epoch 4, global step 564: val_loss reached 0.51912 (best 0.51912), saving model to "/Workspace/code/checkpoints/sample-mnist-epoch=04-val_loss=0.52.ckpt" as top 1


Epoch 5:  80%|████████  | 114/142 [00:53<00:13,  2.11it/s, loss=0.527, v_num=0, train_loss=0.536, train_accuracy=0.809, val_loss=0.519, val_accuracy=0.760]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:23,  2.98s/it][A
Epoch 5:  82%|████████▏ | 116/142 [00:57<00:12,  2.03it/s, loss=0.527, v_num=0, train_loss=0.536, train_accuracy=0.809, val_loss=0.519, val_accuracy=0.760]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.33it/s][A
Epoch 5:  83%|████████▎ | 118/142 [00:57<00:11,  2.06it/s, loss=0.527, v_num=0, train_loss=0.536, train_accuracy=0.809, val_loss=0.519, val_accuracy=0.760]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.78it/s][A
Epoch 5:  85%|████████▍ | 120/142 [00:57<00:10,  2.09it/s, loss=0.527, v_num=0, train_loss=0.536, train_accuracy=0.809, val_loss=0.519, val_accuracy=0.760]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.53it/s][A
Epoch 5:  86%|████████▌ | 122/142 [00:57

Epoch 5, global step 677: val_loss was not in top 1


Epoch 6:  80%|████████  | 114/142 [00:54<00:13,  2.11it/s, loss=0.529, v_num=0, train_loss=0.665, train_accuracy=0.662, val_loss=0.528, val_accuracy=0.749]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:22,  2.94s/it][A
Epoch 6:  82%|████████▏ | 116/142 [00:57<00:12,  2.03it/s, loss=0.529, v_num=0, train_loss=0.665, train_accuracy=0.662, val_loss=0.528, val_accuracy=0.749]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.34it/s][A
Epoch 6:  83%|████████▎ | 118/142 [00:57<00:11,  2.06it/s, loss=0.529, v_num=0, train_loss=0.665, train_accuracy=0.662, val_loss=0.528, val_accuracy=0.749]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.81it/s][A
Epoch 6:  85%|████████▍ | 120/142 [00:57<00:10,  2.08it/s, loss=0.529, v_num=0, train_loss=0.665, train_accuracy=0.662, val_loss=0.528, val_accuracy=0.749]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.53it/s][A
Epoch 6:  86%|████████▌ | 122/142 [00:57

Epoch 6, global step 790: val_loss was not in top 1


Epoch 7:  80%|████████  | 114/142 [00:54<00:13,  2.10it/s, loss=0.48, v_num=0, train_loss=0.532, train_accuracy=0.735, val_loss=0.675, val_accuracy=0.692] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:22,  2.94s/it][A
Epoch 7:  82%|████████▏ | 116/142 [00:57<00:12,  2.02it/s, loss=0.48, v_num=0, train_loss=0.532, train_accuracy=0.735, val_loss=0.675, val_accuracy=0.692]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.34it/s][A
Epoch 7:  83%|████████▎ | 118/142 [00:57<00:11,  2.05it/s, loss=0.48, v_num=0, train_loss=0.532, train_accuracy=0.735, val_loss=0.675, val_accuracy=0.692]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.78it/s][A
Epoch 7:  85%|████████▍ | 120/142 [00:57<00:10,  2.07it/s, loss=0.48, v_num=0, train_loss=0.532, train_accuracy=0.735, val_loss=0.675, val_accuracy=0.692]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.55it/s][A
Epoch 7:  86%|████████▌ | 122/142 [00:58<00

Epoch 7, global step 903: val_loss reached 0.51711 (best 0.51711), saving model to "/Workspace/code/checkpoints/sample-mnist-epoch=07-val_loss=0.52.ckpt" as top 1


Epoch 8:  80%|████████  | 114/142 [00:54<00:13,  2.11it/s, loss=0.497, v_num=0, train_loss=0.775, train_accuracy=0.618, val_loss=0.517, val_accuracy=0.772]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:03<01:24,  3.01s/it][A
Epoch 8:  82%|████████▏ | 116/142 [00:57<00:12,  2.03it/s, loss=0.497, v_num=0, train_loss=0.775, train_accuracy=0.618, val_loss=0.517, val_accuracy=0.772]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.32it/s][A
Epoch 8:  83%|████████▎ | 118/142 [00:57<00:11,  2.06it/s, loss=0.497, v_num=0, train_loss=0.775, train_accuracy=0.618, val_loss=0.517, val_accuracy=0.772]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.78it/s][A
Epoch 8:  85%|████████▍ | 120/142 [00:57<00:10,  2.08it/s, loss=0.497, v_num=0, train_loss=0.775, train_accuracy=0.618, val_loss=0.517, val_accuracy=0.772]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.55it/s][A
Epoch 8:  86%|████████▌ | 122/142 [00:57

Epoch 8, global step 1016: val_loss was not in top 1


Epoch 9:  80%|████████  | 114/142 [00:54<00:13,  2.11it/s, loss=0.634, v_num=0, train_loss=0.609, train_accuracy=0.721, val_loss=0.859, val_accuracy=0.623]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:23,  2.96s/it][A
Epoch 9:  82%|████████▏ | 116/142 [00:57<00:12,  2.03it/s, loss=0.634, v_num=0, train_loss=0.609, train_accuracy=0.721, val_loss=0.859, val_accuracy=0.623]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.33it/s][A
Epoch 9:  83%|████████▎ | 118/142 [00:57<00:11,  2.06it/s, loss=0.634, v_num=0, train_loss=0.609, train_accuracy=0.721, val_loss=0.859, val_accuracy=0.623]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.79it/s][A
Epoch 9:  85%|████████▍ | 120/142 [00:57<00:10,  2.08it/s, loss=0.634, v_num=0, train_loss=0.609, train_accuracy=0.721, val_loss=0.859, val_accuracy=0.623]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.50it/s][A
Epoch 9:  86%|████████▌ | 122/142 [00:57

Epoch 9, global step 1129: val_loss was not in top 1


Epoch 10:  80%|████████  | 114/142 [00:54<00:13,  2.09it/s, loss=0.628, v_num=0, train_loss=0.618, train_accuracy=0.676, val_loss=0.654, val_accuracy=0.609]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:21,  2.92s/it][A
Epoch 10:  82%|████████▏ | 116/142 [00:57<00:12,  2.02it/s, loss=0.628, v_num=0, train_loss=0.618, train_accuracy=0.676, val_loss=0.654, val_accuracy=0.609]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.33it/s][A
Epoch 10:  83%|████████▎ | 118/142 [00:57<00:11,  2.04it/s, loss=0.628, v_num=0, train_loss=0.618, train_accuracy=0.676, val_loss=0.654, val_accuracy=0.609]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.80it/s][A
Epoch 10:  85%|████████▍ | 120/142 [00:57<00:10,  2.07it/s, loss=0.628, v_num=0, train_loss=0.618, train_accuracy=0.676, val_loss=0.654, val_accuracy=0.609]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.57it/s][A
Epoch 10:  86%|████████▌ | 122/142 [

Epoch 10, global step 1242: val_loss was not in top 1


Epoch 11:  80%|████████  | 114/142 [00:54<00:13,  2.09it/s, loss=0.632, v_num=0, train_loss=0.665, train_accuracy=0.574, val_loss=0.649, val_accuracy=0.612]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:22,  2.94s/it][A
Epoch 11:  82%|████████▏ | 116/142 [00:57<00:12,  2.02it/s, loss=0.632, v_num=0, train_loss=0.665, train_accuracy=0.574, val_loss=0.649, val_accuracy=0.612]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.33it/s][A
Epoch 11:  83%|████████▎ | 118/142 [00:57<00:11,  2.04it/s, loss=0.632, v_num=0, train_loss=0.665, train_accuracy=0.574, val_loss=0.649, val_accuracy=0.612]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.80it/s][A
Epoch 11:  85%|████████▍ | 120/142 [00:57<00:10,  2.07it/s, loss=0.632, v_num=0, train_loss=0.665, train_accuracy=0.574, val_loss=0.649, val_accuracy=0.612]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.57it/s][A
Epoch 11:  86%|████████▌ | 122/142 [

Epoch 11, global step 1355: val_loss was not in top 1


Epoch 12:  80%|████████  | 114/142 [00:54<00:13,  2.10it/s, loss=0.622, v_num=0, train_loss=0.633, train_accuracy=0.618, val_loss=0.640, val_accuracy=0.616]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:23,  2.97s/it][A
Epoch 12:  82%|████████▏ | 116/142 [00:57<00:12,  2.02it/s, loss=0.622, v_num=0, train_loss=0.633, train_accuracy=0.618, val_loss=0.640, val_accuracy=0.616]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.32it/s][A
Epoch 12:  83%|████████▎ | 118/142 [00:57<00:11,  2.05it/s, loss=0.622, v_num=0, train_loss=0.633, train_accuracy=0.618, val_loss=0.640, val_accuracy=0.616]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.78it/s][A
Epoch 12:  85%|████████▍ | 120/142 [00:57<00:10,  2.08it/s, loss=0.622, v_num=0, train_loss=0.633, train_accuracy=0.618, val_loss=0.640, val_accuracy=0.616]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.55it/s][A
Epoch 12:  86%|████████▌ | 122/142 [

Epoch 12, global step 1468: val_loss was not in top 1


Epoch 13:  80%|████████  | 114/142 [00:54<00:13,  2.10it/s, loss=0.617, v_num=0, train_loss=0.639, train_accuracy=0.618, val_loss=0.635, val_accuracy=0.621]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:03<01:25,  3.04s/it][A
Epoch 13:  82%|████████▏ | 116/142 [00:57<00:12,  2.02it/s, loss=0.617, v_num=0, train_loss=0.639, train_accuracy=0.618, val_loss=0.635, val_accuracy=0.621]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.31it/s][A
Epoch 13:  83%|████████▎ | 118/142 [00:57<00:11,  2.05it/s, loss=0.617, v_num=0, train_loss=0.639, train_accuracy=0.618, val_loss=0.635, val_accuracy=0.621]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.76it/s][A
Epoch 13:  85%|████████▍ | 120/142 [00:57<00:10,  2.08it/s, loss=0.617, v_num=0, train_loss=0.639, train_accuracy=0.618, val_loss=0.635, val_accuracy=0.621]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.51it/s][A
Epoch 13:  86%|████████▌ | 122/142 [

Epoch 13, global step 1581: val_loss was not in top 1


Epoch 14:  80%|████████  | 114/142 [00:54<00:13,  2.09it/s, loss=0.613, v_num=0, train_loss=0.626, train_accuracy=0.632, val_loss=0.631, val_accuracy=0.625]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:03<01:25,  3.06s/it][A
Epoch 14:  82%|████████▏ | 116/142 [00:57<00:12,  2.01it/s, loss=0.613, v_num=0, train_loss=0.626, train_accuracy=0.632, val_loss=0.631, val_accuracy=0.625]
Validating:  10%|█         | 3/29 [00:03<00:20,  1.29it/s][A
Epoch 14:  83%|████████▎ | 118/142 [00:57<00:11,  2.04it/s, loss=0.613, v_num=0, train_loss=0.626, train_accuracy=0.632, val_loss=0.631, val_accuracy=0.625]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.69it/s][A
Epoch 14:  85%|████████▍ | 120/142 [00:58<00:10,  2.07it/s, loss=0.613, v_num=0, train_loss=0.626, train_accuracy=0.632, val_loss=0.631, val_accuracy=0.625]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.43it/s][A
Epoch 14:  86%|████████▌ | 122/142 [

Epoch 14, global step 1694: val_loss was not in top 1


Epoch 15:  80%|████████  | 114/142 [00:54<00:13,  2.09it/s, loss=0.61, v_num=0, train_loss=0.671, train_accuracy=0.588, val_loss=0.627, val_accuracy=0.631] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:21,  2.92s/it][A
Epoch 15:  82%|████████▏ | 116/142 [00:57<00:12,  2.01it/s, loss=0.61, v_num=0, train_loss=0.671, train_accuracy=0.588, val_loss=0.627, val_accuracy=0.631]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.34it/s][A
Epoch 15:  83%|████████▎ | 118/142 [00:57<00:11,  2.04it/s, loss=0.61, v_num=0, train_loss=0.671, train_accuracy=0.588, val_loss=0.627, val_accuracy=0.631]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.82it/s][A
Epoch 15:  85%|████████▍ | 120/142 [00:58<00:10,  2.07it/s, loss=0.61, v_num=0, train_loss=0.671, train_accuracy=0.588, val_loss=0.627, val_accuracy=0.631]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.54it/s][A
Epoch 15:  86%|████████▌ | 122/142 [00:

Epoch 15, global step 1807: val_loss was not in top 1


Epoch 16:  80%|████████  | 114/142 [00:54<00:13,  2.10it/s, loss=0.597, v_num=0, train_loss=0.592, train_accuracy=0.676, val_loss=0.621, val_accuracy=0.636]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:03<01:25,  3.05s/it][A
Epoch 16:  82%|████████▏ | 116/142 [00:57<00:12,  2.02it/s, loss=0.597, v_num=0, train_loss=0.592, train_accuracy=0.676, val_loss=0.621, val_accuracy=0.636]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.30it/s][A
Epoch 16:  83%|████████▎ | 118/142 [00:57<00:11,  2.05it/s, loss=0.597, v_num=0, train_loss=0.592, train_accuracy=0.676, val_loss=0.621, val_accuracy=0.636]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.72it/s][A
Epoch 16:  85%|████████▍ | 120/142 [00:57<00:10,  2.07it/s, loss=0.597, v_num=0, train_loss=0.592, train_accuracy=0.676, val_loss=0.621, val_accuracy=0.636]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.48it/s][A
Epoch 16:  86%|████████▌ | 122/142 [

Epoch 16, global step 1920: val_loss was not in top 1


Epoch 17:  80%|████████  | 114/142 [00:54<00:13,  2.11it/s, loss=0.571, v_num=0, train_loss=0.558, train_accuracy=0.750, val_loss=0.608, val_accuracy=0.643]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:23,  2.98s/it][A
Epoch 17:  82%|████████▏ | 116/142 [00:57<00:12,  2.03it/s, loss=0.571, v_num=0, train_loss=0.558, train_accuracy=0.750, val_loss=0.608, val_accuracy=0.643]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.32it/s][A
Epoch 17:  83%|████████▎ | 118/142 [00:57<00:11,  2.05it/s, loss=0.571, v_num=0, train_loss=0.558, train_accuracy=0.750, val_loss=0.608, val_accuracy=0.643]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.79it/s][A
Epoch 17:  85%|████████▍ | 120/142 [00:57<00:10,  2.08it/s, loss=0.571, v_num=0, train_loss=0.558, train_accuracy=0.750, val_loss=0.608, val_accuracy=0.643]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.51it/s][A
Epoch 17:  86%|████████▌ | 122/142 [

Epoch 17, global step 2033: val_loss was not in top 1


Epoch 18:  80%|████████  | 114/142 [00:54<00:13,  2.08it/s, loss=0.579, v_num=0, train_loss=0.611, train_accuracy=0.676, val_loss=0.582, val_accuracy=0.687]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:23,  2.98s/it][A
Epoch 18:  82%|████████▏ | 116/142 [00:57<00:12,  2.01it/s, loss=0.579, v_num=0, train_loss=0.611, train_accuracy=0.676, val_loss=0.582, val_accuracy=0.687]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.32it/s][A
Epoch 18:  83%|████████▎ | 118/142 [00:58<00:11,  2.03it/s, loss=0.579, v_num=0, train_loss=0.611, train_accuracy=0.676, val_loss=0.582, val_accuracy=0.687]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.74it/s][A
Epoch 18:  85%|████████▍ | 120/142 [00:58<00:10,  2.06it/s, loss=0.579, v_num=0, train_loss=0.611, train_accuracy=0.676, val_loss=0.582, val_accuracy=0.687]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.45it/s][A
Epoch 18:  86%|████████▌ | 122/142 [

Epoch 18, global step 2146: val_loss was not in top 1


Epoch 19:  80%|████████  | 114/142 [00:54<00:13,  2.09it/s, loss=0.535, v_num=0, train_loss=0.553, train_accuracy=0.721, val_loss=0.619, val_accuracy=0.650]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:02<01:23,  2.99s/it][A
Epoch 19:  82%|████████▏ | 116/142 [00:57<00:12,  2.01it/s, loss=0.535, v_num=0, train_loss=0.553, train_accuracy=0.721, val_loss=0.619, val_accuracy=0.650]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.31it/s][A
Epoch 19:  83%|████████▎ | 118/142 [00:57<00:11,  2.04it/s, loss=0.535, v_num=0, train_loss=0.553, train_accuracy=0.721, val_loss=0.619, val_accuracy=0.650]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.76it/s][A
Epoch 19:  85%|████████▍ | 120/142 [00:58<00:10,  2.07it/s, loss=0.535, v_num=0, train_loss=0.553, train_accuracy=0.721, val_loss=0.619, val_accuracy=0.650]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.49it/s][A
Epoch 19:  86%|████████▌ | 122/142 [

Epoch 19, global step 2259: val_loss was not in top 1


Epoch 20:  80%|████████  | 114/142 [00:54<00:13,  2.10it/s, loss=0.515, v_num=0, train_loss=0.601, train_accuracy=0.706, val_loss=0.570, val_accuracy=0.719]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:03<01:24,  3.03s/it][A
Epoch 20:  82%|████████▏ | 116/142 [00:57<00:12,  2.02it/s, loss=0.515, v_num=0, train_loss=0.601, train_accuracy=0.706, val_loss=0.570, val_accuracy=0.719]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.31it/s][A
Epoch 20:  83%|████████▎ | 118/142 [00:57<00:11,  2.04it/s, loss=0.515, v_num=0, train_loss=0.601, train_accuracy=0.706, val_loss=0.570, val_accuracy=0.719]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.75it/s][A
Epoch 20:  85%|████████▍ | 120/142 [00:57<00:10,  2.07it/s, loss=0.515, v_num=0, train_loss=0.601, train_accuracy=0.706, val_loss=0.570, val_accuracy=0.719]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.49it/s][A
Epoch 20:  86%|████████▌ | 122/142 [

Epoch 20, global step 2372: val_loss was not in top 1


Epoch 21:  80%|████████  | 114/142 [00:54<00:13,  2.10it/s, loss=0.514, v_num=0, train_loss=0.497, train_accuracy=0.721, val_loss=0.543, val_accuracy=0.745]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/29 [00:00<?, ?it/s][A
Validating:   3%|▎         | 1/29 [00:03<01:24,  3.03s/it][A
Epoch 21:  82%|████████▏ | 116/142 [00:57<00:12,  2.02it/s, loss=0.514, v_num=0, train_loss=0.497, train_accuracy=0.721, val_loss=0.543, val_accuracy=0.745]
Validating:  10%|█         | 3/29 [00:03<00:19,  1.31it/s][A
Epoch 21:  83%|████████▎ | 118/142 [00:57<00:11,  2.05it/s, loss=0.514, v_num=0, train_loss=0.497, train_accuracy=0.721, val_loss=0.543, val_accuracy=0.745]
Validating:  17%|█▋        | 5/29 [00:03<00:08,  2.75it/s][A
Epoch 21:  85%|████████▍ | 120/142 [00:57<00:10,  2.08it/s, loss=0.514, v_num=0, train_loss=0.497, train_accuracy=0.721, val_loss=0.543, val_accuracy=0.745]
Validating:  24%|██▍       | 7/29 [00:03<00:04,  4.50it/s][A
Epoch 21:  86%|████████▌ | 122/142 [

Epoch 21, global step 2485: val_loss was not in top 1


Epoch 22:  48%|████▊     | 68/142 [00:33<00:36,  2.01it/s, loss=0.514, v_num=0, train_loss=0.516, train_accuracy=0.770, val_loss=0.539, val_accuracy=0.743] 

In [None]:
trainer.test()

## Train from check point

In [None]:
# trainer.fit(model, data_module, ckpt_path="/home/masalu/checkpoints/best-checkpoint.ckpt")