In [1]:
# https://github.com/selimsef/dfdc_deepfake_challenge/blob/master/training/pipelines/train_classifier.py
import argparse
import json
import os
import glob
import pickle
import gc
import sys
import itertools
from collections import defaultdict, OrderedDict
import platform
try:
    PATH = '/Users/dhanley/Documents/rsnastr'         if platform.system() == 'Darwin' else '/data/rsnastr'
    os.chdir(PATH)
except:
    PATH = '/mount'
    os.chdir(PATH)
sys.path.append(PATH)
import warnings
warnings.filterwarnings("ignore")
from sklearn.metrics import log_loss
from utils.logs import get_logger
from utils.utils import RSNAWEIGHTS, RSNA_CFG as CFG
from training.tools.config import load_config
import pandas as pd
import cv2

import torch
from torch.backends import cudnn
from torch.nn import DataParallel
from torch import nn
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast

from tqdm import tqdm
import torch.distributed as dist
from training.datasets.classifier_dataset import RSNASequenceDataset, collateseqfn,         valSeedSampler, examSampler
from training.zoo.sequence import SpatialDropout, LSTMNet
from training.tools.utils import create_optimizer, AverageMeter, collectPreds, collectLoss
from training.tools.utils import splitbatch, unmasklabels, unmasklogits
from training.losses import getLoss
from training import losses
from torch.optim.swa_utils import AveragedModel, SWALR
from tensorboardX import SummaryWriter

os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

In [2]:
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensor
logger = get_logger('LSTM', 'INFO') 

In [3]:
#!pip install albumentations

In [4]:
# In[2]:
import sys; sys.argv=['']; del sys
logger.info('Load args')
logger.info('Load args')
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
arg('--workers', type=int, default=6, help='number of cpu threads to use')
arg('--device', type=str, default='cpu' if platform.system() == 'Darwin' else 'cuda', help='device for model - cpu/gpu')
arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
arg('--output-dir', type=str, default='weights/')
arg('--resume', type=str, default='')
arg('--fold', type=int, default=0)
arg('--batchsize', type=int, default=4)
arg('--lr', type=float, default = 0.00001)
arg('--lrgamma', type=float, default = 0.95)
arg('--labeltype', type=str, default='all') # or 'single'
arg('--dropout', type=float, default = 0.2)
arg('--prefix', type=str, default='classifier_')
arg('--data-dir', type=str, default="data")
arg('--folds-csv', type=str, default='folds.csv.gz')
arg('--nclasses', type=str, default=1)
arg('--crops-dir', type=str, default='jpegip')
arg('--lstm_units',   type=int, default=512)
arg('--epochs',   type=int, default=12)
arg("--delta", default=False, type=lambda x: (str(x).lower() == 'true'))
arg('--nbags',   type=int, default=12)
arg('--label-smoothing', type=float, default=0.00)
arg('--logdir', type=str, default='logs/b2_1820')
arg("--local_rank", default=0, type=int)
arg("--seed", default=777, type=int)
arg("--imgembrgx", type=str, default='')
args = parser.parse_args()


2020-10-22 21:27:23,311 - LSTM - INFO - Load args
2020-10-22 21:27:23,312 - LSTM - INFO - Load args


In [5]:
args.lr=0.00005 
args.label_smoothing=0.0 
args.dropout=0.3
args.device='cuda' 
args.fold=0 
args.batchsize=32
args.imgembrgx='weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_*__fold*_best_dice__all_size320.emb'
args.exmembrgx='weights/exam_lstm_tf_efficientnet_b2_ns_epoch31_fold0.bin__all_size320.emb'

In [6]:
def takeimg(s):
    return s.split('/')[-1].replace('.jpg', '')
fimg = sorted(glob.glob(f'emb/{args.imgembrgx}*data.pk'))[0]
logger.info(f'Loading : {fimg}')
dfname, embname, imgnm = fimg, fimg.replace('.data.pk', '.npz'), fimg.replace('.data.pk', '.imgnames.pk')
imgls = list(map(takeimg, pickle.load( open( imgnm, "rb" ) )))
wtsname = embname.split('/')[-1].replace('.emb.npz', '')
embmat = np.load(embname)['arr_0']
logger.info(f'Weights : {wtsname}')

# In[6]:


