In [1]:
import time
from tqdm import tqdm
import sys
import glob
import gc
import os
sys.path.append('./lib_models')
#sys.path.append('')

import pandas as pd
import numpy as np
import scipy as sp
import cv2
from PIL import Image
from matplotlib import pyplot as plt
import sklearn.metrics
import warnings
import pydicom
import dicomsdl
from joblib import Parallel, delayed
import pickle
import gzip
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from multiprocessing import Pool
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn
from torchvision.io import read_image
import segmentation_models_pytorch as smp
import timm
from timm.utils import AverageMeter
from timm.models import resnet
import timm_new

from monai.transforms import Resize
import  monai.transforms as transforms

from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


import wandb

# wandb.login(key = '585f58f321685308f7933861d9dde7488de0970b')
wandb.login(key ="e82da6563d93654dbd624681e5321e560269f670")

2023-10-15 02:54:18.875884: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-15 02:54:19.046849: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[34m[1mwandb[0m: Currently logged in as: [33mehdgnsdl[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/donghun/.netrc


True

In [2]:
!nvidia-smi

Sun Oct 15 02:54:21 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.113.01             Driver Version: 535.113.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off | 00000000:01:00.0 Off |                  Off |
| 31%   38C    P8              24W / 450W |      6MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off | 00000000:09:00.0 Off |  

# Parameters

In [3]:
backbone = 'timm/resnet10t.c3_in1k'
# backbone = 'timm/resnetrs50.tf_in1k'


IS_WANDB = True
PROJECT_NAME = 'RSNA_ABTD'
GROUP_NAME= 'preprocessing_test'
RUN_NAME=   f'{backbone}_2nd_std_opt-seg0.938_10t'

if not IS_WANDB:
    PROJECT_NAME = 'Dummy_Project'

BASE_PATH  = './kaggle/input/rsna-2023-abdominal-trauma-detection'
TRAIN_PATH = f'{BASE_PATH}/train_images'
DATA_PATH = f'{BASE_PATH}/3d_preprocessed'

seg_inference_dir = f'{BASE_PATH}/seg_infer_results'
cropped_img_dir   = f'{BASE_PATH}/3d_preprocessed_crop_ratio'

if not os.path.isdir(DATA_PATH):
    os.mkdir(DATA_PATH)

# RESOL = 128
UP_RESOL = 128
N_CHANNELS = 6
BATCH_SIZE = 8
ACCUM_STEPS = 4
N_WORKERS  = 8
# N_WORKERS  = 16
LR = 2e-4
N_EPOCHS = 300
EARLY_STOP_COUNT = 50
N_FOLDS  = 5
N_PREPROCESS_CHUNKS = 12
PCT_START = 0.1
n_blocks = 4
drop_rate = 0.2
drop_path_rate = 0.2
p_mixup = 0.0



DROP_REGION= {'HOLES': [3, 20],
                'SIZE': [5, 20],
                'PROB': 0.5,
                'FILL': (-3, 3)}

wandb_config = {
    'UP_RESOL': UP_RESOL,
    'BACKBONE': backbone,
    'N_CHANNELS': N_CHANNELS,
    'N_EPOCHS': N_EPOCHS,
    'N_FOLDS': N_FOLDS,
    'EARLY_STOP_COUNT': EARLY_STOP_COUNT,
    'BATCH_SIZE': BATCH_SIZE,    
    'LR': LR,
    'N_EPOCHS': N_EPOCHS,
    'DROP_RATE': drop_rate,
    'DROP_PATH_RATE': drop_path_rate,
    'MIXUP_RATE': p_mixup,
    'DROP_REGION': DROP_REGION,
    'PCT_START': PCT_START
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [4]:
# Mask related parameters
# Order 0: Bowel, 1: left kidney, 2: right kidney, 3: liver, 4: spleen

chan_keys = ['bowel', 'left_kidney', 'right_kidney', 'liver', 'spleen', 'total']
chan_dict = {}
for i in range(0, 6):
    chan_dict[i] = chan_keys[i]

train_meta_df = pd.read_csv(f'{BASE_PATH}/train_meta.csv')
np.unique(train_meta_df['fold'].to_numpy(), return_counts = True)

(array([0, 1, 2, 3, 4]), array([929, 947, 948, 951, 936]))

In [5]:
def compress(name, data):
    with gzip.open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress(name):
    with gzip.open(name, 'rb') as f:
        data = pickle.load(f)
    return data


def compress_fast(name, data):
    with open(name, 'wb') as f:
        pickle.dump(data, f)

def decompress_fast(name):
    with open(name, 'rb') as f:
        data = pickle.load(f)
    return data

# Model

In [6]:
def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

In [7]:
class Timm3DModel(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModel, self).__init__()
        self.n_labels = n_labels
        self.encoder = timm_new.create_model(
            backbone,
            in_chans=n_channels,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, n_channels, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]

        self.avgpool = nn.AvgPool2d(5, 4, 2)
        
        [_.shape[1] for _ in g]
        self.convs1x1 = nn.ModuleList()    
        self.batchnorms = nn.ModuleList()    
        self.batchnorms13 = nn.ModuleList()
        for i in range(0, len(g)):
            self.convs1x1.append(nn.Conv2d(g[i].shape[1], self.n_labels, 1))
        del g
        gc.collect()
        
    def forward(self,x):
        batch_size = x.shape[0]
        global_features = self.encoder(x)[:n_blocks]        
        for i in range(0, len(global_features)):
            global_features[i] = self.convs1x1[i](global_features[i])
        return global_features
    
    
class Timm3DModelClassifier(nn.Module):
    def __init__(self, backbone, n_channels, n_labels, segtype='unet', pretrained=False):
        super(Timm3DModelClassifier, self).__init__()
        self.model_3d = Timm3DModel(backbone, n_channels, n_labels, segtype, pretrained)
        self.model_3d = convert_3d(self.model_3d)
        self.n_channels = n_channels
        self.n_labels = n_labels                        
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = self.model_3d(x)
        pooled_features = []
        for i in range(0, len(x)):
            pooled_features.append(torch.reshape(torch.mean(x[i], dim = (2, 3, 4)), (batch_size, self.n_labels, 1)))
        pooled_features = torch.cat(pooled_features, dim=2)     
        labels = torch.mean(pooled_features, dim = 2)
        return labels

In [8]:
class AbdominalClassifier(nn.Module):
    def __init__(self, device = DEVICE):
        super().__init__()
        self.device = device
        
        self.model3d_bowel        = Timm3DModelClassifier(backbone, 1, 2)      
        self.model3d_extrav       = Timm3DModelClassifier(backbone, 1, 2)
        self.model3d_kidney_left  = Timm3DModelClassifier(backbone, 1, 3)
        self.model3d_kidney_right = Timm3DModelClassifier(backbone, 1, 3)
        self.model3d_liver        = Timm3DModelClassifier(backbone, 1, 3)
        self.model3d_spleen       = Timm3DModelClassifier(backbone, 1, 3)
        
        self.flatten  = nn.Flatten()
        self.dropout  = nn.Dropout(p=0.5)
        self.softmax  = nn.Softmax(dim=1)        
        self.maxpool  = nn.MaxPool1d(5, 1)
        
    def forward(self, x_bowel, x_kidney_left, x_kidney_right, x_liver, x_spleen, x_total):
        bowel_label        = self.model3d_bowel(x_bowel)
        extrav_label       = self.model3d_extrav(x_total)
        kidney_label_left  = self.model3d_kidney_left(x_kidney_left)
        kidney_label_right = self.model3d_kidney_right(x_kidney_right)
        kidney_label       = (kidney_label_left + kidney_label_right)/2
        liver_label        = self.model3d_liver(x_liver)
        spleen_label       = self.model3d_spleen(x_spleen)
        
        
        labels = torch.cat([bowel_label, extrav_label, kidney_label, liver_label, spleen_label], dim = 1)
        
        bowel_soft = self.softmax(bowel_label)
        extrav_soft = self.softmax(extrav_label)
        kidney_soft = self.softmax(kidney_label)
        liver_soft = self.softmax(liver_label)
        spleen_soft = self.softmax(spleen_label)

        any_in = torch.cat([1-bowel_soft[:,0:1], 1-extrav_soft[:,0:1], 
                            1-kidney_soft[:,0:1], 1-liver_soft[:,0:1], 1-spleen_soft[:,0:1]], dim = 1) 
        any_in = self.maxpool(any_in)
        any_not_in = 1-any_in
        any_in = torch.cat([any_not_in, any_in], dim = 1)

        any_in = torch.log(any_in + 1e-6)  # 1e-6은 0을 처리하기 위한 작은 값
        return labels, any_in
    

In [9]:
model = AbdominalClassifier()

def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

print(get_n_params(model))
del model
gc.collect()

86478624


0

# Metric & Loss

In [10]:
weights = np.ones(2)
weights[1] = 2
crit_bowel  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
weights[1] = 6
crit_extrav = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_any = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

weights = np.ones((3))
weights[1] = 2
weights[2] = 4
crit_kidney = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_liver  = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))
crit_spleen = nn.CrossEntropyLoss(weight = torch.from_numpy(weights).to(DEVICE))

In [11]:
def normalize_to_one(tensor):
    norm = torch.sum(tensor, 1)
    for i in range(0, tensor.shape[1]):
        tensor[:,i]/=norm
    return tensor

def apply_softmax_to_labels(X_out):
    softmax = nn.Softmax(dim=1)

    X_out[:,:2]    = normalize_to_one(softmax(X_out[:,:2]))
    X_out[:,2:4]   = normalize_to_one(softmax(X_out[:,2:4]))
    X_out[:,4:7]   = normalize_to_one(softmax(X_out[:,4:7]))
    X_out[:,7:10]  = normalize_to_one(softmax(X_out[:,7:10]))
    X_out[:,10:13] = normalize_to_one(softmax(X_out[:,10:13]))

    return X_out

def calculate_score(X_outs, ys, step = 'train'):
    X_outs = X_outs.astype(np.float64)
    ys     = ys.astype(np.float64)

    isnan_x = np.isnan(X_outs).astype(int)
    isnan_y = np.isnan(ys).astype(int)
    
    if(np.max(isnan_x)>0):
        print('xnan')
    if(np.max(isnan_y)>0):
        print('ynan')
        
    X_outs[:, 13:15] = nn.Softmax(dim=1)(torch.from_numpy(X_outs[:, 13:15])).numpy()
    bowel_weights  =  ys[:,0] + 2*ys[:,1]
    extrav_weights = ys[:,2] + 6*ys[:,3]
    kidney_weights = ys[:,4] + 2*ys[:,5] + 4*ys[:,6]
    liver_weights  = ys[:,7] + 2*ys[:,8] + 4*ys[:,9]
    spleen_weights = ys[:,10] + 2*ys[:,11] + 4*ys[:,12]
    any_in_weights = ys[:,13] + 6*ys[:,14]

    bowel_loss  = sklearn.metrics.log_loss(ys[:,:2], X_outs[:,:2], sample_weight = bowel_weights.astype(np.float64))
    extrav_loss = sklearn.metrics.log_loss(ys[:,2:4], X_outs[:,2:4], sample_weight = extrav_weights.astype(np.float64))
    kidney_loss = sklearn.metrics.log_loss(ys[:,4:7], X_outs[:,4:7], sample_weight = kidney_weights.astype(np.float64))
    liver_loss  = sklearn.metrics.log_loss(ys[:,7:10], X_outs[:,7:10], sample_weight = liver_weights.astype(np.float64))
    spleen_loss = sklearn.metrics.log_loss(ys[:,10:13], X_outs[:,10:13], sample_weight = spleen_weights.astype(np.float64))
    any_in_loss = sklearn.metrics.log_loss(ys[:,13:15], X_outs[:,13:15], sample_weight =  any_in_weights.astype(np.float64))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6

    losses= {f'{step}_bowel_metric': bowel_loss, f'{step}_extrav_metric': extrav_loss, f'{step}_kidney_metric': kidney_loss,
             f'{step}_liver_metric': liver_loss, f'{step}_spleen_metric': spleen_loss, f'{step}_any_in_metric': any_in_loss,
             f'{step}_avg_metric': avg_loss}

    wandb.log(losses)
    return avg_loss

def calculate_loss(X_out, X_any, y):
    batch_size = X_out.shape[0]
    bowel_loss  = crit_bowel(X_out[:,:2], y[:,:2])
    extrav_loss = crit_extrav(X_out[:,2:4], y[:,2:4])
    kidney_loss = crit_kidney(X_out[:,4:7], y[:,4:7])
    liver_loss  = crit_liver(X_out[:,7:10], y[:,7:10])
    spleen_loss = crit_spleen(X_out[:,10:13], y[:,10:13])
    any_in_loss = crit_any(X_any,  torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1))
    
    avg_loss = (bowel_loss + extrav_loss + kidney_loss + liver_loss + spleen_loss + any_in_loss)/6
    return bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss

# Augmentations

In [12]:
def mixup(inputs, truth, clip=[0, 1]):
    indices = torch.randperm(inputs.size(0))
    shuffled_input = inputs[indices]
    shuffled_labels = truth[indices]

    lam = np.random.uniform(clip[0], clip[1])
    inputs = inputs * lam + shuffled_input * (1 - lam)
    return inputs, truth, shuffled_labels, lam

# transforms_train = transforms.Compose([
#     transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=0),    
#     transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=1),
#     transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=2),    
#     transforms.RandGridDistortiond(keys=chan_keys, prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),    
# ])

transforms_train = transforms.Compose([
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=0),    
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=1),
    transforms.RandFlipd(keys=chan_keys, prob=0.5, spatial_axis=2),    
    transforms.RandGridDistortiond(keys=chan_keys, prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest"),
    # transforms.RandAffined(keys=chan_keys, prob=0.5, translate_range=(5, 5), rotate_range=(np.pi/6, np.pi/6), scale_range=(0.15, 0.15)),
    # transforms.RandZoomd(keys=chan_keys, prob=0.5, min_zoom=0.9, max_zoom=1.1),
    # transforms.RandGaussianNoised(keys=chan_keys, prob=0.5, mean=0.0, std=0.1),
    transforms.RandAdjustContrastd(keys=chan_keys, prob=0.5, gamma=(0.9, 1.1)),
    # transforms.RandElasticTransformd(keys=chan_keys, prob=0.5, sigma=30, alpha=200, order=1),
    # transforms.RandGaussianSmoothd(keys=chan_keys, prob=0.5, sigma_x=(0.5, 1.5), sigma_y=(0.5, 1.5), sigma_z=(0.5, 1.5)),
    # transforms.NormalizeIntensityd(keys=chan_keys, subtrahend=127.5, divisor=127.5), # Normalization
])

remain_transforms_train = transforms.Compose([
    #transforms.RandCoarseDropout(holes = DROP_REGION['HOLES'][0], max_holes = DROP_REGION['HOLES'][1],
    #                        spatial_size = DROP_REGION['SIZE'][0]*np.ones(3, int), max_spatial_size =DROP_REGION['SIZE'][1]*np.ones(3, int), 
    #                        prob = DROP_REGION['PROB'], 
    #                        fill_value = DROP_REGION['FILL'])
])



transforms_common_preprocessing = transforms.Compose([
    #transforms.HistogramNormalize(num_bins = 256, min = 0, max = 255)
])

# Dataset

In [13]:
class AbdominalCTDataset(Dataset):
    def __init__(self, meta_df, is_train = True, transform_set = None, remain_transforms_set = None):
        self.meta_df = meta_df
        self.is_train = is_train
        self.transform_set = transform_set
        self.remain_transforms_set = remain_transforms_set
        self.data_3ds = []        
        for i in tqdm(range(0, len(self.meta_df))):
            tmp_data_3ds = {}
            base_name = self.meta_df.iloc[i]['cropped_path']            
            for j in range(0, 6):
                tmp_data_3d = decompress_fast(f'{base_name}_{j}').unsqueeze(0)
                #tmp_data_3d = torch.from_numpy(tmp_data_3d)
                tmp_data_3ds[chan_dict[j]] = tmp_data_3d            
            self.data_3ds.append(tmp_data_3ds)

    def __len__(self):
        return len(self.meta_df)
    
    def __getitem__(self, idx):
        row = self.meta_df.iloc[idx]
        label = row[['bowel_healthy','bowel_injury',
                    'extravasation_healthy','extravasation_injury',
                    'kidney_healthy','kidney_low','kidney_high',
                    'liver_healthy','liver_low','liver_high',
                    'spleen_healthy','spleen_low','spleen_high', 'any_injury']]
        
        data_3d = self.data_3ds[idx].copy()
        
        if self.is_train:
            if self.transform_set is not None:
                data_3d = self.transform_set(data_3d)

            if self.remain_transforms_set is not None:   
                for i in range(0, 6):
                    data_3d[chan_dict[i]] = self.remain_transforms_set(data_3d[chan_dict[i]])
        
        label = label.to_numpy().astype(np.float32)                    
        label = torch.from_numpy(label)
                    
        return data_3d['bowel'], data_3d['left_kidney'], data_3d['right_kidney'], \
                data_3d['liver'], data_3d['spleen'], data_3d['total'], label        


In [14]:
#data_3d= torch.rand((6, 128, 128, 128))*0.5
#data_3d = remain_transforms_train(data_3d)
#torch.max(data_3d)
#print(data_3d)

# Train loop

In [15]:
def train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale):
    train_meters = {'loss': AverageMeter()}
    model.train()
    X_outs=[]
    ys=[]
    accum_counter = 0
    counter = 0
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in train_loader:
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)
        current_lr = float(scheduler.get_last_lr()[0])
        
        batch_size = X_bowel.shape[0]
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=True):  
            X_out, X_any  = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)
            bowel_loss, extrav_loss, kidney_loss, liver_loss, spleen_loss, any_in_loss, avg_loss = calculate_loss(X_out, X_any, y)
                
            step = 'train'
            wandb.log({ 'lr': current_lr,
                        f'{step}_bowel_loss': bowel_loss.item(),
                        f'{step}_extrav_loss': extrav_loss.item(),
                        f'{step}_kidney_loss': kidney_loss.item(),
                        f'{step}_liver_loss': liver_loss.item(),
                        f'{step}_spleen_loss': spleen_loss.item(),
                        f'{step}_any_loss': any_in_loss.item(),
                        f'{step}_avg_loss': avg_loss.item()
                        })
            
            scaler.scale(avg_loss/accum_scale[accum_counter]).backward()
            if(counter==accum_points[accum_counter]):
                scaler.step(optimizer)
                scheduler.step()
                scaler.update()    
                accum_counter+=1                
        counter+=1                   

        #Metric calculation
        y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)    
        X_out = apply_softmax_to_labels(X_out).detach().to('cpu').numpy()
        X_any = X_any.detach().to('cpu').numpy()
        X_out = np.hstack([X_out, X_any])
        X_outs.append(X_out)

        y     = y.to('cpu').numpy()[:,:-1]
        y_any = y_any.to('cpu').numpy()
        y     = np.hstack([y, y_any])
        ys.append(y)

        trn_loss = avg_loss.item()      
        train_meters['loss'].update(trn_loss, n=batch_size)     
        #pbar.set_description(f'Train loss: {trn_loss}')   
        
        
    print('Epoch {:d} / trn/loss={:.4f}'.format(epoch+1, train_meters['loss'].avg))    

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'train')                 
    print('Epoch {:d} / train/metric={:.4f}'.format(epoch+1, metric))   

    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()
    torch.cuda.empty_cache()    
    return scheduler, scaler, optimizer


