In [1]:
import os
import math
import cv2
import pickle
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn.init as init



from torchvision import transforms
from torchvision import models

from tqdm import tqdm
from PIL import Image
from sklearn.metrics import balanced_accuracy_score

# import mediapipe as mp


In [10]:
label_path='/kaggle/input/data-raf/datasets/raf-basic/EmoLabel/list_patition_label.txt'

raf_path='/kaggle/input/data-raf/datasets/raf-basic'

num_head = 4
workers=4
batch_size=64
lr=0.1
device_name=0
epochs=100

In [3]:
class DAN(nn.Module):
    def __init__(self, num_class=7,num_head=4, pretrained=True):
        super(DAN, self).__init__()
        
        resnet = models.resnet18(pretrained)
        
        if pretrained:
            checkpoint = torch.load('/kaggle/input/dan-def-weight/resnet18_msceleb.pth')
            resnet.load_state_dict(checkpoint['state_dict'],strict=True)

        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.num_head = num_head
        for i in range(num_head):
            setattr(self,"cat_head%d" %i, CrossAttentionHead())
        self.sig = nn.Sigmoid()
        self.fc = nn.Linear(512, num_class)
        self.bn = nn.BatchNorm1d(num_class)


    def forward(self, x):
        x = self.features(x)
        heads = []
        for i in range(self.num_head):
            heads.append(getattr(self,"cat_head%d" %i)(x))
        
        heads = torch.stack(heads).permute([1,0,2])
        if heads.size(1)>1:
            heads = F.log_softmax(heads,dim=1)
            
        out = self.fc(heads.sum(dim=1))
        out = self.bn(out)
   
        return out, x, heads

class CrossAttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = SpatialAttention()
        self.ca = ChannelAttention()
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    def forward(self, x):
        sa = self.sa(x)
        ca = self.ca(sa)

        return ca


class SpatialAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=1),
            nn.BatchNorm2d(256),
        )
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
        )
        self.conv_1x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(1,3),padding=(0,1)),
            nn.BatchNorm2d(512),
        )
        self.conv_3x1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3,1),padding=(1,0)),
            nn.BatchNorm2d(512),
        )
        self.relu = nn.ReLU()


    def forward(self, x):
        y = self.conv1x1(x)
        y = self.relu(self.conv_3x3(y) + self.conv_1x3(y) + self.conv_3x1(y))
        y = y.sum(dim=1,keepdim=True) 
        out = x*y
        
        return out 

class ChannelAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.attention = nn.Sequential(
            nn.Linear(512, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 512),
            nn.Sigmoid()    
        )


    def forward(self, sa):
        sa = self.gap(sa)
        sa = sa.view(sa.size(0),-1)
        y = self.attention(sa)
        out = sa * y
        
        return out

In [4]:
class RafDataSet(data.Dataset):
    def __init__(self, raf_path, phase, transform = None):
        self.phase = phase
        self.transform = transform
        self.raf_path = raf_path

        df = pd.read_csv(os.path.join(self.raf_path, 'EmoLabel/list_patition_label.txt'), sep=' ', header=None,names=['name','label'])

        if phase == 'train':
            self.data = df[df['name'].str.startswith('train')]
        else:
            self.data = df[df['name'].str.startswith('test')]

        file_names = self.data.loc[:, 'name'].values
        self.label = self.data.loc[:, 'label'].values - 1 # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral

        _, self.sample_counts = np.unique(self.label, return_counts=True)
        # print(f' distribution of {phase} samples: {self.sample_counts}')

        self.file_paths = []
        for f in file_names:
            f = f.split(".")[0]
            f = f +"_aligned.jpg"
            path = os.path.join(self.raf_path, 'Image/aligned', f)
            self.file_paths.append(path)

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.label[idx]

        if self.transform is not None:
            image = self.transform(image)
        
        return image, label
    
    
# val_dataset = RafDataSet(raf_path, phase = 'test', transform = data_transforms_val)

# img, lb = val_dataset[0]
# print(img)