datadf = pd.read_csv(f'{args.data_dir}/train.csv.zip')
datadf = datadf.set_index('SOPInstanceUID').loc[imgls].reset_index()
folddf = pd.read_csv(f'{args.data_dir}/{args.folds_csv}')

2020-10-22 21:27:23,392 - LSTM - INFO - Loading : emb/weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_0__fold0_best_dice__all_size320.emb.data.pk
2020-10-22 21:29:14,777 - LSTM - INFO - Weights : classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_0__fold0_best_dice__all_size320


In [7]:
metadf = pd.read_csv(f'{args.data_dir}/train_meta.csv')
metadf = metadf.set_index('SOPInstanceUID')[['slice_thicknesses']]
metadf = metadf.loc[datadf.SOPInstanceUID]
metadf['StudyInstanceUID'] = datadf['StudyInstanceUID'].values
metadf['slice_thicknesses1'] = metadf.slice_thicknesses.shift(1)
metadf['StudyInstanceUID1'] = metadf.StudyInstanceUID.shift(1)
metadf['thickdiff'] = metadf.slice_thicknesses - metadf.slice_thicknesses1
metadf['thickdiff'][metadf.StudyInstanceUID != metadf.StudyInstanceUID1] = 0.
metadf['thickdiff'] = metadf['thickdiff'].fillna(0)
metadf['thick'] = metadf['slice_thicknesses']
metadf = metadf[['thick', 'thickdiff']]

In [8]:
# In[7]:
logger.info('Create traindatasets')
logger.info(f'Embedding delta : {args.delta}')
trndataset = RSNASequenceDataset(datadf, 
                                   embmat, 
                                   folddf,
                                   mode="train",
                                   delta=args.delta,
                                   imgclasses=CFG["image_target_cols"],
                                   studyclasses=CFG['exam_target_cols'],
                                   fold=args.fold,
                                   label_smoothing=args.label_smoothing,
                                   folds_csv=args.folds_csv)
logger.info('Create valdatasets')
valdataset = RSNASequenceDataset(datadf, 
                                   embmat, 
                                   #embexmmat,
                                   folddf,
                                   mode="valid",
                                   delta=args.delta,
                                   imgclasses=CFG["image_target_cols"],
                                   studyclasses=CFG['exam_target_cols'],
                                   fold=args.fold,
                                   label_smoothing=args.label_smoothing,
                                   folds_csv=args.folds_csv)

2020-10-22 21:29:28,078 - LSTM - INFO - Create traindatasets
2020-10-22 21:29:28,079 - LSTM - INFO - Embedding delta : False
2020-10-22 21:29:28,926 - LSTM - INFO - Create valdatasets


In [9]:
import timm
import numpy as np
import torch
from tqdm import tqdm
from sklearn.metrics import log_loss
import pandas as pd
from torch import nn
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.pooling import AdaptiveAvgPool2d
import torch.nn.functional as F
from torch import nn

class SpatialDropout(nn.Dropout2d):
    def forward(self, x):
        x = x.unsqueeze(2)    # (N, T, 1, K)
        x = x.permute(0, 3, 2, 1)  # (N, K, 1, T)
        x = super(SpatialDropout, self).forward(x)  # (N, K, 1, T), some features are masked
        x = x.permute(0, 3, 2, 1)  # (N, T, 1, K)
        x = x.squeeze(2)  # (N, T, K)
        return x
    