def valid_func(model, valid_loader, epoch):
    X_outs=[]
    ys=[]
    model.eval()
    for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
        batch_size = y.shape[0]
        X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
        X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
        y = y.to(DEVICE)           
        with torch.cuda.amp.autocast(enabled=True):                
            with torch.no_grad():                 
                X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
                y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
                X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

                X_any = X_any.to('cpu').numpy()
                X_out = np.hstack([X_out, X_any])
                X_outs.append(X_out)

                y     = y.to('cpu').numpy()[:,:-1]
                y_any = y_any.to('cpu').numpy()
                y     = np.hstack([y, y_any])
                ys.append(y)

    X_outs = np.vstack(X_outs) 
    ys     = np.vstack(ys)
    metric = calculate_score(X_outs, ys, 'valid')                
    print('Epoch {:d} / val/metric={:.4f}'.format(epoch+1, metric))           
    
    del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_outs, y, ys, X_any
    gc.collect()        
    torch.cuda.empty_cache()   
    return metric 

In [None]:
model = AbdominalClassifier()
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs")
    model = nn.DataParallel(model)  # This line enables multi-GPU training
model.to(DEVICE)

wandb.init(
    config = wandb_config,
    project= PROJECT_NAME,
    group  = GROUP_NAME,
    name   = RUN_NAME,
    dir    = BASE_PATH)

backbone = backbone.replace('/', '_')

if __name__ == '__main__':
    train_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']!=3], is_train = True, transform_set  = transforms_train, 
                                        remain_transforms_set = remain_transforms_train)
    valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['fold']==3], is_train = False, transform_set = None,
                                        remain_transforms_set = None)
    
    train_loader = DataLoader(dataset = train_dataset, shuffle = True, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)

    valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                            num_workers = N_WORKERS, drop_last = False)     
    
    ttl_iters = N_EPOCHS * len(train_loader)
    
    #gradient accumulation for stability of the training
    accum_len = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    accum_points = np.zeros(accum_len, int)
    accum_scale  = np.zeros(accum_len, int)
    
    prev_step = -1
    for i in range(0, accum_len):
        accum_points[i] = min(prev_step+ACCUM_STEPS, len(train_loader)-1)
        accum_scale[i]  = accum_points[i] - prev_step
        prev_step = accum_points[i]

    #Scheduler & optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr = LR)
    n_batch_iters = int(np.ceil(len(train_loader)/ACCUM_STEPS)+0.001)
    # scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, pct_start= PCT_START,
                                                    steps_per_epoch= n_batch_iters, epochs = N_EPOCHS)
    

    scaler = torch.cuda.amp.GradScaler(enabled=True)
    val_metrics = np.ones(N_EPOCHS)*100

    gc.collect()

    for epoch in tqdm(range(0, N_EPOCHS), leave = False):     
        
        scheduler, scaler, optimizer = train_func(model, train_loader, scaler, scheduler, optimizer, epoch, accum_points, accum_scale)
        metric                       = valid_func(model, valid_loader, epoch)
        
        gc.collect()
        torch.cuda.empty_cache()    
        #Save the best model    
        if(metric < np.min(val_metrics)):
            try:
                os.makedirs(f'{BASE_PATH}/weights')
            except:
                a = 1
            best_metric = metric
            print(f'Best val_metric {best_metric} at epoch {epoch+1}!')
            torch.save(model.state_dict(), f'{BASE_PATH}/weights/{backbone}_lr{LR}_epochs_{N_EPOCHS}_resol{UP_RESOL}_batch{BATCH_SIZE*ACCUM_STEPS}_fold1.pt')    
            not_improve_counter=0
            val_metrics[epoch] = metric
            continue                    
        val_metrics[epoch] = metric                        
        
        #Early stopping
        not_improve_counter+=1
        if(not_improve_counter == EARLY_STOP_COUNT):
            print(f'Not improved for {not_improve_counter} epochs, terminate the train')
            break
wandb.log({'best_total_log_loss': best_metric})
wandb.finish()