In [5]:
class AffinityLoss(nn.Module):
    def __init__(self, device, num_class=8, feat_dim=512):
        super(AffinityLoss, self).__init__()
        self.num_class = num_class
        self.feat_dim = feat_dim
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.device = device

        self.centers = nn.Parameter(torch.randn(self.num_class, self.feat_dim).to(device))

    def forward(self, x, labels):
        x = self.gap(x).view(x.size(0), -1)

        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_class) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_class, batch_size).t()
        distmat.addmm_(x, self.centers.t(), beta=1, alpha=-2)

        classes = torch.arange(self.num_class).long().to(self.device)
        labels = labels.unsqueeze(1).expand(batch_size, self.num_class)
        mask = labels.eq(classes.expand(batch_size, self.num_class))

        dist = distmat * mask.float()
        dist = dist / self.centers.var(dim=0).sum()

        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss

In [6]:
class PartitionLoss(nn.Module):
    def __init__(self, ):
        super(PartitionLoss, self).__init__()
    
    def forward(self, x):
        num_head = x.size(1)

        if num_head > 1:
            var = x.var(dim=1).mean()
            ## add eps to avoid empty var case
            loss = torch.log(1+num_head/(var+eps))
        else:
            loss = 0
            
        return loss

In [7]:

data_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
                transforms.RandomRotation(20),
                transforms.RandomCrop(224, padding=32)
            ], p=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(scale=(0.02,0.25)),
        ])
    
train_dataset = RafDataSet(raf_path, phase = 'train', transform = data_transforms)    
    
print('Whole train set size:', train_dataset.__len__())

train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size = batch_size,
                                               num_workers = workers,
                                               shuffle = True,  
                                               pin_memory = True)

data_transforms_val = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])   

val_dataset = RafDataSet(raf_path, phase = 'test', transform = data_transforms_val)   

print('Validation set size:', val_dataset.__len__())
    
val_loader = torch.utils.data.DataLoader(val_dataset,
                                               batch_size = batch_size,
                                               num_workers = workers,
                                               shuffle = False,  
                                               pin_memory = True)

Whole train set size: 12271
Validation set size: 3068