# https://www.kaggle.com/bminixhofer/speed-up-your-rnn-with-sequence-bucketing
class LSTMNet(nn.Module):
    def __init__(self, 
                 embed_size, 
                 nimgclasses = 1, 
                 nstudyclasses = 9, 
                 LSTM_UNITS=64, 
                 DO = 0.3):
        super(LSTMNet, self).__init__()
        
        self.nimgclasses = nimgclasses
        self.nstudyclasses = nstudyclasses
        self.embed_size = embed_size
        self.embedding_dropout = SpatialDropout(DO)
        
        self.lstm1 = nn.LSTM(embed_size, LSTM_UNITS, bidirectional=True, batch_first=True)
        self.lstm2 = nn.LSTM(LSTM_UNITS * 2, LSTM_UNITS, bidirectional=True, batch_first=True)

        self.img_linear1 = nn.Linear(LSTM_UNITS*2, LSTM_UNITS*2)
        self.img_linear2 = nn.Linear(LSTM_UNITS*2, LSTM_UNITS*2)
        self.study_linear1 = nn.Linear(LSTM_UNITS*4, LSTM_UNITS*4)

        self.img_linear_out = nn.Linear(LSTM_UNITS*2, self.nimgclasses)
        self.study_linear_out = nn.Linear(LSTM_UNITS*4, self.nstudyclasses)

    def forward(self, x, mask, diffmat = None, lengths=None):
        
        h_embedding = x

        h_embadd = torch.cat((h_embedding[:,:,:self.embed_size], h_embedding[:,:,:self.embed_size]), -1)
        h_embadd = self.embedding_dropout(h_embadd)
        h_lstm1, _ = self.lstm1(h_embedding)
        h_lstm2, _ = self.lstm2(h_lstm1)
        
        # Masked mean and max pool for study level prediction
        avg_pool = torch.sum(h_lstm2 * mask.unsqueeze(-1).float(), 1)* \
            (1/ mask.sum(1).float()).unsqueeze(1)
        max_pool, _ = torch.max(h_lstm2 * mask.unsqueeze(-1).float(), 1)
        #avg_pool = torch.sum(h_lstm2, 1) * (1/ mask.sum(1).float()).unsqueeze(1)
        # max_pool, _ = torch.max(h_lstm2, 1)
        
        # Get study level prediction
        h_study_conc = torch.cat((max_pool, avg_pool), 1)
        h_study_conc_linear1  = nn.functional.relu(self.study_linear1(h_study_conc))
        study_hidden = h_study_conc + h_study_conc_linear1
        study_output = self.study_linear_out(study_hidden)
        
        # Get study level prediction
        h_img_conc_linear1  = nn.functional.relu(self.img_linear1(h_lstm1))
        h_img_conc_linear2  = nn.functional.relu(self.img_linear2(h_lstm2))
        img_hidden = h_lstm1 + h_lstm2 + h_img_conc_linear1 + h_img_conc_linear2 # + h_embadd
        img_output = self.img_linear_out(img_hidden)
        
        return study_output, img_output
    
def diffmat(img_names, mask, metadf):
    diffmat = torch.zeros(mask.shape).flatten()
    imgidx = img_names.flatten()[img_names.flatten()!='mask']
    diffmat[mask.flatten()==1.] = torch.tensor(metadf['thickdiff'].loc[imgidx].values).float()
    diffmat = diffmat.reshape(mask.shape)
    diffmat = diffmat.half()
    return diffmat

In [10]:
# In[10]:


logger.info('Create loaders...')
valloader = DataLoader(valdataset, batch_size=args.batchsize, shuffle=False, num_workers=4, collate_fn=collateseqfn)
embed_size = embmat.shape[1]
if args.delta:
    embed_size = embed_size * 3
gc.collect()
# In[11]:

logger.info('Create model')
self = model = LSTMNet(embed_size+1, 
                       nimgclasses = len(CFG["image_target_cols"]), 
                       nstudyclasses = len(CFG['exam_target_cols']),
                       LSTM_UNITS=args.lstm_units, 
                       DO = args.dropout)
model = model.to(args.device)
DECAY = 0.0
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
plist = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': DECAY},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = torch.optim.Adam(plist, lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=args.lrgamma, last_epoch=-1)


# In[12]:


ypredls = []
ypredtstls = []
scaler = torch.cuda.amp.GradScaler()
bce_func_exam = torch.nn.BCEWithLogitsLoss(reduction='none', 
                    weight = torch.tensor(CFG['exam_weights']).to(args.device))
bce_func_img = torch.nn.BCEWithLogitsLoss(reduction='none')


# In[13]:

2020-10-22 21:29:29,792 - LSTM - INFO - Create loaders...
2020-10-22 21:29:29,888 - LSTM - INFO - Create model


In [11]:
bcewLL_func = torch.nn.BCEWithLogitsLoss(reduction='none')
label_w = torch.tensor([0.0736196319, 
             0.2346625767, 
             0.0782208589, 
             0.06257668712, 
             0.1042944785, 
             0.06257668712, 
             0.1042944785, 
             0.1877300613, 
             0.09202453988]).to(args.device, dtype=torch.float)
image_w = torch.tensor(0.07361963).to(args.device, dtype=torch.float)

def exam_lossfn(studylogits, ystudy, criterion = bcewLL_func):
    exam_loss = criterion(studylogits, ystudy)
    exam_wts = exam_loss.shape[0]
    exam_loss = torch.sum(exam_loss*label_w, 1).sum()
    return exam_loss, exam_wts