Using 2 GPUs


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3760/3760 [01:27<00:00, 43.16it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 951/951 [00:22<00:00, 42.73it/s]
  0%|                                                                                                           | 0/300 [00:00<?, ?it/s]

Epoch 1 / trn/loss=1.1546
Epoch 1 / train/metric=0.8901




Epoch 1 / val/metric=0.8652
Best val_metric 0.8651708001075941 at epoch 1!


  0%|▎                                                                                              | 1/300 [04:24<21:56:29, 264.18s/it]

Epoch 2 / trn/loss=1.0942
Epoch 2 / train/metric=0.8363




Epoch 2 / val/metric=0.8099
Best val_metric 0.8098588910064094 at epoch 2!


  1%|▋                                                                                              | 2/300 [08:47<21:48:26, 263.45s/it]

Epoch 3 / trn/loss=1.0376
Epoch 3 / train/metric=0.7855




Epoch 3 / val/metric=0.7566
Best val_metric 0.7565811091200624 at epoch 3!


  1%|▉                                                                                              | 3/300 [13:10<21:44:12, 263.48s/it]

Epoch 4 / trn/loss=0.9862
Epoch 4 / train/metric=0.7387




Epoch 4 / val/metric=0.7107
Best val_metric 0.7106522139702341 at epoch 4!


  1%|█▎                                                                                             | 4/300 [17:32<21:36:50, 262.87s/it]

Epoch 5 / trn/loss=0.9382
Epoch 5 / train/metric=0.6948




Epoch 5 / val/metric=0.6690
Best val_metric 0.669018545854443 at epoch 5!


  2%|█▌                                                                                             | 5/300 [21:54<21:31:36, 262.70s/it]

Epoch 6 / trn/loss=0.8973
Epoch 6 / train/metric=0.6569




Epoch 6 / val/metric=0.6319
Best val_metric 0.6318531958384861 at epoch 6!


  2%|█▉                                                                                             | 6/300 [26:16<21:25:42, 262.39s/it]

Epoch 7 / trn/loss=0.8648
Epoch 7 / train/metric=0.6264




Epoch 7 / val/metric=0.6074
Best val_metric 0.6074150297833403 at epoch 7!


  2%|██▏                                                                                            | 7/300 [30:40<21:23:29, 262.83s/it]

Epoch 8 / trn/loss=0.8427
Epoch 8 / train/metric=0.6065




Epoch 8 / val/metric=0.5927
Best val_metric 0.5926639128500232 at epoch 8!


  3%|██▌                                                                                            | 8/300 [35:01<21:16:51, 262.37s/it]

Epoch 9 / trn/loss=0.8294
Epoch 9 / train/metric=0.5929




Epoch 9 / val/metric=0.5786
Best val_metric 0.5785822024828636 at epoch 9!


  3%|██▊                                                                                            | 9/300 [39:24<21:13:03, 262.49s/it]

Epoch 10 / trn/loss=0.8175
Epoch 10 / train/metric=0.5824




Epoch 10 / val/metric=0.5725
Best val_metric 0.5725296776529669 at epoch 10!


  3%|███▏                                                                                          | 10/300 [43:47<21:08:37, 262.47s/it]

Epoch 11 / trn/loss=0.8088
Epoch 11 / train/metric=0.5749




Epoch 11 / val/metric=0.5676
Best val_metric 0.5675812783434294 at epoch 11!


  4%|███▍                                                                                          | 11/300 [48:08<21:02:54, 262.19s/it]

Epoch 12 / trn/loss=0.8056
Epoch 12 / train/metric=0.5708




Epoch 12 / val/metric=0.5624
Best val_metric 0.5623766491398288 at epoch 12!


  4%|███▊                                                                                          | 12/300 [52:32<21:01:13, 262.75s/it]

Epoch 13 / trn/loss=0.8043
Epoch 13 / train/metric=0.5711




Epoch 13 / val/metric=0.5608
Best val_metric 0.5607928079615071 at epoch 13!


  4%|████                                                                                          | 13/300 [56:54<20:55:25, 262.46s/it]

Epoch 14 / trn/loss=0.8000
Epoch 14 / train/metric=0.5663




Epoch 14 / val/metric=0.5580
Best val_metric 0.5579714389506719 at epoch 14!


  5%|████▎                                                                                       | 14/300 [1:01:17<20:51:19, 262.51s/it]

Epoch 15 / trn/loss=0.8023
Epoch 15 / train/metric=0.5666




Epoch 15 / val/metric=0.5569
Best val_metric 0.5568830744165502 at epoch 15!


  5%|████▌                                                                                       | 15/300 [1:05:38<20:45:53, 262.29s/it]

Epoch 16 / trn/loss=0.7941
Epoch 16 / train/metric=0.5598




Epoch 16 / val/metric=0.5592


  5%|████▉                                                                                       | 16/300 [1:10:00<20:40:25, 262.06s/it]

Epoch 17 / trn/loss=0.7968
Epoch 17 / train/metric=0.5642




Epoch 17 / val/metric=0.5545
Best val_metric 0.5544677202445699 at epoch 17!


  6%|█████▏                                                                                      | 17/300 [1:14:21<20:34:57, 261.83s/it]

Epoch 18 / trn/loss=0.7969
Epoch 18 / train/metric=0.5659




Epoch 18 / val/metric=0.5669


  6%|█████▌                                                                                      | 18/300 [1:18:44<20:32:24, 262.21s/it]

Epoch 19 / trn/loss=0.7905
Epoch 19 / train/metric=0.5579




Epoch 19 / val/metric=0.5575


  6%|█████▊                                                                                      | 19/300 [1:23:07<20:28:13, 262.26s/it]

Epoch 20 / trn/loss=0.7903
Epoch 20 / train/metric=0.5585




Epoch 20 / val/metric=0.5780


  7%|██████▏                                                                                     | 20/300 [1:27:29<20:23:55, 262.27s/it]

Epoch 21 / trn/loss=0.7865
Epoch 21 / train/metric=0.5560




Epoch 21 / val/metric=0.5641


  7%|██████▍                                                                                     | 21/300 [1:31:51<20:19:26, 262.25s/it]

Epoch 22 / trn/loss=0.7857
Epoch 22 / train/metric=0.5559




Epoch 22 / val/metric=0.5670


  7%|██████▋                                                                                     | 22/300 [1:36:13<20:15:00, 262.23s/it]

Epoch 23 / trn/loss=0.7863
Epoch 23 / train/metric=0.5545




Epoch 23 / val/metric=0.5493
Best val_metric 0.5493373358165324 at epoch 23!


  8%|███████                                                                                     | 23/300 [1:40:35<20:10:02, 262.10s/it]

Epoch 24 / trn/loss=0.7851
Epoch 24 / train/metric=0.5525




Epoch 24 / val/metric=0.5514


  8%|███████▎                                                                                    | 24/300 [1:44:57<20:05:34, 262.08s/it]

Epoch 25 / trn/loss=0.7804
Epoch 25 / train/metric=0.5476




Epoch 25 / val/metric=0.5487
Best val_metric 0.5487377602602158 at epoch 25!


  8%|███████▋                                                                                    | 25/300 [1:49:20<20:02:16, 262.31s/it]

Epoch 26 / trn/loss=0.7809
Epoch 26 / train/metric=0.5475




Epoch 26 / val/metric=0.5560


  9%|███████▉                                                                                    | 26/300 [1:53:42<19:57:11, 262.16s/it]

Epoch 27 / trn/loss=0.7829
Epoch 27 / train/metric=0.5502




Epoch 27 / val/metric=0.5769


  9%|████████▎                                                                                   | 27/300 [1:58:04<19:52:25, 262.07s/it]

Epoch 28 / trn/loss=0.7806
Epoch 28 / train/metric=0.5480




Epoch 28 / val/metric=0.5548


  9%|████████▌                                                                                   | 28/300 [2:02:26<19:48:16, 262.12s/it]

Epoch 29 / trn/loss=0.7781
Epoch 29 / train/metric=0.5477




Epoch 29 / val/metric=0.5554


 10%|████████▉                                                                                   | 29/300 [2:06:48<19:43:26, 262.02s/it]

Epoch 30 / trn/loss=0.7751
Epoch 30 / train/metric=0.5449




Epoch 30 / val/metric=0.5518


 10%|█████████▏                                                                                  | 30/300 [2:11:10<19:38:48, 261.96s/it]

Epoch 31 / trn/loss=0.7737
Epoch 31 / train/metric=0.5442




Epoch 31 / val/metric=0.5550


 10%|█████████▌                                                                                  | 31/300 [2:15:30<19:32:44, 261.58s/it]

Epoch 32 / trn/loss=0.7725
Epoch 32 / train/metric=0.5422




Epoch 32 / val/metric=0.5542


 11%|█████████▊                                                                                  | 32/300 [2:19:52<19:29:14, 261.77s/it]

Epoch 33 / trn/loss=0.7733
Epoch 33 / train/metric=0.5418




Epoch 33 / val/metric=0.5591


 11%|██████████                                                                                  | 33/300 [2:24:15<19:25:53, 262.00s/it]

Epoch 34 / trn/loss=0.7693
Epoch 34 / train/metric=0.5416




Epoch 34 / val/metric=0.5850


 11%|██████████▍                                                                                 | 34/300 [2:28:37<19:21:53, 262.08s/it]

Epoch 35 / trn/loss=0.7750
Epoch 35 / train/metric=0.5454




Epoch 35 / val/metric=0.5504


 12%|██████████▋                                                                                 | 35/300 [2:32:59<19:16:54, 261.94s/it]

Epoch 36 / trn/loss=0.7673
Epoch 36 / train/metric=0.5380




Epoch 36 / val/metric=0.5590


 12%|███████████                                                                                 | 36/300 [2:37:20<19:12:06, 261.84s/it]

Epoch 37 / trn/loss=0.7721
Epoch 37 / train/metric=0.5405




Epoch 37 / val/metric=0.5441
Best val_metric 0.5440717888251086 at epoch 37!


 12%|███████████▎                                                                                | 37/300 [2:41:43<19:08:18, 261.97s/it]

Epoch 38 / trn/loss=0.7624
Epoch 38 / train/metric=0.5362




Epoch 38 / val/metric=0.5559


 13%|███████████▋                                                                                | 38/300 [2:46:06<19:05:26, 262.32s/it]

Epoch 39 / trn/loss=0.7674
Epoch 39 / train/metric=0.5387




Epoch 39 / val/metric=0.5465


 13%|███████████▉                                                                                | 39/300 [2:50:27<19:00:11, 262.11s/it]

Epoch 40 / trn/loss=0.7623
Epoch 40 / train/metric=0.5345




Epoch 40 / val/metric=0.5741


 13%|████████████▎                                                                               | 40/300 [2:54:49<18:54:52, 261.89s/it]

Epoch 41 / trn/loss=0.7604
Epoch 41 / train/metric=0.5352




Epoch 41 / val/metric=0.5514


 14%|████████████▌                                                                               | 41/300 [2:59:11<18:50:14, 261.83s/it]

Epoch 42 / trn/loss=0.7578
Epoch 42 / train/metric=0.5320




Epoch 42 / val/metric=0.5482


 14%|████████████▉                                                                               | 42/300 [3:03:33<18:46:09, 261.90s/it]

Epoch 43 / trn/loss=0.7545
Epoch 43 / train/metric=0.5289




Epoch 43 / val/metric=0.5412
Best val_metric 0.5411532898091064 at epoch 43!


 14%|█████████████▏                                                                              | 43/300 [3:07:55<18:42:17, 262.01s/it]

Epoch 44 / trn/loss=0.7511
Epoch 44 / train/metric=0.5285




Epoch 44 / val/metric=0.5550


 15%|█████████████▍                                                                              | 44/300 [3:12:16<18:37:12, 261.84s/it]

Epoch 45 / trn/loss=0.7552
Epoch 45 / train/metric=0.5283




Epoch 45 / val/metric=0.5499


 15%|█████████████▊                                                                              | 45/300 [3:16:39<18:33:21, 261.97s/it]

Epoch 46 / trn/loss=0.7523
Epoch 46 / train/metric=0.5278




Epoch 46 / val/metric=0.5393
Best val_metric 0.5393036244538365 at epoch 46!


 15%|██████████████                                                                              | 46/300 [3:21:01<18:29:18, 262.04s/it]

Epoch 47 / trn/loss=0.7510
Epoch 47 / train/metric=0.5267




Epoch 47 / val/metric=0.5442


 16%|██████████████▍                                                                             | 47/300 [3:25:24<18:26:51, 262.50s/it]

Epoch 48 / trn/loss=0.7510
Epoch 48 / train/metric=0.5274




Epoch 48 / val/metric=0.5542


 16%|██████████████▋                                                                             | 48/300 [3:29:46<18:21:32, 262.27s/it]

Epoch 49 / trn/loss=0.7521
Epoch 49 / train/metric=0.5259




Epoch 49 / val/metric=0.5326
Best val_metric 0.5326098512121 at epoch 49!


 16%|███████████████                                                                             | 49/300 [3:34:08<18:16:46, 262.18s/it]

Epoch 50 / trn/loss=0.7537
Epoch 50 / train/metric=0.5279




Epoch 50 / val/metric=0.5455


 17%|███████████████▎                                                                            | 50/300 [3:38:29<18:11:13, 261.89s/it]

Epoch 51 / trn/loss=0.7479
Epoch 51 / train/metric=0.5257




Epoch 51 / val/metric=0.5442


 17%|███████████████▋                                                                            | 51/300 [3:42:52<18:07:26, 262.04s/it]

Epoch 52 / trn/loss=0.7510
Epoch 52 / train/metric=0.5269




Epoch 52 / val/metric=0.5714


 17%|███████████████▉                                                                            | 52/300 [3:47:14<18:03:22, 262.11s/it]

Epoch 53 / trn/loss=0.7454
Epoch 53 / train/metric=0.5233




Epoch 53 / val/metric=0.5607


 18%|████████████████▎                                                                           | 53/300 [3:51:38<18:01:12, 262.64s/it]

Epoch 54 / trn/loss=0.7499
Epoch 54 / train/metric=0.5250




Epoch 54 / val/metric=0.5596


 18%|████████████████▌                                                                           | 54/300 [3:56:00<17:55:43, 262.37s/it]

Epoch 55 / trn/loss=0.7452
Epoch 55 / train/metric=0.5253




Epoch 55 / val/metric=0.5380


 18%|████████████████▊                                                                           | 55/300 [4:00:21<17:50:46, 262.23s/it]

Epoch 56 / trn/loss=0.7450
Epoch 56 / train/metric=0.5225




Epoch 56 / val/metric=0.5572


 19%|█████████████████▏                                                                          | 56/300 [4:04:44<17:46:13, 262.19s/it]

Epoch 57 / trn/loss=0.7471
Epoch 57 / train/metric=0.5244




Epoch 57 / val/metric=0.5538


 19%|█████████████████▍                                                                          | 57/300 [4:09:06<17:41:46, 262.17s/it]

Epoch 58 / trn/loss=0.7452
Epoch 58 / train/metric=0.5224




Epoch 58 / val/metric=0.5389


 19%|█████████████████▊                                                                          | 58/300 [4:13:27<17:36:32, 261.95s/it]

Epoch 59 / trn/loss=0.7405
Epoch 59 / train/metric=0.5203




Epoch 59 / val/metric=0.5302
Best val_metric 0.5302132534030849 at epoch 59!


 20%|██████████████████                                                                          | 59/300 [4:17:49<17:32:18, 261.98s/it]

Epoch 60 / trn/loss=0.7495
Epoch 60 / train/metric=0.5258




Epoch 60 / val/metric=0.5448


 20%|██████████████████▍                                                                         | 60/300 [4:22:11<17:27:33, 261.89s/it]

Epoch 61 / trn/loss=0.7457
Epoch 61 / train/metric=0.5226




Epoch 61 / val/metric=0.5382


 20%|██████████████████▋                                                                         | 61/300 [4:26:33<17:23:33, 261.98s/it]

Epoch 62 / trn/loss=0.7314
Epoch 62 / train/metric=0.5150




Epoch 62 / val/metric=0.5567


 21%|███████████████████                                                                         | 62/300 [4:30:54<17:18:30, 261.81s/it]

Epoch 63 / trn/loss=0.7378
Epoch 63 / train/metric=0.5187




Epoch 63 / val/metric=0.5346


 21%|███████████████████▎                                                                        | 63/300 [4:35:18<17:15:47, 262.23s/it]

Epoch 64 / trn/loss=0.7398
Epoch 64 / train/metric=0.5193




Epoch 64 / val/metric=0.5329


 21%|███████████████████▋                                                                        | 64/300 [4:39:40<17:12:02, 262.38s/it]

Epoch 65 / trn/loss=0.7372
Epoch 65 / train/metric=0.5200




Epoch 65 / val/metric=0.5385


 22%|███████████████████▉                                                                        | 65/300 [4:44:03<17:08:12, 262.52s/it]

Epoch 66 / trn/loss=0.7428
Epoch 66 / train/metric=0.5213




Epoch 66 / val/metric=0.5305


 22%|████████████████████▏                                                                       | 66/300 [4:48:25<17:02:51, 262.27s/it]

Epoch 67 / trn/loss=0.7353
Epoch 67 / train/metric=0.5132




Epoch 67 / val/metric=0.5237
Best val_metric 0.5236593700465448 at epoch 67!


 22%|████████████████████▌                                                                       | 67/300 [4:52:46<16:57:28, 262.01s/it]

Epoch 68 / trn/loss=0.7375
Epoch 68 / train/metric=0.5157




Epoch 68 / val/metric=0.5590


 23%|████████████████████▊                                                                       | 68/300 [4:57:08<16:52:42, 261.91s/it]

Epoch 69 / trn/loss=0.7351
Epoch 69 / train/metric=0.5175




Epoch 69 / val/metric=0.5259


 23%|█████████████████████▏                                                                      | 69/300 [5:01:30<16:48:58, 262.07s/it]

Epoch 70 / trn/loss=0.7336
Epoch 70 / train/metric=0.5150




Epoch 70 / val/metric=0.5353


 23%|█████████████████████▍                                                                      | 70/300 [5:05:54<16:46:06, 262.46s/it]

Epoch 71 / trn/loss=0.7308
Epoch 71 / train/metric=0.5126




Epoch 71 / val/metric=0.5261


 24%|█████████████████████▊                                                                      | 71/300 [5:10:17<16:42:35, 262.69s/it]

Epoch 72 / trn/loss=0.7319
Epoch 72 / train/metric=0.5118




Epoch 72 / val/metric=0.5258


 24%|██████████████████████                                                                      | 72/300 [5:14:39<16:37:22, 262.47s/it]

Epoch 73 / trn/loss=0.7273
Epoch 73 / train/metric=0.5106




Epoch 73 / val/metric=0.5390


 24%|██████████████████████▍                                                                     | 73/300 [5:19:00<16:31:33, 262.09s/it]

Epoch 74 / trn/loss=0.7354
Epoch 74 / train/metric=0.5145




Epoch 74 / val/metric=0.5269


 25%|██████████████████████▋                                                                     | 74/300 [5:23:22<16:27:13, 262.10s/it]

Epoch 75 / trn/loss=0.7250
Epoch 75 / train/metric=0.5074




Epoch 75 / val/metric=0.5278


 25%|███████████████████████                                                                     | 75/300 [5:27:43<16:21:45, 261.80s/it]

Epoch 76 / trn/loss=0.7271
Epoch 76 / train/metric=0.5082




Epoch 76 / val/metric=0.5701


 25%|███████████████████████▎                                                                    | 76/300 [5:32:06<16:18:24, 262.08s/it]

Epoch 77 / trn/loss=0.7286
Epoch 77 / train/metric=0.5116




Epoch 77 / val/metric=0.5565


 26%|███████████████████████▌                                                                    | 77/300 [5:36:28<16:13:44, 261.99s/it]

Epoch 78 / trn/loss=0.7294
Epoch 78 / train/metric=0.5129




Epoch 78 / val/metric=0.5260


 26%|███████████████████████▉                                                                    | 78/300 [5:40:51<16:10:40, 262.35s/it]

Epoch 79 / trn/loss=0.7345
Epoch 79 / train/metric=0.5147




Epoch 79 / val/metric=0.5188
Best val_metric 0.5188134067186405 at epoch 79!


 26%|████████████████████████▏                                                                   | 79/300 [5:45:14<16:06:35, 262.42s/it]

Epoch 80 / trn/loss=0.7247
Epoch 80 / train/metric=0.5063




Epoch 80 / val/metric=0.5380


 27%|████████████████████████▌                                                                   | 80/300 [5:49:36<16:02:17, 262.44s/it]

Epoch 81 / trn/loss=0.7293
Epoch 81 / train/metric=0.5102




Epoch 81 / val/metric=0.5263


 27%|████████████████████████▊                                                                   | 81/300 [5:53:58<15:57:02, 262.20s/it]

Epoch 82 / trn/loss=0.7290
Epoch 82 / train/metric=0.5128




Epoch 82 / val/metric=0.5207


 27%|█████████████████████████▏                                                                  | 82/300 [5:58:19<15:51:57, 262.01s/it]

Epoch 83 / trn/loss=0.7307
Epoch 83 / train/metric=0.5134




Epoch 83 / val/metric=0.5217


 28%|█████████████████████████▍                                                                  | 83/300 [6:02:41<15:47:20, 261.94s/it]

Epoch 84 / trn/loss=0.7239
Epoch 84 / train/metric=0.5069




Epoch 84 / val/metric=0.5233


 28%|█████████████████████████▊                                                                  | 84/300 [6:07:03<15:42:45, 261.88s/it]

Epoch 85 / trn/loss=0.7196
Epoch 85 / train/metric=0.5045




Epoch 85 / val/metric=0.5197


 28%|██████████████████████████                                                                  | 85/300 [6:11:26<15:39:45, 262.26s/it]

Epoch 86 / trn/loss=0.7292
Epoch 86 / train/metric=0.5093




Epoch 86 / val/metric=0.5397


 29%|██████████████████████████▎                                                                 | 86/300 [6:15:48<15:35:18, 262.24s/it]

Epoch 87 / trn/loss=0.7334
Epoch 87 / train/metric=0.5129




Epoch 87 / val/metric=0.5158
Best val_metric 0.5157831883686291 at epoch 87!


 29%|██████████████████████████▋                                                                 | 87/300 [6:20:11<15:31:18, 262.34s/it]

Epoch 88 / trn/loss=0.7260
Epoch 88 / train/metric=0.5102




Epoch 88 / val/metric=0.5209


 29%|██████████████████████████▉                                                                 | 88/300 [6:24:32<15:26:05, 262.10s/it]

Epoch 89 / trn/loss=0.7251
Epoch 89 / train/metric=0.5101




Epoch 89 / val/metric=0.5311


 30%|███████████████████████████▎                                                                | 89/300 [6:28:55<15:21:57, 262.17s/it]

Epoch 90 / trn/loss=0.7149
Epoch 90 / train/metric=0.4995




Epoch 90 / val/metric=0.5347


 30%|███████████████████████████▌                                                                | 90/300 [6:33:18<15:19:01, 262.58s/it]

Epoch 91 / trn/loss=0.7201
Epoch 91 / train/metric=0.5040




Epoch 91 / val/metric=0.5299


 30%|███████████████████████████▉                                                                | 91/300 [6:37:39<15:12:40, 262.01s/it]

Epoch 92 / trn/loss=0.7190
Epoch 92 / train/metric=0.5040




Epoch 92 / val/metric=0.5314


 31%|████████████████████████████▏                                                               | 92/300 [6:42:02<15:09:36, 262.39s/it]

Epoch 93 / trn/loss=0.7200
Epoch 93 / train/metric=0.5054




Epoch 93 / val/metric=0.5280


 31%|████████████████████████████▌                                                               | 93/300 [6:46:23<15:03:27, 261.87s/it]

Epoch 94 / trn/loss=0.7224
Epoch 94 / train/metric=0.5072




Epoch 94 / val/metric=0.5325


 31%|████████████████████████████▊                                                               | 94/300 [6:50:44<14:58:40, 261.75s/it]

Epoch 95 / trn/loss=0.7238
Epoch 95 / train/metric=0.5059




Epoch 95 / val/metric=0.5418


 32%|█████████████████████████████▏                                                              | 95/300 [6:55:06<14:54:20, 261.76s/it]

Epoch 96 / trn/loss=0.7184
Epoch 96 / train/metric=0.5051




Epoch 96 / val/metric=0.5355


 32%|█████████████████████████████▍                                                              | 96/300 [6:59:29<14:50:41, 261.97s/it]

Epoch 97 / trn/loss=0.7201
Epoch 97 / train/metric=0.5046




Epoch 97 / val/metric=0.5249


 32%|█████████████████████████████▋                                                              | 97/300 [7:03:50<14:45:56, 261.86s/it]

Epoch 98 / trn/loss=0.7244
Epoch 98 / train/metric=0.5073




Epoch 98 / val/metric=0.5166


 33%|██████████████████████████████                                                              | 98/300 [7:08:12<14:41:36, 261.86s/it]

Epoch 99 / trn/loss=0.7167
Epoch 99 / train/metric=0.4990




Epoch 99 / val/metric=0.5144
Best val_metric 0.5144428932090376 at epoch 99!


 33%|██████████████████████████████▎                                                             | 99/300 [7:12:35<14:38:41, 262.29s/it]

Epoch 100 / trn/loss=0.7077
Epoch 100 / train/metric=0.4955




Epoch 100 / val/metric=0.5284


 33%|██████████████████████████████▎                                                            | 100/300 [7:16:58<14:34:39, 262.40s/it]

Epoch 101 / trn/loss=0.7161
Epoch 101 / train/metric=0.5030




Epoch 101 / val/metric=0.5182


 34%|██████████████████████████████▋                                                            | 101/300 [7:21:21<14:30:30, 262.47s/it]

Epoch 102 / trn/loss=0.7169
Epoch 102 / train/metric=0.5020




Epoch 102 / val/metric=0.5313


 34%|██████████████████████████████▉                                                            | 102/300 [7:25:42<14:25:15, 262.20s/it]

Epoch 103 / trn/loss=0.7153
Epoch 103 / train/metric=0.4995




Epoch 103 / val/metric=0.5174


 34%|███████████████████████████████▏                                                           | 103/300 [7:30:05<14:21:21, 262.34s/it]

Epoch 104 / trn/loss=0.7113
Epoch 104 / train/metric=0.4995




Epoch 104 / val/metric=0.5278


 35%|███████████████████████████████▌                                                           | 104/300 [7:34:29<14:18:36, 262.84s/it]

Epoch 105 / trn/loss=0.7190
Epoch 105 / train/metric=0.5037




Epoch 105 / val/metric=0.5220


 35%|███████████████████████████████▊                                                           | 105/300 [7:38:50<14:12:42, 262.37s/it]

Epoch 106 / trn/loss=0.7099
Epoch 106 / train/metric=0.4987




Epoch 106 / val/metric=0.5186


 35%|████████████████████████████████▏                                                          | 106/300 [7:43:12<14:07:57, 262.25s/it]

Epoch 107 / trn/loss=0.7119
Epoch 107 / train/metric=0.4991




Epoch 107 / val/metric=0.5217


 36%|████████████████████████████████▍                                                          | 107/300 [7:47:34<14:03:24, 262.20s/it]

Epoch 108 / trn/loss=0.7076
Epoch 108 / train/metric=0.4947




Epoch 108 / val/metric=0.5284


 36%|████████████████████████████████▊                                                          | 108/300 [7:51:55<13:57:48, 261.81s/it]

Epoch 109 / trn/loss=0.7077
Epoch 109 / train/metric=0.4959




Epoch 109 / val/metric=0.5161


 36%|█████████████████████████████████                                                          | 109/300 [7:56:17<13:53:40, 261.89s/it]

Epoch 110 / trn/loss=0.7168
Epoch 110 / train/metric=0.5009




Epoch 110 / val/metric=0.5238


 37%|█████████████████████████████████▎                                                         | 110/300 [8:00:41<13:51:14, 262.50s/it]

Epoch 111 / trn/loss=0.7115
Epoch 111 / train/metric=0.4982




Epoch 111 / val/metric=0.5253


 37%|█████████████████████████████████▋                                                         | 111/300 [8:05:04<13:47:07, 262.58s/it]

Epoch 112 / trn/loss=0.7080
Epoch 112 / train/metric=0.4959




Epoch 112 / val/metric=0.5137
Best val_metric 0.5136806356387583 at epoch 112!


 37%|█████████████████████████████████▉                                                         | 112/300 [8:09:27<13:43:36, 262.86s/it]

Epoch 113 / trn/loss=0.7118
Epoch 113 / train/metric=0.4969




Epoch 113 / val/metric=0.5265


 38%|██████████████████████████████████▎                                                        | 113/300 [8:13:50<13:38:56, 262.76s/it]

Epoch 114 / trn/loss=0.7076
Epoch 114 / train/metric=0.4952




Epoch 114 / val/metric=0.5482


 38%|██████████████████████████████████▌                                                        | 114/300 [8:18:13<13:34:44, 262.82s/it]

Epoch 115 / trn/loss=0.7095
Epoch 115 / train/metric=0.4963




Epoch 115 / val/metric=0.5378


 38%|██████████████████████████████████▉                                                        | 115/300 [8:22:37<13:31:50, 263.30s/it]

Epoch 116 / trn/loss=0.7011
Epoch 116 / train/metric=0.4926




Epoch 116 / val/metric=0.5190


 39%|███████████████████████████████████▏                                                       | 116/300 [8:26:59<13:26:20, 262.94s/it]

Epoch 117 / trn/loss=0.7035
Epoch 117 / train/metric=0.4933




Epoch 117 / val/metric=0.5429


 39%|███████████████████████████████████▍                                                       | 117/300 [8:31:22<13:21:41, 262.85s/it]

Epoch 118 / trn/loss=0.7164
Epoch 118 / train/metric=0.5021




Epoch 118 / val/metric=0.5197


 39%|███████████████████████████████████▊                                                       | 118/300 [8:35:44<13:16:22, 262.54s/it]

Epoch 119 / trn/loss=0.7081
Epoch 119 / train/metric=0.4944




Epoch 119 / val/metric=0.5225


 40%|████████████████████████████████████                                                       | 119/300 [8:40:06<13:11:14, 262.29s/it]

Epoch 120 / trn/loss=0.6997
Epoch 120 / train/metric=0.4908




Epoch 120 / val/metric=0.5162


 40%|████████████████████████████████████▍                                                      | 120/300 [8:44:28<13:07:13, 262.41s/it]

Epoch 121 / trn/loss=0.7049
Epoch 121 / train/metric=0.4925




Epoch 121 / val/metric=0.5367


 40%|████████████████████████████████████▋                                                      | 121/300 [8:48:51<13:03:32, 262.64s/it]

Epoch 122 / trn/loss=0.7068
Epoch 122 / train/metric=0.4953




Epoch 122 / val/metric=0.5325


 41%|█████████████████████████████████████                                                      | 122/300 [8:53:13<12:58:11, 262.31s/it]

Epoch 123 / trn/loss=0.6986
Epoch 123 / train/metric=0.4889




Epoch 123 / val/metric=0.5152


 41%|█████████████████████████████████████▎                                                     | 123/300 [8:57:35<12:53:44, 262.29s/it]

Epoch 124 / trn/loss=0.6969
Epoch 124 / train/metric=0.4896




Epoch 124 / val/metric=0.5145


 41%|█████████████████████████████████████▌                                                     | 124/300 [9:01:57<12:49:18, 262.26s/it]

Epoch 125 / trn/loss=0.7043
Epoch 125 / train/metric=0.4930




Epoch 125 / val/metric=0.5198


 42%|█████████████████████████████████████▉                                                     | 125/300 [9:06:19<12:44:03, 261.96s/it]

Epoch 126 / trn/loss=0.7012
Epoch 126 / train/metric=0.4895




Epoch 126 / val/metric=0.5071
Best val_metric 0.5071244031061604 at epoch 126!


 42%|██████████████████████████████████████▏                                                    | 126/300 [9:10:42<12:40:46, 262.34s/it]

Epoch 127 / trn/loss=0.7027
Epoch 127 / train/metric=0.4915




Epoch 127 / val/metric=0.5247


 42%|██████████████████████████████████████▌                                                    | 127/300 [9:15:04<12:36:25, 262.34s/it]

Epoch 128 / trn/loss=0.6939
Epoch 128 / train/metric=0.4853




Epoch 128 / val/metric=0.5115


 43%|██████████████████████████████████████▊                                                    | 128/300 [9:19:26<12:31:51, 262.28s/it]

Epoch 129 / trn/loss=0.7015
Epoch 129 / train/metric=0.4921




Epoch 129 / val/metric=0.5240


 43%|███████████████████████████████████████▏                                                   | 129/300 [9:23:50<12:28:56, 262.79s/it]

Epoch 130 / trn/loss=0.6996
Epoch 130 / train/metric=0.4895




Epoch 130 / val/metric=0.5172


 43%|███████████████████████████████████████▍                                                   | 130/300 [9:28:15<12:25:55, 263.26s/it]

Epoch 131 / trn/loss=0.7120
Epoch 131 / train/metric=0.4954




Epoch 131 / val/metric=0.5240


 44%|███████████████████████████████████████▋                                                   | 131/300 [9:32:37<12:20:24, 262.87s/it]

Epoch 132 / trn/loss=0.6983
Epoch 132 / train/metric=0.4881




Epoch 132 / val/metric=0.5272


 44%|████████████████████████████████████████                                                   | 132/300 [9:36:59<12:15:35, 262.71s/it]

Epoch 133 / trn/loss=0.6954
Epoch 133 / train/metric=0.4862




Epoch 133 / val/metric=0.5220


 44%|████████████████████████████████████████▎                                                  | 133/300 [9:41:21<12:10:41, 262.53s/it]

Epoch 134 / trn/loss=0.6917
Epoch 134 / train/metric=0.4819




Epoch 134 / val/metric=0.5199


 45%|████████████████████████████████████████▋                                                  | 134/300 [9:45:44<12:06:42, 262.67s/it]

Epoch 135 / trn/loss=0.7085
Epoch 135 / train/metric=0.4966




Epoch 135 / val/metric=0.5488


 45%|████████████████████████████████████████▉                                                  | 135/300 [9:50:07<12:02:41, 262.80s/it]

Epoch 136 / trn/loss=0.6973
Epoch 136 / train/metric=0.4908




Epoch 136 / val/metric=0.5221


 45%|█████████████████████████████████████████▎                                                 | 136/300 [9:54:33<12:00:44, 263.68s/it]

Epoch 137 / trn/loss=0.6934
Epoch 137 / train/metric=0.4856




Epoch 137 / val/metric=0.5319


 46%|█████████████████████████████████████████▌                                                 | 137/300 [9:59:00<11:59:00, 264.67s/it]

Epoch 138 / trn/loss=0.7045
Epoch 138 / train/metric=0.4959




Epoch 138 / val/metric=0.5190


 46%|█████████████████████████████████████████▍                                                | 138/300 [10:03:26<11:55:58, 265.18s/it]

Epoch 139 / trn/loss=0.6917
Epoch 139 / train/metric=0.4841




Epoch 139 / val/metric=0.5210


 46%|█████████████████████████████████████████▋                                                | 139/300 [10:07:53<11:52:42, 265.60s/it]

Epoch 140 / trn/loss=0.6915
Epoch 140 / train/metric=0.4859




Epoch 140 / val/metric=0.5148


 47%|██████████████████████████████████████████                                                | 140/300 [10:12:19<11:49:04, 265.90s/it]

Epoch 141 / trn/loss=0.6891
Epoch 141 / train/metric=0.4818




Epoch 141 / val/metric=0.5114


 47%|██████████████████████████████████████████▎                                               | 141/300 [10:16:53<11:50:30, 268.12s/it]

Epoch 142 / trn/loss=0.6928
Epoch 142 / train/metric=0.4855




Epoch 142 / val/metric=0.5250


 47%|██████████████████████████████████████████▌                                               | 142/300 [10:21:28<11:51:18, 270.12s/it]

Epoch 143 / trn/loss=0.6872
Epoch 143 / train/metric=0.4807




Epoch 143 / val/metric=0.5234


 48%|██████████████████████████████████████████▉                                               | 143/300 [10:25:54<11:43:54, 269.01s/it]

Epoch 144 / trn/loss=0.6982
Epoch 144 / train/metric=0.4881




Epoch 144 / val/metric=0.5279


 48%|███████████████████████████████████████████▏                                              | 144/300 [10:30:46<11:57:12, 275.85s/it]

Epoch 145 / trn/loss=0.6914
Epoch 145 / train/metric=0.4823




Epoch 145 / val/metric=0.5205


 48%|███████████████████████████████████████████▌                                              | 145/300 [10:35:17<11:48:42, 274.34s/it]

Epoch 146 / trn/loss=0.6949
Epoch 146 / train/metric=0.4872




Epoch 146 / val/metric=0.5174


 49%|███████████████████████████████████████████▊                                              | 146/300 [10:39:39<11:35:02, 270.80s/it]

Epoch 147 / trn/loss=0.6913
Epoch 147 / train/metric=0.4861




Epoch 147 / val/metric=0.5283


 49%|████████████████████████████████████████████                                              | 147/300 [10:44:02<11:24:07, 268.28s/it]

Epoch 148 / trn/loss=0.6904
Epoch 148 / train/metric=0.4847




Epoch 148 / val/metric=0.5305


 49%|████████████████████████████████████████████▍                                             | 148/300 [10:48:26<11:16:38, 267.09s/it]

Epoch 149 / trn/loss=0.6916
Epoch 149 / train/metric=0.4848




Epoch 149 / val/metric=0.5100


 50%|████████████████████████████████████████████▋                                             | 149/300 [10:52:49<11:09:13, 265.91s/it]

Epoch 150 / trn/loss=0.6918
Epoch 150 / train/metric=0.4835




Epoch 150 / val/metric=0.5268


 50%|█████████████████████████████████████████████                                             | 150/300 [10:57:12<11:02:29, 265.00s/it]

Epoch 151 / trn/loss=0.6882
Epoch 151 / train/metric=0.4820




Epoch 151 / val/metric=0.5164


 50%|█████████████████████████████████████████████▎                                            | 151/300 [11:01:35<10:56:24, 264.33s/it]

Epoch 152 / trn/loss=0.6910
Epoch 152 / train/metric=0.4832




Epoch 152 / val/metric=0.5155


 51%|█████████████████████████████████████████████▌                                            | 152/300 [11:06:03<10:55:20, 265.68s/it]

Epoch 153 / trn/loss=0.6875
Epoch 153 / train/metric=0.4799




Epoch 153 / val/metric=0.5183


 51%|█████████████████████████████████████████████▉                                            | 153/300 [11:10:50<11:06:29, 272.03s/it]

Epoch 154 / trn/loss=0.6849
Epoch 154 / train/metric=0.4772




Epoch 154 / val/metric=0.5167


 51%|██████████████████████████████████████████████▏                                           | 154/300 [11:15:44<11:18:06, 278.68s/it]

Epoch 155 / trn/loss=0.6953
Epoch 155 / train/metric=0.4878




Epoch 155 / val/metric=0.5108


 52%|██████████████████████████████████████████████▌                                           | 155/300 [11:20:39<11:24:45, 283.35s/it]

Epoch 156 / trn/loss=0.6898
Epoch 156 / train/metric=0.4822




Epoch 156 / val/metric=0.5239


 52%|██████████████████████████████████████████████▊                                           | 156/300 [11:25:30<11:25:51, 285.78s/it]

Epoch 157 / trn/loss=0.6887
Epoch 157 / train/metric=0.4816




Epoch 157 / val/metric=0.5196


 52%|███████████████████████████████████████████████                                           | 157/300 [11:30:25<11:27:24, 288.42s/it]

Epoch 158 / trn/loss=0.6786
Epoch 158 / train/metric=0.4752




Epoch 158 / val/metric=0.5229


 53%|███████████████████████████████████████████████▍                                          | 158/300 [11:35:18<11:26:05, 289.90s/it]

Epoch 159 / trn/loss=0.6820
Epoch 159 / train/metric=0.4770




Epoch 159 / val/metric=0.5229


 53%|███████████████████████████████████████████████▋                                          | 159/300 [11:40:12<11:24:01, 291.08s/it]

Epoch 160 / trn/loss=0.6819
Epoch 160 / train/metric=0.4777




Epoch 160 / val/metric=0.5315


 53%|████████████████████████████████████████████████                                          | 160/300 [11:45:07<11:21:41, 292.15s/it]

Epoch 161 / trn/loss=0.6833
Epoch 161 / train/metric=0.4758




Epoch 161 / val/metric=0.5185


 54%|████████████████████████████████████████████████▎                                         | 161/300 [11:49:59<11:17:00, 292.23s/it]

Epoch 162 / trn/loss=0.6821
Epoch 162 / train/metric=0.4773




Epoch 162 / val/metric=0.5189


 54%|████████████████████████████████████████████████▌                                         | 162/300 [11:54:51<11:12:10, 292.25s/it]

Epoch 163 / trn/loss=0.6812
Epoch 163 / train/metric=0.4782




Epoch 163 / val/metric=0.5189


 54%|████████████████████████████████████████████████▉                                         | 163/300 [11:59:44<11:07:49, 292.47s/it]

Epoch 164 / trn/loss=0.6871
Epoch 164 / train/metric=0.4808




Epoch 164 / val/metric=0.5324


 55%|█████████████████████████████████████████████████▏                                        | 164/300 [12:04:09<10:44:14, 284.23s/it]

Epoch 165 / trn/loss=0.6805
Epoch 165 / train/metric=0.4754




Epoch 165 / val/metric=0.5122


 55%|█████████████████████████████████████████████████▌                                        | 165/300 [12:08:34<10:26:23, 278.40s/it]

Epoch 166 / trn/loss=0.6795
Epoch 166 / train/metric=0.4743




Epoch 166 / val/metric=0.5327


 55%|█████████████████████████████████████████████████▊                                        | 166/300 [12:12:58<10:12:19, 274.18s/it]

Epoch 167 / trn/loss=0.6766
Epoch 167 / train/metric=0.4748




Epoch 167 / val/metric=0.5242


 56%|██████████████████████████████████████████████████                                        | 167/300 [12:17:22<10:00:39, 270.98s/it]

Epoch 168 / trn/loss=0.6779
Epoch 168 / train/metric=0.4748




Epoch 168 / val/metric=0.5109


 56%|██████████████████████████████████████████████████▉                                        | 168/300 [12:21:45<9:51:14, 268.75s/it]

Epoch 169 / trn/loss=0.6754
Epoch 169 / train/metric=0.4720




Epoch 169 / val/metric=0.5098


 56%|███████████████████████████████████████████████████▎                                       | 169/300 [12:26:10<9:43:49, 267.40s/it]

Epoch 170 / trn/loss=0.6787
Epoch 170 / train/metric=0.4733




Epoch 170 / val/metric=0.5074


 57%|███████████████████████████████████████████████████▌                                       | 170/300 [12:30:35<9:38:01, 266.78s/it]

Epoch 171 / trn/loss=0.6731
Epoch 171 / train/metric=0.4712




Epoch 171 / val/metric=0.5130


 57%|███████████████████████████████████████████████████▊                                       | 171/300 [12:35:00<9:32:06, 266.10s/it]

Epoch 172 / trn/loss=0.6702
Epoch 172 / train/metric=0.4687




Epoch 172 / val/metric=0.5082


 57%|████████████████████████████████████████████████████▏                                      | 172/300 [12:39:23<9:26:04, 265.35s/it]

Epoch 173 / trn/loss=0.6737
Epoch 173 / train/metric=0.4711




Epoch 173 / val/metric=0.5049
Best val_metric 0.5048790920778888 at epoch 173!


 58%|████████████████████████████████████████████████████▍                                      | 173/300 [12:43:47<9:20:36, 264.85s/it]

Epoch 174 / trn/loss=0.6734
Epoch 174 / train/metric=0.4720




Epoch 174 / val/metric=0.5223


 58%|████████████████████████████████████████████████████▊                                      | 174/300 [12:48:14<9:17:26, 265.45s/it]

Epoch 175 / trn/loss=0.6738
Epoch 175 / train/metric=0.4724




Epoch 175 / val/metric=0.5217


 58%|█████████████████████████████████████████████████████                                      | 175/300 [12:52:39<9:12:42, 265.30s/it]

Epoch 176 / trn/loss=0.6709
Epoch 176 / train/metric=0.4689




Epoch 176 / val/metric=0.5093


 59%|█████████████████████████████████████████████████████▍                                     | 176/300 [12:57:02<9:06:51, 264.61s/it]

Epoch 177 / trn/loss=0.6757
Epoch 177 / train/metric=0.4731




Epoch 177 / val/metric=0.5245


 59%|█████████████████████████████████████████████████████▋                                     | 177/300 [13:01:26<9:01:58, 264.38s/it]

Epoch 178 / trn/loss=0.6724
Epoch 178 / train/metric=0.4709




Epoch 178 / val/metric=0.5165


 59%|█████████████████████████████████████████████████████▉                                     | 178/300 [13:05:50<8:57:31, 264.35s/it]

Epoch 179 / trn/loss=0.6725
Epoch 179 / train/metric=0.4712




Epoch 179 / val/metric=0.5235


 60%|██████████████████████████████████████████████████████▎                                    | 179/300 [13:10:15<8:53:43, 264.66s/it]

Epoch 180 / trn/loss=0.6809
Epoch 180 / train/metric=0.4747




Epoch 180 / val/metric=0.5147


 60%|██████████████████████████████████████████████████████▌                                    | 180/300 [13:14:40<8:49:17, 264.65s/it]

Epoch 181 / trn/loss=0.6710
Epoch 181 / train/metric=0.4690




Epoch 181 / val/metric=0.5125


 60%|██████████████████████████████████████████████████████▉                                    | 181/300 [13:19:05<8:45:17, 264.85s/it]

Epoch 182 / trn/loss=0.6702
Epoch 182 / train/metric=0.4691




Epoch 182 / val/metric=0.5279


 61%|███████████████████████████████████████████████████████▏                                   | 182/300 [13:23:29<8:40:06, 264.46s/it]

Epoch 183 / trn/loss=0.6675
Epoch 183 / train/metric=0.4652




Epoch 183 / val/metric=0.5063


 61%|███████████████████████████████████████████████████████▌                                   | 183/300 [13:27:52<8:35:06, 264.16s/it]

Epoch 184 / trn/loss=0.6732
Epoch 184 / train/metric=0.4709




Epoch 184 / val/metric=0.5107


 61%|███████████████████████████████████████████████████████▊                                   | 184/300 [13:32:16<8:30:27, 264.03s/it]

Epoch 185 / trn/loss=0.6673
Epoch 185 / train/metric=0.4675




Epoch 185 / val/metric=0.5036
Best val_metric 0.5035819623789052 at epoch 185!


 62%|████████████████████████████████████████████████████████                                   | 185/300 [13:36:42<8:27:29, 264.78s/it]

Epoch 186 / trn/loss=0.6691
Epoch 186 / train/metric=0.4656




Epoch 186 / val/metric=0.5047


 62%|████████████████████████████████████████████████████████▍                                  | 186/300 [13:41:06<8:22:22, 264.41s/it]

Epoch 187 / trn/loss=0.6719
Epoch 187 / train/metric=0.4720




Epoch 187 / val/metric=0.5052


 62%|████████████████████████████████████████████████████████▋                                  | 187/300 [13:45:30<8:17:48, 264.32s/it]

Epoch 188 / trn/loss=0.6710
Epoch 188 / train/metric=0.4697




Epoch 188 / val/metric=0.5194


 63%|█████████████████████████████████████████████████████████                                  | 188/300 [13:49:54<8:13:15, 264.25s/it]

Epoch 189 / trn/loss=0.6608
Epoch 189 / train/metric=0.4618




Epoch 189 / val/metric=0.5099


 63%|█████████████████████████████████████████████████████████▎                                 | 189/300 [13:54:18<8:08:23, 263.99s/it]

Epoch 190 / trn/loss=0.6675
Epoch 190 / train/metric=0.4657




Epoch 190 / val/metric=0.5067


 63%|█████████████████████████████████████████████████████████▋                                 | 190/300 [13:58:42<8:04:09, 264.08s/it]

Epoch 191 / trn/loss=0.6672
Epoch 191 / train/metric=0.4662




Epoch 191 / val/metric=0.5039
Epoch 192 / trn/loss=0.6673
Epoch 192 / train/metric=0.4656




Epoch 192 / val/metric=0.5064


 64%|██████████████████████████████████████████████████████████▏                                | 192/300 [14:07:30<7:55:25, 264.12s/it]

Epoch 193 / trn/loss=0.6633
Epoch 193 / train/metric=0.4625




Epoch 193 / val/metric=0.5178


 64%|██████████████████████████████████████████████████████████▌                                | 193/300 [14:11:54<7:50:44, 263.96s/it]

Epoch 194 / trn/loss=0.6728
Epoch 194 / train/metric=0.4701




Epoch 194 / val/metric=0.5103


 65%|██████████████████████████████████████████████████████████▊                                | 194/300 [14:16:17<7:46:09, 263.86s/it]

Epoch 195 / trn/loss=0.6690
Epoch 195 / train/metric=0.4649




Epoch 195 / val/metric=0.5103


 65%|███████████████████████████████████████████████████████████▏                               | 195/300 [14:20:41<7:41:30, 263.72s/it]

Epoch 196 / trn/loss=0.6619
Epoch 196 / train/metric=0.4619




Epoch 196 / val/metric=0.5170


 65%|███████████████████████████████████████████████████████████▍                               | 196/300 [14:25:05<7:37:21, 263.86s/it]

Epoch 197 / trn/loss=0.6576
Epoch 197 / train/metric=0.4583




Epoch 197 / val/metric=0.5242


 66%|███████████████████████████████████████████████████████████▊                               | 197/300 [14:29:29<7:33:06, 263.95s/it]

Epoch 198 / trn/loss=0.6581
Epoch 198 / train/metric=0.4613




Epoch 198 / val/metric=0.5169


 66%|████████████████████████████████████████████████████████████                               | 198/300 [14:33:53<7:28:32, 263.85s/it]

Epoch 199 / trn/loss=0.6611
Epoch 199 / train/metric=0.4619




Epoch 199 / val/metric=0.5222


 66%|████████████████████████████████████████████████████████████▎                              | 199/300 [14:38:18<7:25:06, 264.43s/it]

Epoch 200 / trn/loss=0.6628
Epoch 200 / train/metric=0.4638




Epoch 200 / val/metric=0.5182


 67%|████████████████████████████████████████████████████████████▋                              | 200/300 [14:42:42<7:20:19, 264.19s/it]

Epoch 201 / trn/loss=0.6613
Epoch 201 / train/metric=0.4632




Epoch 201 / val/metric=0.5179


 67%|████████████████████████████████████████████████████████████▉                              | 201/300 [14:47:05<7:15:13, 263.78s/it]

Epoch 202 / trn/loss=0.6642
Epoch 202 / train/metric=0.4636




Epoch 202 / val/metric=0.5076


 67%|█████████████████████████████████████████████████████████████▎                             | 202/300 [14:51:29<7:11:11, 264.00s/it]

Epoch 203 / trn/loss=0.6607
Epoch 203 / train/metric=0.4630




Epoch 203 / val/metric=0.5135


 68%|█████████████████████████████████████████████████████████████▌                             | 203/300 [14:55:54<7:07:18, 264.31s/it]

Epoch 204 / trn/loss=0.6561
Epoch 204 / train/metric=0.4598




Epoch 204 / val/metric=0.5089


 68%|█████████████████████████████████████████████████████████████▉                             | 204/300 [15:00:18<7:02:46, 264.24s/it]

Epoch 205 / trn/loss=0.6620
Epoch 205 / train/metric=0.4606




Epoch 205 / val/metric=0.5107


 68%|██████████████████████████████████████████████████████████████▏                            | 205/300 [15:04:42<6:58:14, 264.15s/it]

Epoch 206 / trn/loss=0.6586
Epoch 206 / train/metric=0.4589




Epoch 206 / val/metric=0.5094


 69%|██████████████████████████████████████████████████████████████▍                            | 206/300 [15:09:07<6:54:09, 264.36s/it]

Epoch 207 / trn/loss=0.6614
Epoch 207 / train/metric=0.4607




Epoch 207 / val/metric=0.5092


 69%|██████████████████████████████████████████████████████████████▊                            | 207/300 [15:13:31<6:49:28, 264.17s/it]

Epoch 208 / trn/loss=0.6565
Epoch 208 / train/metric=0.4584




Epoch 208 / val/metric=0.5132


 69%|███████████████████████████████████████████████████████████████                            | 208/300 [15:17:55<6:44:58, 264.11s/it]

Epoch 209 / trn/loss=0.6577
Epoch 209 / train/metric=0.4602




Epoch 209 / val/metric=0.5048


 70%|███████████████████████████████████████████████████████████████▍                           | 209/300 [15:22:18<6:40:10, 263.85s/it]

Epoch 210 / trn/loss=0.6519
Epoch 210 / train/metric=0.4562




Epoch 210 / val/metric=0.5175


 70%|███████████████████████████████████████████████████████████████▋                           | 210/300 [15:26:42<6:35:57, 263.97s/it]

Epoch 211 / trn/loss=0.6555
Epoch 211 / train/metric=0.4585




Epoch 211 / val/metric=0.5138


 70%|████████████████████████████████████████████████████████████████                           | 211/300 [15:31:06<6:31:11, 263.73s/it]

Epoch 212 / trn/loss=0.6469
Epoch 212 / train/metric=0.4514




Epoch 212 / val/metric=0.5065


 71%|████████████████████████████████████████████████████████████████▎                          | 212/300 [15:35:29<6:26:35, 263.59s/it]

Epoch 213 / trn/loss=0.6499
Epoch 213 / train/metric=0.4523




Epoch 213 / val/metric=0.5024
Best val_metric 0.5023965168958828 at epoch 213!


 71%|████████████████████████████████████████████████████████████████▌                          | 213/300 [15:39:54<6:22:54, 264.07s/it]

Epoch 214 / trn/loss=0.6500
Epoch 214 / train/metric=0.4548




Epoch 214 / val/metric=0.5039


 71%|████████████████████████████████████████████████████████████████▉                          | 214/300 [15:44:20<6:19:28, 264.75s/it]

Epoch 215 / trn/loss=0.6482
Epoch 215 / train/metric=0.4538




Epoch 215 / val/metric=0.5060


 72%|█████████████████████████████████████████████████████████████████▏                         | 215/300 [15:48:44<6:14:29, 264.35s/it]

Epoch 216 / trn/loss=0.6588
Epoch 216 / train/metric=0.4589




Epoch 216 / val/metric=0.5064


 72%|█████████████████████████████████████████████████████████████████▌                         | 216/300 [15:53:07<6:09:30, 263.93s/it]

Epoch 217 / trn/loss=0.6488
Epoch 217 / train/metric=0.4531




Epoch 217 / val/metric=0.5130


 72%|█████████████████████████████████████████████████████████████████▊                         | 217/300 [15:57:34<6:06:39, 265.06s/it]

Epoch 218 / trn/loss=0.6492
Epoch 218 / train/metric=0.4518




Epoch 218 / val/metric=0.5065


 73%|██████████████████████████████████████████████████████████████████▏                        | 218/300 [16:01:58<6:01:39, 264.63s/it]

Epoch 219 / trn/loss=0.6509
Epoch 219 / train/metric=0.4544




Epoch 219 / val/metric=0.5148


 73%|██████████████████████████████████████████████████████████████████▍                        | 219/300 [16:06:21<5:56:33, 264.12s/it]

Epoch 220 / trn/loss=0.6598
Epoch 220 / train/metric=0.4593




Epoch 220 / val/metric=0.5110


 73%|██████████████████████████████████████████████████████████████████▋                        | 220/300 [16:10:44<5:51:33, 263.67s/it]

Epoch 221 / trn/loss=0.6540
Epoch 221 / train/metric=0.4581




Epoch 221 / val/metric=0.5074


 74%|███████████████████████████████████████████████████████████████████                        | 221/300 [16:15:07<5:47:03, 263.58s/it]

Epoch 222 / trn/loss=0.6490
Epoch 222 / train/metric=0.4518




Epoch 222 / val/metric=0.5087


 74%|███████████████████████████████████████████████████████████████████▎                       | 222/300 [16:19:31<5:42:55, 263.78s/it]

Epoch 223 / trn/loss=0.6585
Epoch 223 / train/metric=0.4592




Epoch 223 / val/metric=0.5031


 74%|███████████████████████████████████████████████████████████████████▋                       | 223/300 [16:23:56<5:38:57, 264.12s/it]

Epoch 224 / trn/loss=0.6535
Epoch 224 / train/metric=0.4543




Epoch 224 / val/metric=0.5047


 75%|███████████████████████████████████████████████████████████████████▉                       | 224/300 [16:28:20<5:34:19, 263.94s/it]

Epoch 225 / trn/loss=0.6494
Epoch 225 / train/metric=0.4555




Epoch 225 / val/metric=0.5087


 75%|████████████████████████████████████████████████████████████████████▎                      | 225/300 [16:32:44<5:30:14, 264.19s/it]

Epoch 226 / trn/loss=0.6525
Epoch 226 / train/metric=0.4570




Epoch 226 / val/metric=0.5074


 75%|████████████████████████████████████████████████████████████████████▌                      | 226/300 [16:37:07<5:25:20, 263.80s/it]

Epoch 227 / trn/loss=0.6569
Epoch 227 / train/metric=0.4592




Epoch 227 / val/metric=0.5183


 76%|████████████████████████████████████████████████████████████████████▊                      | 227/300 [16:41:32<5:21:22, 264.14s/it]

Epoch 228 / trn/loss=0.6481
Epoch 228 / train/metric=0.4530




Epoch 228 / val/metric=0.5065


 76%|█████████████████████████████████████████████████████████████████████▏                     | 228/300 [16:45:55<5:16:22, 263.65s/it]

Epoch 229 / trn/loss=0.6555
Epoch 229 / train/metric=0.4580




Epoch 229 / val/metric=0.5147


 76%|█████████████████████████████████████████████████████████████████████▍                     | 229/300 [16:50:18<5:11:44, 263.44s/it]

Epoch 230 / trn/loss=0.6462
Epoch 230 / train/metric=0.4497




Epoch 230 / val/metric=0.5039


 77%|█████████████████████████████████████████████████████████████████████▊                     | 230/300 [16:54:41<5:07:23, 263.48s/it]

Epoch 231 / trn/loss=0.6529
Epoch 231 / train/metric=0.4553




Epoch 231 / val/metric=0.5108


 77%|██████████████████████████████████████████████████████████████████████                     | 231/300 [16:59:05<5:02:58, 263.46s/it]

Epoch 232 / trn/loss=0.6450
Epoch 232 / train/metric=0.4507




Epoch 232 / val/metric=0.5094


 77%|██████████████████████████████████████████████████████████████████████▎                    | 232/300 [17:03:28<4:58:30, 263.40s/it]

Epoch 233 / trn/loss=0.6504
Epoch 233 / train/metric=0.4531




Epoch 233 / val/metric=0.5063


 78%|██████████████████████████████████████████████████████████████████████▋                    | 233/300 [17:07:52<4:54:26, 263.67s/it]

Epoch 234 / trn/loss=0.6467
Epoch 234 / train/metric=0.4513




Epoch 234 / val/metric=0.5080


 78%|██████████████████████████████████████████████████████████████████████▉                    | 234/300 [17:12:15<4:49:47, 263.45s/it]

Epoch 235 / trn/loss=0.6488
Epoch 235 / train/metric=0.4538




Epoch 235 / val/metric=0.5185


 78%|███████████████████████████████████████████████████████████████████████▎                   | 235/300 [17:16:40<4:45:41, 263.72s/it]

Epoch 236 / trn/loss=0.6498
Epoch 236 / train/metric=0.4518




Epoch 236 / val/metric=0.5103


 79%|███████████████████████████████████████████████████████████████████████▌                   | 236/300 [17:21:03<4:41:07, 263.55s/it]

Epoch 237 / trn/loss=0.6489
Epoch 237 / train/metric=0.4546




Epoch 237 / val/metric=0.5167


 79%|███████████████████████████████████████████████████████████████████████▉                   | 237/300 [17:25:26<4:36:47, 263.61s/it]

Epoch 238 / trn/loss=0.6442
Epoch 238 / train/metric=0.4489




Epoch 238 / val/metric=0.5164


 79%|████████████████████████████████████████████████████████████████████████▏                  | 238/300 [17:29:51<4:32:44, 263.94s/it]

Epoch 239 / trn/loss=0.6388
Epoch 239 / train/metric=0.4463




Epoch 239 / val/metric=0.5068


 80%|████████████████████████████████████████████████████████████████████████▍                  | 239/300 [17:34:16<4:28:28, 264.08s/it]

Epoch 240 / trn/loss=0.6418
Epoch 240 / train/metric=0.4479




Epoch 240 / val/metric=0.5067


 80%|████████████████████████████████████████████████████████████████████████▊                  | 240/300 [17:38:40<4:24:15, 264.25s/it]

Epoch 241 / trn/loss=0.6444
Epoch 241 / train/metric=0.4493




Epoch 241 / val/metric=0.5099


 80%|█████████████████████████████████████████████████████████████████████████                  | 241/300 [17:43:04<4:19:41, 264.09s/it]

Epoch 242 / trn/loss=0.6455
Epoch 242 / train/metric=0.4519




Epoch 242 / val/metric=0.5085


 81%|█████████████████████████████████████████████████████████████████████████▍                 | 242/300 [17:47:29<4:15:27, 264.27s/it]

Epoch 243 / trn/loss=0.6446
Epoch 243 / train/metric=0.4501




Epoch 243 / val/metric=0.5073


 81%|█████████████████████████████████████████████████████████████████████████▋                 | 243/300 [17:51:52<4:10:44, 263.94s/it]

Epoch 244 / trn/loss=0.6473
Epoch 244 / train/metric=0.4516




Epoch 244 / val/metric=0.5075


 81%|██████████████████████████████████████████████████████████████████████████                 | 244/300 [17:56:16<4:06:24, 264.02s/it]

Epoch 245 / trn/loss=0.6396
Epoch 245 / train/metric=0.4467




Epoch 245 / val/metric=0.5049


 82%|██████████████████████████████████████████████████████████████████████████▎                | 245/300 [18:00:41<4:02:15, 264.29s/it]

Epoch 246 / trn/loss=0.6382
Epoch 246 / train/metric=0.4460




Epoch 246 / val/metric=0.5062


 82%|██████████████████████████████████████████████████████████████████████████▌                | 246/300 [18:05:05<3:57:48, 264.23s/it]

Epoch 247 / trn/loss=0.6433
Epoch 247 / train/metric=0.4487




Epoch 247 / val/metric=0.5011
Best val_metric 0.5011127645624668 at epoch 247!


 82%|██████████████████████████████████████████████████████████████████████████▉                | 247/300 [18:09:29<3:53:27, 264.30s/it]

Epoch 248 / trn/loss=0.6432
Epoch 248 / train/metric=0.4487




Epoch 248 / val/metric=0.5028


 83%|███████████████████████████████████████████████████████████████████████████▏               | 248/300 [18:13:55<3:49:18, 264.59s/it]

Epoch 249 / trn/loss=0.6454
Epoch 249 / train/metric=0.4498




Epoch 249 / val/metric=0.5042


 83%|███████████████████████████████████████████████████████████████████████████▌               | 249/300 [18:18:19<3:44:41, 264.34s/it]

Epoch 250 / trn/loss=0.6400
Epoch 250 / train/metric=0.4476




Epoch 250 / val/metric=0.5048


 83%|███████████████████████████████████████████████████████████████████████████▊               | 250/300 [18:22:42<3:40:00, 264.00s/it]

Epoch 251 / trn/loss=0.6445
Epoch 251 / train/metric=0.4504




Epoch 251 / val/metric=0.5046


 84%|████████████████████████████████████████████████████████████████████████████▏              | 251/300 [18:27:04<3:35:15, 263.59s/it]

Epoch 252 / trn/loss=0.6384
Epoch 252 / train/metric=0.4469




Epoch 252 / val/metric=0.5041


 84%|████████████████████████████████████████████████████████████████████████████▍              | 252/300 [18:31:27<3:30:45, 263.46s/it]

Epoch 253 / trn/loss=0.6426
Epoch 253 / train/metric=0.4494




Epoch 253 / val/metric=0.5084


 84%|████████████████████████████████████████████████████████████████████████████▋              | 253/300 [18:35:51<3:26:17, 263.35s/it]

Epoch 254 / trn/loss=0.6410
Epoch 254 / train/metric=0.4476




Epoch 254 / val/metric=0.5103


 85%|█████████████████████████████████████████████████████████████████████████████              | 254/300 [18:40:16<3:22:20, 263.92s/it]

Epoch 255 / trn/loss=0.6415
Epoch 255 / train/metric=0.4481




Epoch 255 / val/metric=0.5062


 85%|█████████████████████████████████████████████████████████████████████████████▎             | 255/300 [18:44:39<3:17:40, 263.57s/it]

Epoch 256 / trn/loss=0.6390
Epoch 256 / train/metric=0.4467




Epoch 256 / val/metric=0.5070


 85%|█████████████████████████████████████████████████████████████████████████████▋             | 256/300 [18:49:02<3:13:16, 263.57s/it]

Epoch 257 / trn/loss=0.6425
Epoch 257 / train/metric=0.4490




Epoch 257 / val/metric=0.5080


 86%|█████████████████████████████████████████████████████████████████████████████▉             | 257/300 [18:53:25<3:08:38, 263.23s/it]

Epoch 258 / trn/loss=0.6451
Epoch 258 / train/metric=0.4523




Epoch 258 / val/metric=0.5149


 86%|██████████████████████████████████████████████████████████████████████████████▎            | 258/300 [18:57:48<3:04:14, 263.20s/it]

Epoch 259 / trn/loss=0.6437
Epoch 259 / train/metric=0.4497




Epoch 259 / val/metric=0.5096


 86%|██████████████████████████████████████████████████████████████████████████████▌            | 259/300 [19:02:12<3:00:06, 263.57s/it]

Epoch 260 / trn/loss=0.6344
Epoch 260 / train/metric=0.4434




Epoch 260 / val/metric=0.5108


 87%|██████████████████████████████████████████████████████████████████████████████▊            | 260/300 [19:06:37<2:55:54, 263.87s/it]

Epoch 261 / trn/loss=0.6462
Epoch 261 / train/metric=0.4526




Epoch 261 / val/metric=0.5064


 87%|███████████████████████████████████████████████████████████████████████████████▏           | 261/300 [19:11:01<2:51:33, 263.94s/it]

Epoch 262 / trn/loss=0.6423
Epoch 262 / train/metric=0.4485




Epoch 262 / val/metric=0.5056


 87%|███████████████████████████████████████████████████████████████████████████████▍           | 262/300 [19:15:24<2:47:06, 263.85s/it]

Epoch 263 / trn/loss=0.6377
Epoch 263 / train/metric=0.4445




Epoch 263 / val/metric=0.5095


 88%|███████████████████████████████████████████████████████████████████████████████▊           | 263/300 [19:19:48<2:42:39, 263.78s/it]

Epoch 264 / trn/loss=0.6380
Epoch 264 / train/metric=0.4443




Epoch 264 / val/metric=0.5028


 88%|████████████████████████████████████████████████████████████████████████████████           | 264/300 [19:24:11<2:38:05, 263.49s/it]

Epoch 265 / trn/loss=0.6347
Epoch 265 / train/metric=0.4421




Epoch 265 / val/metric=0.5060


 88%|████████████████████████████████████████████████████████████████████████████████▍          | 265/300 [19:28:35<2:33:45, 263.60s/it]

Epoch 266 / trn/loss=0.6310
Epoch 266 / train/metric=0.4404




Epoch 266 / val/metric=0.5094


 89%|████████████████████████████████████████████████████████████████████████████████▋          | 266/300 [19:32:59<2:29:26, 263.73s/it]

Epoch 267 / trn/loss=0.6424
Epoch 267 / train/metric=0.4474




Epoch 267 / val/metric=0.5064


 89%|████████████████████████████████████████████████████████████████████████████████▉          | 267/300 [19:37:22<2:25:01, 263.67s/it]

Epoch 268 / trn/loss=0.6393
Epoch 268 / train/metric=0.4460




Epoch 268 / val/metric=0.5064


 89%|█████████████████████████████████████████████████████████████████████████████████▎         | 268/300 [19:41:47<2:20:44, 263.89s/it]

Epoch 269 / trn/loss=0.6372
Epoch 269 / train/metric=0.4441




Epoch 269 / val/metric=0.5033


 90%|█████████████████████████████████████████████████████████████████████████████████▌         | 269/300 [19:46:11<2:16:23, 263.99s/it]

Epoch 270 / trn/loss=0.6330
Epoch 270 / train/metric=0.4427




Epoch 270 / val/metric=0.5090


 90%|█████████████████████████████████████████████████████████████████████████████████▉         | 270/300 [19:50:35<2:12:00, 264.01s/it]

Epoch 271 / trn/loss=0.6378
Epoch 271 / train/metric=0.4468




Epoch 271 / val/metric=0.5040


 90%|██████████████████████████████████████████████████████████████████████████████████▏        | 271/300 [19:54:58<2:07:24, 263.61s/it]

Epoch 272 / trn/loss=0.6375
Epoch 272 / train/metric=0.4461




Epoch 272 / val/metric=0.5057


 91%|██████████████████████████████████████████████████████████████████████████████████▌        | 272/300 [19:59:20<2:02:53, 263.34s/it]

Epoch 273 / trn/loss=0.6338
Epoch 273 / train/metric=0.4440




Epoch 273 / val/metric=0.5040


 91%|██████████████████████████████████████████████████████████████████████████████████▊        | 273/300 [20:03:43<1:58:25, 263.16s/it]

Epoch 274 / trn/loss=0.6367
Epoch 274 / train/metric=0.4450




Epoch 274 / val/metric=0.5103


 91%|███████████████████████████████████████████████████████████████████████████████████        | 274/300 [20:08:06<1:54:03, 263.19s/it]

Epoch 275 / trn/loss=0.6373
Epoch 275 / train/metric=0.4454




Epoch 275 / val/metric=0.5053


 92%|███████████████████████████████████████████████████████████████████████████████████▍       | 275/300 [20:12:29<1:49:37, 263.08s/it]

Epoch 276 / trn/loss=0.6373
Epoch 276 / train/metric=0.4455




Epoch 276 / val/metric=0.5050


 92%|███████████████████████████████████████████████████████████████████████████████████▋       | 276/300 [20:16:52<1:45:08, 262.87s/it]

Epoch 277 / trn/loss=0.6390
Epoch 277 / train/metric=0.4457




Epoch 277 / val/metric=0.5103


 92%|████████████████████████████████████████████████████████████████████████████████████       | 277/300 [20:21:14<1:40:39, 262.58s/it]

Epoch 278 / trn/loss=0.6375
Epoch 278 / train/metric=0.4446




Epoch 278 / val/metric=0.5041


 93%|████████████████████████████████████████████████████████████████████████████████████▎      | 278/300 [20:25:37<1:36:22, 262.86s/it]

Epoch 279 / trn/loss=0.6367
Epoch 279 / train/metric=0.4443




Epoch 279 / val/metric=0.5056


 93%|████████████████████████████████████████████████████████████████████████████████████▋      | 279/300 [20:30:02<1:32:11, 263.39s/it]

Epoch 280 / trn/loss=0.6418
Epoch 280 / train/metric=0.4488




Epoch 280 / val/metric=0.5133


 93%|████████████████████████████████████████████████████████████████████████████████████▉      | 280/300 [20:34:26<1:27:51, 263.57s/it]

Epoch 281 / trn/loss=0.6343
Epoch 281 / train/metric=0.4437




Epoch 281 / val/metric=0.5051


 94%|█████████████████████████████████████████████████████████████████████████████████████▏     | 281/300 [20:38:50<1:23:33, 263.87s/it]

Epoch 282 / trn/loss=0.6334
Epoch 282 / train/metric=0.4431




Epoch 282 / val/metric=0.5124


 94%|█████████████████████████████████████████████████████████████████████████████████████▌     | 282/300 [20:43:12<1:18:58, 263.25s/it]

Epoch 283 / trn/loss=0.6376
Epoch 283 / train/metric=0.4455




Epoch 283 / val/metric=0.5080


                                                                                                                                        

In [None]:
#Execute this cell to fininsh the wandb run when you stopped training.
import wandb
try: 
    wandb.log({'best_total_log_loss': best_metric})
    wandb.finish()
    
except:
    print('Wandb is already finished!')

In [None]:
ind = 23561
train_meta_df[train_meta_df['patient_id']==ind]

In [None]:
valid_dataset = AbdominalCTDataset(train_meta_df[train_meta_df['patient_id']==ind], is_train = False, transform_set = None,
                                    remain_transforms_set = None)

valid_loader = DataLoader(dataset = valid_dataset, shuffle = False, batch_size = BATCH_SIZE, pin_memory = False, 
                        num_workers = N_WORKERS, drop_last = False)     

In [None]:
X_outs=[]
ys=[]
model.eval()
model.load_state_dict(torch.load(f'{BASE_PATH}/weights/timm_resnet10t.c3_in1k_lr0.0002_epochs_500_resol128_batch24-cv0443.pt'))
for X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, y in valid_loader:
    batch_size = y.shape[0]
    X_bowel, X_lkid, X_rkid = X_bowel.to(DEVICE), X_lkid.to(DEVICE), X_rkid.to(DEVICE)
    X_liv, X_spl, X_tot     = X_liv.to(DEVICE), X_spl.to(DEVICE), X_tot.to(DEVICE)
    y = y.to(DEVICE)           
    with torch.cuda.amp.autocast(enabled=True):                
        with torch.no_grad():                 
            X_out, X_any = model(X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot)                                          
            y_any = torch.cat([torch.ones(batch_size, 1).to(DEVICE)- y[:,13:14],y[:,13:14]], dim = 1)              
            X_out = apply_softmax_to_labels(X_out).to('cpu').numpy()

            X_any = X_any.to('cpu').numpy()
            X_out = np.hstack([X_out, X_any])
            X_outs.append(X_out)

            y     = y.to('cpu').numpy()[:,:-1]
            y_any = y_any.to('cpu').numpy()
            y     = np.hstack([y, y_any])
            ys.append(y)

X_outs = np.vstack(X_outs) 
ys     = np.vstack(ys)
#metric = calculate_score(X_outs, ys, 'valid')                      

del X_bowel, X_lkid, X_rkid, X_liv, X_spl, X_tot, X_any
gc.collect()        
torch.cuda.empty_cache()   

In [None]:
np.average(X_outs, axis = 0)


In [None]:
len(X_outs)

In [None]:
!nvidia-smi