In [8]:

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True
    
    device = torch.device(device_name)

    model = DAN(num_head=num_head)
    model.to(device)

    device = torch.device(device_name)

    criterion_cls = torch.nn.CrossEntropyLoss()
    criterion_af = AffinityLoss(device)
    criterion_pt = PartitionLoss()

    params = list(model.parameters()) + list(criterion_af.parameters())
    optimizer = torch.optim.SGD(params,lr=lr, weight_decay = 1e-4, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

  f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [11]:
import sys
eps = sys.float_info.epsilon

best_acc = 0
for epoch in tqdm(range(1, epochs + 1)):
    running_loss = 0.0
    correct_sum = 0
    iter_cnt = 0
    model.train()

    for (imgs, targets) in train_loader:
        iter_cnt += 1
        optimizer.zero_grad()

        imgs = imgs.to(device)
        targets = targets.to(device)

        out, feat, heads = model(imgs)

        loss = criterion_cls(out, targets) + 1 * criterion_af(feat, targets) + 1 * criterion_pt(heads)  # 89.3 89.4

        loss.backward()
        optimizer.step()

        running_loss += loss
        _, predicts = torch.max(out, 1)
        correct_num = torch.eq(predicts, targets).sum()
        correct_sum += correct_num

    acc = correct_sum.float() / float(train_dataset.__len__())
    running_loss = running_loss / iter_cnt
    tqdm.write('[Epoch %d] Training accuracy: %.4f. Loss: %.3f. LR %.6f' % (
    epoch, acc, running_loss, optimizer.param_groups[0]['lr']))

    with torch.no_grad():
        running_loss = 0.0
        iter_cnt = 0
        bingo_cnt = 0
        sample_cnt = 0

        ## for calculating balanced accuracy
        y_true = []
        y_pred = []

        model.eval()
        for (imgs, targets) in val_loader:
            imgs = imgs.to(device)
            targets = targets.to(device)

            out, feat, heads = model(imgs)
            loss = criterion_cls(out, targets) + criterion_af(feat, targets) + criterion_pt(heads)

            running_loss += loss
            iter_cnt += 1
            _, predicts = torch.max(out, 1)
            correct_num = torch.eq(predicts, targets)
            bingo_cnt += correct_num.sum().cpu()
            sample_cnt += out.size(0)

            y_true.append(targets.cpu().numpy())
            y_pred.append(predicts.cpu().numpy())

        running_loss = running_loss / iter_cnt
        scheduler.step()

        acc = bingo_cnt.float() / float(sample_cnt)
        acc = np.around(acc.numpy(), 4)
        best_acc = max(acc, best_acc)

        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        balanced_acc = np.around(balanced_accuracy_score(y_true, y_pred), 4)

        tqdm.write(
            "[Epoch %d] Validation accuracy:%.4f. bacc:%.4f. Loss:%.3f" % (epoch, acc, balanced_acc, running_loss))
        tqdm.write("best_acc:" + str(best_acc))

        if acc > 0.89 and acc == best_acc:
            torch.save({'iter': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(), },
                       os.path.join("rafdb_epoch" + str(epoch) + "_acc" + str(acc) + "_bacc" + str(
                           balanced_acc) + ".pth"))
            tqdm.write('Model saved.')

  0%|          | 0/50 [00:35<?, ?it/s]

[Epoch 1] Training accuracy: 0.9705. Loss: 0.253. LR 0.000010


  2%|▏         | 1/50 [00:43<35:11, 43.10s/it]

[Epoch 1] Validation accuracy:0.8872. bacc:0.8154. Loss:0.568
best_acc:0.8872


  2%|▏         | 1/50 [01:19<35:11, 43.10s/it]

[Epoch 2] Training accuracy: 0.9676. Loss: 0.256. LR 0.000010


  4%|▍         | 2/50 [01:26<34:46, 43.47s/it]

[Epoch 2] Validation accuracy:0.8875. bacc:0.8156. Loss:0.568
best_acc:0.8875


  4%|▍         | 2/50 [02:02<34:46, 43.47s/it]

[Epoch 3] Training accuracy: 0.9696. Loss: 0.253. LR 0.000010


  6%|▌         | 3/50 [02:09<33:42, 43.03s/it]

[Epoch 3] Validation accuracy:0.8856. bacc:0.8151. Loss:0.567
best_acc:0.8875


  6%|▌         | 3/50 [02:44<33:42, 43.03s/it]

[Epoch 4] Training accuracy: 0.9677. Loss: 0.261. LR 0.000010


  8%|▊         | 4/50 [02:51<32:44, 42.70s/it]

[Epoch 4] Validation accuracy:0.8898. bacc:0.8178. Loss:0.564
best_acc:0.8898


  8%|▊         | 4/50 [03:29<32:44, 42.70s/it]

[Epoch 5] Training accuracy: 0.9669. Loss: 0.260. LR 0.000010


 10%|█         | 5/50 [03:36<32:42, 43.61s/it]

[Epoch 5] Validation accuracy:0.8859. bacc:0.8144. Loss:0.569
best_acc:0.8898


 10%|█         | 5/50 [04:12<32:42, 43.61s/it]

[Epoch 6] Training accuracy: 0.9677. Loss: 0.261. LR 0.000010


 12%|█▏        | 6/50 [04:19<31:44, 43.29s/it]

[Epoch 6] Validation accuracy:0.8853. bacc:0.8153. Loss:0.565
best_acc:0.8898


 12%|█▏        | 6/50 [04:54<31:44, 43.29s/it]

[Epoch 7] Training accuracy: 0.9691. Loss: 0.256. LR 0.000010


 14%|█▍        | 7/50 [05:01<30:50, 43.04s/it]

[Epoch 7] Validation accuracy:0.8862. bacc:0.8136. Loss:0.563
best_acc:0.8898


 14%|█▍        | 7/50 [05:36<30:50, 43.04s/it]

[Epoch 8] Training accuracy: 0.9665. Loss: 0.263. LR 0.000010


 16%|█▌        | 8/50 [05:43<29:52, 42.69s/it]

[Epoch 8] Validation accuracy:0.8866. bacc:0.8179. Loss:0.568
best_acc:0.8898


 16%|█▌        | 8/50 [06:19<29:52, 42.69s/it]

[Epoch 9] Training accuracy: 0.9691. Loss: 0.253. LR 0.000010


 18%|█▊        | 9/50 [06:26<29:07, 42.63s/it]

[Epoch 9] Validation accuracy:0.8879. bacc:0.8145. Loss:0.568
best_acc:0.8898


 18%|█▊        | 9/50 [07:02<29:07, 42.63s/it]

[Epoch 10] Training accuracy: 0.9676. Loss: 0.258. LR 0.000010


 20%|██        | 10/50 [07:09<28:27, 42.68s/it]

[Epoch 10] Validation accuracy:0.8866. bacc:0.8133. Loss:0.561
best_acc:0.8898


 20%|██        | 10/50 [07:43<28:27, 42.68s/it]

[Epoch 11] Training accuracy: 0.9687. Loss: 0.260. LR 0.000001


 22%|██▏       | 11/50 [07:51<27:36, 42.47s/it]

[Epoch 11] Validation accuracy:0.8882. bacc:0.8124. Loss:0.564
best_acc:0.8898


 22%|██▏       | 11/50 [08:26<27:36, 42.47s/it]

[Epoch 12] Training accuracy: 0.9659. Loss: 0.265. LR 0.000001


 24%|██▍       | 12/50 [08:34<26:59, 42.61s/it]

[Epoch 12] Validation accuracy:0.8869. bacc:0.8209. Loss:0.571
best_acc:0.8898


 24%|██▍       | 12/50 [09:09<26:59, 42.61s/it]

[Epoch 13] Training accuracy: 0.9686. Loss: 0.257. LR 0.000001


 26%|██▌       | 13/50 [09:17<26:20, 42.73s/it]

[Epoch 13] Validation accuracy:0.8885. bacc:0.8135. Loss:0.562
best_acc:0.8898


 26%|██▌       | 13/50 [09:53<26:20, 42.73s/it]

[Epoch 14] Training accuracy: 0.9665. Loss: 0.260. LR 0.000001


 28%|██▊       | 14/50 [10:00<25:46, 42.96s/it]

[Epoch 14] Validation accuracy:0.8849. bacc:0.8143. Loss:0.567
best_acc:0.8898


 28%|██▊       | 14/50 [10:39<25:46, 42.96s/it]

[Epoch 15] Training accuracy: 0.9681. Loss: 0.261. LR 0.000001


 30%|███       | 15/50 [10:47<25:42, 44.06s/it]

[Epoch 15] Validation accuracy:0.8885. bacc:0.8165. Loss:0.566
best_acc:0.8898


 30%|███       | 15/50 [11:21<25:42, 44.06s/it]

[Epoch 16] Training accuracy: 0.9682. Loss: 0.258. LR 0.000001


 32%|███▏      | 16/50 [11:28<24:32, 43.32s/it]

[Epoch 16] Validation accuracy:0.8875. bacc:0.8208. Loss:0.565
best_acc:0.8898


 32%|███▏      | 16/50 [12:04<24:32, 43.32s/it]

[Epoch 17] Training accuracy: 0.9688. Loss: 0.250. LR 0.000001


 34%|███▍      | 17/50 [12:11<23:44, 43.17s/it]

[Epoch 17] Validation accuracy:0.8853. bacc:0.8131. Loss:0.565
best_acc:0.8898


 34%|███▍      | 17/50 [12:47<23:44, 43.17s/it]

[Epoch 18] Training accuracy: 0.9681. Loss: 0.260. LR 0.000001


 36%|███▌      | 18/50 [12:55<23:07, 43.36s/it]

[Epoch 18] Validation accuracy:0.8872. bacc:0.8157. Loss:0.563
best_acc:0.8898


 36%|███▌      | 18/50 [13:33<23:07, 43.36s/it]

[Epoch 19] Training accuracy: 0.9696. Loss: 0.253. LR 0.000001


 38%|███▊      | 19/50 [13:40<22:40, 43.88s/it]

[Epoch 19] Validation accuracy:0.8882. bacc:0.8189. Loss:0.569
best_acc:0.8898


 38%|███▊      | 19/50 [14:16<22:40, 43.88s/it]

[Epoch 20] Training accuracy: 0.9700. Loss: 0.253. LR 0.000001


 40%|████      | 20/50 [14:24<21:57, 43.92s/it]

[Epoch 20] Validation accuracy:0.8853. bacc:0.8148. Loss:0.562
best_acc:0.8898


 40%|████      | 20/50 [14:59<21:57, 43.92s/it]

[Epoch 21] Training accuracy: 0.9703. Loss: 0.256. LR 0.000000


 42%|████▏     | 21/50 [15:06<20:56, 43.33s/it]

[Epoch 21] Validation accuracy:0.8866. bacc:0.8125. Loss:0.564
best_acc:0.8898


 42%|████▏     | 21/50 [15:41<20:56, 43.33s/it]

[Epoch 22] Training accuracy: 0.9694. Loss: 0.254. LR 0.000000


 44%|████▍     | 22/50 [15:49<20:08, 43.15s/it]

[Epoch 22] Validation accuracy:0.8872. bacc:0.8139. Loss:0.563
best_acc:0.8898


 44%|████▍     | 22/50 [16:25<20:08, 43.15s/it]

[Epoch 23] Training accuracy: 0.9703. Loss: 0.251. LR 0.000000


 46%|████▌     | 23/50 [16:33<19:31, 43.39s/it]

[Epoch 23] Validation accuracy:0.8875. bacc:0.8106. Loss:0.569
best_acc:0.8898


 46%|████▌     | 23/50 [17:09<19:31, 43.39s/it]

[Epoch 24] Training accuracy: 0.9676. Loss: 0.259. LR 0.000000


 48%|████▊     | 24/50 [17:17<18:53, 43.61s/it]

[Epoch 24] Validation accuracy:0.8885. bacc:0.8166. Loss:0.570
best_acc:0.8898


 48%|████▊     | 24/50 [17:52<18:53, 43.61s/it]

[Epoch 25] Training accuracy: 0.9688. Loss: 0.253. LR 0.000000


 50%|█████     | 25/50 [18:00<18:07, 43.49s/it]

[Epoch 25] Validation accuracy:0.8869. bacc:0.8099. Loss:0.563
best_acc:0.8898


 50%|█████     | 25/50 [18:39<18:07, 43.49s/it]

[Epoch 26] Training accuracy: 0.9700. Loss: 0.254. LR 0.000000


 52%|█████▏    | 26/50 [18:46<17:43, 44.32s/it]

[Epoch 26] Validation accuracy:0.8856. bacc:0.8143. Loss:0.565
best_acc:0.8898


 52%|█████▏    | 26/50 [19:25<17:43, 44.32s/it]

[Epoch 27] Training accuracy: 0.9694. Loss: 0.256. LR 0.000000


 52%|█████▏    | 26/50 [19:32<17:43, 44.32s/it]

[Epoch 27] Validation accuracy:0.8905. bacc:0.8179. Loss:0.566
best_acc:0.8905


 54%|█████▍    | 27/50 [19:33<17:13, 44.95s/it]

Model saved.


 54%|█████▍    | 27/50 [20:11<17:13, 44.95s/it]

[Epoch 28] Training accuracy: 0.9663. Loss: 0.260. LR 0.000000


 56%|█████▌    | 28/50 [20:19<16:35, 45.27s/it]

[Epoch 28] Validation accuracy:0.8889. bacc:0.8145. Loss:0.566
best_acc:0.8905


 56%|█████▌    | 28/50 [20:57<16:35, 45.27s/it]

[Epoch 29] Training accuracy: 0.9697. Loss: 0.255. LR 0.000000


 58%|█████▊    | 29/50 [21:04<15:51, 45.31s/it]

[Epoch 29] Validation accuracy:0.8862. bacc:0.8173. Loss:0.566
best_acc:0.8905


 58%|█████▊    | 29/50 [21:42<15:51, 45.31s/it]

[Epoch 30] Training accuracy: 0.9676. Loss: 0.264. LR 0.000000


 60%|██████    | 30/50 [21:50<15:10, 45.54s/it]

[Epoch 30] Validation accuracy:0.8875. bacc:0.8145. Loss:0.563
best_acc:0.8905


 60%|██████    | 30/50 [22:29<15:10, 45.54s/it]

[Epoch 31] Training accuracy: 0.9720. Loss: 0.248. LR 0.000000


 62%|██████▏   | 31/50 [22:36<14:28, 45.72s/it]

[Epoch 31] Validation accuracy:0.8856. bacc:0.8130. Loss:0.563
best_acc:0.8905


 62%|██████▏   | 31/50 [23:13<14:28, 45.72s/it]

[Epoch 32] Training accuracy: 0.9680. Loss: 0.259. LR 0.000000


 64%|██████▍   | 32/50 [23:20<13:34, 45.26s/it]

[Epoch 32] Validation accuracy:0.8879. bacc:0.8178. Loss:0.568
best_acc:0.8905


 64%|██████▍   | 32/50 [23:58<13:34, 45.26s/it]

[Epoch 33] Training accuracy: 0.9680. Loss: 0.257. LR 0.000000


 66%|██████▌   | 33/50 [24:06<12:49, 45.26s/it]

[Epoch 33] Validation accuracy:0.8862. bacc:0.8101. Loss:0.566
best_acc:0.8905


 66%|██████▌   | 33/50 [24:44<12:49, 45.26s/it]

[Epoch 34] Training accuracy: 0.9698. Loss: 0.254. LR 0.000000


 68%|██████▊   | 34/50 [24:51<12:04, 45.31s/it]

[Epoch 34] Validation accuracy:0.8872. bacc:0.8137. Loss:0.567
best_acc:0.8905


 68%|██████▊   | 34/50 [25:30<12:04, 45.31s/it]

[Epoch 35] Training accuracy: 0.9685. Loss: 0.258. LR 0.000000


 70%|███████   | 35/50 [25:38<11:26, 45.75s/it]

[Epoch 35] Validation accuracy:0.8872. bacc:0.8204. Loss:0.568
best_acc:0.8905


 70%|███████   | 35/50 [26:16<11:26, 45.75s/it]

[Epoch 36] Training accuracy: 0.9677. Loss: 0.262. LR 0.000000


 72%|███████▏  | 36/50 [26:24<10:41, 45.81s/it]

[Epoch 36] Validation accuracy:0.8872. bacc:0.8141. Loss:0.561
best_acc:0.8905


 72%|███████▏  | 36/50 [27:02<10:41, 45.81s/it]

[Epoch 37] Training accuracy: 0.9694. Loss: 0.256. LR 0.000000


 74%|███████▍  | 37/50 [27:10<09:56, 45.90s/it]

[Epoch 37] Validation accuracy:0.8879. bacc:0.8176. Loss:0.573
best_acc:0.8905


 74%|███████▍  | 37/50 [27:48<09:56, 45.90s/it]

[Epoch 38] Training accuracy: 0.9694. Loss: 0.255. LR 0.000000


 76%|███████▌  | 38/50 [27:56<09:10, 45.84s/it]

[Epoch 38] Validation accuracy:0.8889. bacc:0.8189. Loss:0.564
best_acc:0.8905


 76%|███████▌  | 38/50 [28:34<09:10, 45.84s/it]

[Epoch 39] Training accuracy: 0.9682. Loss: 0.257. LR 0.000000


 78%|███████▊  | 39/50 [28:42<08:25, 45.98s/it]

[Epoch 39] Validation accuracy:0.8879. bacc:0.8169. Loss:0.570
best_acc:0.8905


 78%|███████▊  | 39/50 [29:20<08:25, 45.98s/it]

[Epoch 40] Training accuracy: 0.9710. Loss: 0.253. LR 0.000000


 80%|████████  | 40/50 [29:28<07:38, 45.89s/it]

[Epoch 40] Validation accuracy:0.8892. bacc:0.8137. Loss:0.568
best_acc:0.8905


 80%|████████  | 40/50 [30:04<07:38, 45.89s/it]

[Epoch 41] Training accuracy: 0.9694. Loss: 0.255. LR 0.000000


 82%|████████▏ | 41/50 [30:11<06:46, 45.19s/it]

[Epoch 41] Validation accuracy:0.8866. bacc:0.8157. Loss:0.565
best_acc:0.8905


 82%|████████▏ | 41/50 [30:50<06:46, 45.19s/it]

[Epoch 42] Training accuracy: 0.9683. Loss: 0.258. LR 0.000000


 84%|████████▍ | 42/50 [30:58<06:04, 45.54s/it]

[Epoch 42] Validation accuracy:0.8859. bacc:0.8147. Loss:0.565
best_acc:0.8905


 84%|████████▍ | 42/50 [31:36<06:04, 45.54s/it]

[Epoch 43] Training accuracy: 0.9711. Loss: 0.251. LR 0.000000


 86%|████████▌ | 43/50 [31:43<05:19, 45.59s/it]

[Epoch 43] Validation accuracy:0.8875. bacc:0.8162. Loss:0.566
best_acc:0.8905


 86%|████████▌ | 43/50 [32:22<05:19, 45.59s/it]

[Epoch 44] Training accuracy: 0.9693. Loss: 0.256. LR 0.000000


 88%|████████▊ | 44/50 [32:30<04:34, 45.79s/it]

[Epoch 44] Validation accuracy:0.8869. bacc:0.8157. Loss:0.566
best_acc:0.8905


 88%|████████▊ | 44/50 [33:08<04:34, 45.79s/it]

[Epoch 45] Training accuracy: 0.9662. Loss: 0.260. LR 0.000000


 90%|█████████ | 45/50 [33:16<03:49, 45.90s/it]

[Epoch 45] Validation accuracy:0.8885. bacc:0.8166. Loss:0.566
best_acc:0.8905


 90%|█████████ | 45/50 [33:54<03:49, 45.90s/it]

[Epoch 46] Training accuracy: 0.9667. Loss: 0.260. LR 0.000000


 92%|█████████▏| 46/50 [34:01<03:03, 45.76s/it]

[Epoch 46] Validation accuracy:0.8875. bacc:0.8162. Loss:0.566
best_acc:0.8905


 92%|█████████▏| 46/50 [34:39<03:03, 45.76s/it]

[Epoch 47] Training accuracy: 0.9692. Loss: 0.259. LR 0.000000


 94%|█████████▍| 47/50 [34:46<02:16, 45.41s/it]

[Epoch 47] Validation accuracy:0.8898. bacc:0.8186. Loss:0.571
best_acc:0.8905


 94%|█████████▍| 47/50 [35:22<02:16, 45.41s/it]

[Epoch 48] Training accuracy: 0.9711. Loss: 0.252. LR 0.000000


 96%|█████████▌| 48/50 [35:30<01:29, 44.99s/it]

[Epoch 48] Validation accuracy:0.8869. bacc:0.8135. Loss:0.562
best_acc:0.8905


 96%|█████████▌| 48/50 [36:07<01:29, 44.99s/it]

[Epoch 49] Training accuracy: 0.9663. Loss: 0.257. LR 0.000000


 98%|█████████▊| 49/50 [36:15<00:44, 44.99s/it]

[Epoch 49] Validation accuracy:0.8882. bacc:0.8191. Loss:0.565
best_acc:0.8905


 98%|█████████▊| 49/50 [36:53<00:44, 44.99s/it]

[Epoch 50] Training accuracy: 0.9677. Loss: 0.260. LR 0.000000


100%|██████████| 50/50 [37:02<00:00, 44.44s/it]

[Epoch 50] Validation accuracy:0.8872. bacc:0.8211. Loss:0.569
best_acc:0.8905