def image_lossfn(imglogits, yimg, mask, criterion = bcewLL_func):
    criterion = bcewLL_func
    qi = yimg.sum(1)/mask.sum(1)
    img_num = mask.sum(1)
    image_loss = (criterion(imglogits.squeeze(-1), yimg) * mask).sum(1)
    image_loss = torch.sum(image_w*qi*image_loss)
    image_wt = torch.sum(image_w*qi*img_num)
    return image_loss, image_wt

class resultsfn:
    loss   = 0.
    wts    = 0.
    imgloss   = 0.
    imgwts    = 0.
    exmloss   = 0.
    exmwts    = 0.

In [None]:
logger.info('Start training')
best_val_loss = 100.
for epoch in range(args.epochs):
    examsampler = examSampler(trndataset.datadf, trndataset.folddf)
    trnloader = DataLoader(trndataset, batch_size=args.batchsize, sampler = examsampler, num_workers=4, collate_fn=collateseqfn)
    for param in model.parameters():
        param.requires_grad = True
    model.train()  
    trnres = resultsfn()
    pbartrn = tqdm(enumerate(trnloader), 
                total = len(trndataset)//trnloader.batch_size, 
                desc=f"Train epoch {epoch}", ncols=0)
    for step, batch in pbartrn:
        img_names, yimg, ystudy, masktrn, lelabels = splitbatch(batch, args.device)
        if yimg.sum()==0: 
            logger.info('AAAAAA')
            continue
            
        
        slicediff = diffmat(batch['img_name'], batch['mask'], metadf)
        xtrn = torch.cat((batch['emb'], slicediff.unsqueeze(-1)), -1)
        xtrn = xtrn.to(args.device, dtype=torch.float)
        
        
        xtrn = torch.autograd.Variable(xtrn, requires_grad=True)
        yimg = torch.autograd.Variable(yimg)
        ystudy = torch.autograd.Variable(ystudy)
        with autocast():
            studylogits, imglogits = model(xtrn, masktrn)
            exam_loss, exam_wts = exam_lossfn(studylogits, ystudy)
            image_loss, image_wts = image_lossfn(imglogits, yimg, masktrn)
        loss = (exam_loss+image_loss)/(exam_wts+image_wts)
        scaler.scale(loss).backward()
        trnres.loss += (exam_loss+image_loss).item()
        trnres.wts += (exam_wts+image_wts).item()
        trnres.imgloss   += image_loss.item()
        trnres.imgwts    += image_wts.item()
        trnres.exmloss   += exam_loss.item()
        trnres.exmwts    += exam_wts
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        pbartrn.set_postfix({'train loss': trnres.loss/trnres.wts, 
                          'image loss': trnres.imgloss/trnres.imgwts, 
                          'exam loss': trnres.exmloss/trnres.exmwts})
        
        if step%100==0:
            torch.cuda.empty_cache()  
    #logger.info(f'Epoch {epoch} train loss all {trnres.loss/trnres.wts:.4f}')
    deltamsg = '_delta' if args.delta else ''
    output_model_file = f'weights/exam_lstm_{wtsname}{deltamsg}__epoch{epoch}.bin'
    torch.save(model.state_dict(), output_model_file)
    scheduler.step()
    model.eval()  
    valres = resultsfn()
    pbarval = tqdm(enumerate(valloader), 
                total = len(valdataset)//valloader.batch_size, 
                desc=f"Valid epoch {epoch}", ncols=0)
    for step, batch in pbarval:
        img_names, yimg, ystudy, maskval, lelabels = splitbatch(batch, args.device)
        if yimg.sum()==0: 
            logger.info('AAAAAA')
            continue
        #xval = batch['emb'].to(args.device, dtype=torch.float)
        
        slicediff = diffmat(batch['img_name'], batch['mask'], metadf)
        xval = torch.cat((batch['emb'], slicediff.unsqueeze(-1)), -1)
        xval = xval.to(args.device, dtype=torch.float)
        
        
        with torch.no_grad():
            studylogits, imglogits = model(xval, maskval)
            exam_loss, exam_wts = exam_lossfn(studylogits, ystudy)
            image_loss, image_wts = image_lossfn(imglogits, yimg, maskval)
        loss = (exam_loss+image_loss)/(exam_wts+image_wts)
        valres.loss += (exam_loss+image_loss).item()
        valres.wts += (exam_wts+image_wts).item()
        valres.imgloss   += image_loss.item()
        valres.imgwts    += image_wts.item()
        valres.exmloss   += exam_loss.item()
        valres.exmwts    += exam_wts
        pbarval.set_postfix({'valid loss': valres.loss/valres.wts, 
                          'image loss': valres.imgloss/valres.imgwts, 
                          'exam loss': valres.exmloss/valres.exmwts})
    val_loss = valres.loss/valres.wts
    if best_val_loss>val_loss:
        output_model_file = f'weights/exam_lstm_{wtsname}{deltamsg}__epoch{epoch}.bin'
        torch.save(model.state_dict(), output_model_file)
        best_val_loss=val_loss
        logger.info(f'Best Epoch {epoch} val loss all {best_val_loss:.4f}')

2020-10-22 21:35:18,972 - LSTM - INFO - Start training

Valid epoch 0:   0% 0/45 [00:00<?, ?it/s][A
Valid epoch 0:   0% 0/45 [00:04<?, ?it/s, valid loss=0.279, image loss=0.18, exam loss=0.334][A
Valid epoch 0:   2% 1/45 [00:04<03:26,  4.70s/it, valid loss=0.279, image loss=0.18, exam loss=0.334][A
Valid epoch 0:   2% 1/45 [00:04<03:26,  4.70s/it, valid loss=0.257, image loss=0.202, exam loss=0.293][A
Valid epoch 0:   4% 2/45 [00:04<02:25,  3.38s/it, valid loss=0.257, image loss=0.202, exam loss=0.293][A
Valid epoch 0:   4% 2/45 [00:05<02:25,  3.38s/it, valid loss=0.241, image loss=0.203, exam loss=0.27] [A
Valid epoch 0:   7% 3/45 [00:05<01:44,  2.48s/it, valid loss=0.241, image loss=0.203, exam loss=0.27][A
Valid epoch 0:   7% 3/45 [00:05<01:44,  2.48s/it, valid loss=0.244, image loss=0.208, exam loss=0.272][A
Valid epoch 0:   9% 4/45 [00:05<01:15,  1.84s/it, valid loss=0.244, image loss=0.208, exam loss=0.272][A
Valid epoch 0:   9% 4/45 [00:09<01:15,  1.84s/it, valid loss=0

Valid epoch 0:  89% 40/45 [00:52<00:04,  1.21it/s, valid loss=0.238, image loss=0.248, exam loss=0.229][A
Valid epoch 0:  91% 41/45 [00:52<00:07,  1.77s/it, valid loss=0.238, image loss=0.248, exam loss=0.229][A
Valid epoch 0:  91% 41/45 [00:53<00:07,  1.77s/it, valid loss=0.24, image loss=0.25, exam loss=0.231]  [A
Valid epoch 0:  93% 42/45 [00:53<00:04,  1.37s/it, valid loss=0.24, image loss=0.25, exam loss=0.231][A
Valid epoch 0:  93% 42/45 [00:53<00:04,  1.37s/it, valid loss=0.24, image loss=0.25, exam loss=0.231][A
Valid epoch 0:  96% 43/45 [00:53<00:02,  1.05s/it, valid loss=0.24, image loss=0.25, exam loss=0.231][A
Valid epoch 0:  96% 43/45 [00:53<00:02,  1.05s/it, valid loss=0.239, image loss=0.249, exam loss=0.229][A
Valid epoch 0:  98% 44/45 [00:53<00:00,  1.27it/s, valid loss=0.239, image loss=0.249, exam loss=0.229][A
Valid epoch 0:  98% 44/45 [00:57<00:00,  1.27it/s, valid loss=0.239, image loss=0.249, exam loss=0.229][A
Valid epoch 0: 100% 45/45 [00:57<00:00,  1.

Valid epoch 1:  76% 34/45 [00:43<00:13,  1.26s/it, valid loss=0.238, image loss=0.259, exam loss=0.22] [A
Valid epoch 1:  78% 35/45 [00:43<00:10,  1.08s/it, valid loss=0.238, image loss=0.259, exam loss=0.22][A
Valid epoch 1:  78% 35/45 [00:43<00:10,  1.08s/it, valid loss=0.236, image loss=0.254, exam loss=0.221][A
Valid epoch 1:  80% 36/45 [00:43<00:07,  1.23it/s, valid loss=0.236, image loss=0.254, exam loss=0.221][A
Valid epoch 1:  80% 36/45 [00:47<00:07,  1.23it/s, valid loss=0.235, image loss=0.251, exam loss=0.221][A
Valid epoch 1:  82% 37/45 [00:47<00:15,  1.90s/it, valid loss=0.235, image loss=0.251, exam loss=0.221][A
Valid epoch 1:  82% 37/45 [00:48<00:15,  1.90s/it, valid loss=0.236, image loss=0.249, exam loss=0.224][A
Valid epoch 1:  84% 38/45 [00:48<00:10,  1.54s/it, valid loss=0.236, image loss=0.249, exam loss=0.224][A
Valid epoch 1:  84% 38/45 [00:48<00:10,  1.54s/it, valid loss=0.237, image loss=0.249, exam loss=0.227][A
Valid epoch 1:  87% 39/45 [00:48<00:06

Valid epoch 2:  62% 28/45 [00:33<00:12,  1.41it/s, valid loss=0.242, image loss=0.261, exam loss=0.225][A
Valid epoch 2:  62% 28/45 [00:37<00:12,  1.41it/s, valid loss=0.241, image loss=0.263, exam loss=0.222][A
Valid epoch 2:  64% 29/45 [00:37<00:28,  1.76s/it, valid loss=0.241, image loss=0.263, exam loss=0.222][A
Valid epoch 2:  64% 29/45 [00:37<00:28,  1.76s/it, valid loss=0.241, image loss=0.263, exam loss=0.222][A
Valid epoch 2:  67% 30/45 [00:37<00:18,  1.26s/it, valid loss=0.241, image loss=0.263, exam loss=0.222][A
Valid epoch 2:  67% 30/45 [00:37<00:18,  1.26s/it, valid loss=0.239, image loss=0.262, exam loss=0.22] [A
Valid epoch 2:  69% 31/45 [00:37<00:12,  1.08it/s, valid loss=0.239, image loss=0.262, exam loss=0.22][A
Valid epoch 2:  69% 31/45 [00:37<00:12,  1.08it/s, valid loss=0.236, image loss=0.259, exam loss=0.217][A
Valid epoch 2:  71% 32/45 [00:37<00:08,  1.47it/s, valid loss=0.236, image loss=0.259, exam loss=0.217][A
Valid epoch 2:  71% 32/45 [00:42<00:08

Valid epoch 3:  51% 23/45 [00:28<00:16,  1.37it/s, valid loss=0.241, image loss=0.259, exam loss=0.225][A
Valid epoch 3:  51% 23/45 [00:28<00:16,  1.37it/s, valid loss=0.238, image loss=0.259, exam loss=0.221][A
Valid epoch 3:  53% 24/45 [00:28<00:11,  1.84it/s, valid loss=0.238, image loss=0.259, exam loss=0.221][A
Valid epoch 3:  53% 24/45 [00:32<00:11,  1.84it/s, valid loss=0.24, image loss=0.26, exam loss=0.223]  [A
Valid epoch 3:  56% 25/45 [00:32<00:33,  1.67s/it, valid loss=0.24, image loss=0.26, exam loss=0.223][A
Valid epoch 3:  56% 25/45 [00:33<00:33,  1.67s/it, valid loss=0.241, image loss=0.264, exam loss=0.222][A
Valid epoch 3:  58% 26/45 [00:33<00:22,  1.20s/it, valid loss=0.241, image loss=0.264, exam loss=0.222][A
Valid epoch 3:  58% 26/45 [00:33<00:22,  1.20s/it, valid loss=0.241, image loss=0.261, exam loss=0.224][A
Valid epoch 3:  60% 27/45 [00:33<00:16,  1.09it/s, valid loss=0.241, image loss=0.261, exam loss=0.224][A
Valid epoch 3:  60% 27/45 [00:33<00:16,

Valid epoch 4:  36% 16/45 [00:19<00:20,  1.43it/s, valid loss=0.243, image loss=0.256, exam loss=0.232][A
Valid epoch 4:  36% 16/45 [00:23<00:20,  1.43it/s, valid loss=0.242, image loss=0.256, exam loss=0.229][A
Valid epoch 4:  38% 17/45 [00:23<00:48,  1.72s/it, valid loss=0.242, image loss=0.256, exam loss=0.229][A
Valid epoch 4:  38% 17/45 [00:23<00:48,  1.72s/it, valid loss=0.244, image loss=0.26, exam loss=0.23]  [A
Valid epoch 4:  40% 18/45 [00:23<00:33,  1.25s/it, valid loss=0.244, image loss=0.26, exam loss=0.23][A
Valid epoch 4:  40% 18/45 [00:23<00:33,  1.25s/it, valid loss=0.253, image loss=0.273, exam loss=0.236][A
Valid epoch 4:  42% 19/45 [00:23<00:25,  1.03it/s, valid loss=0.253, image loss=0.273, exam loss=0.236][A
Valid epoch 4:  42% 19/45 [00:23<00:25,  1.03it/s, valid loss=0.249, image loss=0.268, exam loss=0.233][A
Valid epoch 4:  44% 20/45 [00:23<00:17,  1.41it/s, valid loss=0.249, image loss=0.268, exam loss=0.233][A
Valid epoch 4:  44% 20/45 [00:28<00:17,

Valid epoch 5:  20% 9/45 [00:14<01:07,  1.88s/it, valid loss=0.258, image loss=0.262, exam loss=0.254][A
Valid epoch 5:  22% 10/45 [00:14<00:47,  1.35s/it, valid loss=0.258, image loss=0.262, exam loss=0.254][A
Valid epoch 5:  22% 10/45 [00:14<00:47,  1.35s/it, valid loss=0.257, image loss=0.261, exam loss=0.253][A
Valid epoch 5:  24% 11/45 [00:14<00:34,  1.02s/it, valid loss=0.257, image loss=0.261, exam loss=0.253][A
Valid epoch 5:  24% 11/45 [00:14<00:34,  1.02s/it, valid loss=0.255, image loss=0.263, exam loss=0.248][A
Valid epoch 5:  24% 11/45 [00:18<00:34,  1.02s/it, valid loss=0.251, image loss=0.26, exam loss=0.244] [A
Valid epoch 5:  29% 13/45 [00:18<00:43,  1.36s/it, valid loss=0.251, image loss=0.26, exam loss=0.244][A
Valid epoch 5:  29% 13/45 [00:19<00:43,  1.36s/it, valid loss=0.247, image loss=0.259, exam loss=0.238][A
Valid epoch 5:  31% 14/45 [00:19<00:30,  1.01it/s, valid loss=0.247, image loss=0.259, exam loss=0.238][A
Valid epoch 5:  31% 14/45 [00:19<00:30,

Valid epoch 6:   7% 3/45 [00:05<02:34,  3.67s/it, valid loss=0.244, image loss=0.208, exam loss=0.272][A
Valid epoch 6:   9% 4/45 [00:05<01:46,  2.60s/it, valid loss=0.244, image loss=0.208, exam loss=0.272][A
Valid epoch 6:   9% 4/45 [00:09<01:46,  2.60s/it, valid loss=0.251, image loss=0.227, exam loss=0.274][A
Valid epoch 6:  11% 5/45 [00:09<02:02,  3.06s/it, valid loss=0.251, image loss=0.227, exam loss=0.274][A
Valid epoch 6:  11% 5/45 [00:09<02:02,  3.06s/it, valid loss=0.253, image loss=0.244, exam loss=0.262][A
Valid epoch 6:  13% 6/45 [00:09<01:24,  2.17s/it, valid loss=0.253, image loss=0.244, exam loss=0.262][A
Valid epoch 6:  13% 6/45 [00:09<01:24,  2.17s/it, valid loss=0.261, image loss=0.257, exam loss=0.265][A
Valid epoch 6:  16% 7/45 [00:09<00:59,  1.56s/it, valid loss=0.261, image loss=0.257, exam loss=0.265][A
Valid epoch 6:  16% 7/45 [00:10<00:59,  1.56s/it, valid loss=0.257, image loss=0.254, exam loss=0.261][A
Valid epoch 6:  16% 7/45 [00:14<00:59,  1.56s/

Valid epoch 6:  96% 43/45 [00:53<00:02,  1.04s/it, valid loss=0.239, image loss=0.249, exam loss=0.229][A
Valid epoch 6:  98% 44/45 [00:53<00:00,  1.28it/s, valid loss=0.239, image loss=0.249, exam loss=0.229][A