In [1]:
# https://github.com/selimsef/dfdc_deepfake_challenge/blob/master/training/pipelines/train_classifier.py
import argparse
import json
import os
import glob
import pickle
import gc
import sys
import itertools
from collections import defaultdict, OrderedDict
import platform
PATH = '/Users/dhanley/Documents/rsnastr' \
        if platform.system() == 'Darwin' else '/data/rsnastr'
os.chdir(PATH)
sys.path.append(PATH)
import warnings
warnings.filterwarnings("ignore")
from sklearn.metrics import log_loss
from utils.logs import get_logger
from utils.utils import RSNAWEIGHTS, RSNA_CFG as CFG
from training.tools.config import load_config
import pandas as pd
import cv2

import torch
from torch.backends import cudnn
from torch.nn import DataParallel
from torch import nn
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast

from tqdm import tqdm
import torch.distributed as dist
from training.datasets.classifier_dataset import RSNASequenceDataset, collateseqfn, \
        valSeedSampler, examSampler
from training.zoo.sequence import SpatialDropout, LSTMNet
from training.tools.utils import create_optimizer, AverageMeter, collectPreds, collectLoss
from training.tools.utils import splitbatch, unmasklabels, unmasklogits
from training.losses import getLoss
from training import losses
from torch.optim.swa_utils import AveragedModel, SWALR
from tensorboardX import SummaryWriter

os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensor
logger = get_logger('LSTM', 'INFO') 

In [2]:
import sys; sys.argv=['']; del sys
logger.info('Load args')
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
arg('--workers', type=int, default=6, help='number of cpu threads to use')
arg('--device', type=str, default='cpu' if platform.system() == 'Darwin' else 'cuda', help='device for model - cpu/gpu')
arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
arg('--output-dir', type=str, default='weights/')
arg('--resume', type=str, default='')
arg('--fold', type=int, default=0)
arg('--batchsize', type=int, default=4)
arg('--lr', type=float, default = 0.00001)
arg('--lrgamma', type=float, default = 0.95)
arg('--labeltype', type=str, default='all') # or 'single'
arg('--dropout', type=float, default = 0.2)
arg('--prefix', type=str, default='classifier_')
arg('--data-dir', type=str, default="data")
arg('--folds-csv', type=str, default='folds.csv.gz')
arg('--nclasses', type=str, default=1)
arg('--crops-dir', type=str, default='jpegip')
arg('--lstm_units',   type=int, default=512)
arg('--epochs',   type=int, default=12)
arg('--nbags',   type=int, default=12)
arg('--label-smoothing', type=float, default=0.00)
arg('--logdir', type=str, default='logs/b2_1820')
arg("--local_rank", default=0, type=int)
arg("--seed", default=777, type=int)
args = parser.parse_args()

2020-10-12 21:10:57,659 - LSTM - INFO - Load args


In [3]:
args.lr=0.00001 
args.label_smoothing=0.0 
args.dropout=0.3
args.device='cuda' 
args.fold=0 
args.batchsize=32
args.imgembrgx='weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_*__fold*_best_dice__all_size320.emb'
args.exmembrgx='weights/exam_lstm_tf_efficientnet_b2_ns_epoch31_fold0.bin__all_size320.emb'


In [4]:
def takeimg(s):
    return s.split('/')[-1].replace('.jpg', '')
fimg = sorted(glob.glob(f'emb/{args.imgembrgx}*data.pk'))[0]
dfname, embname, imgnm = fimg, fimg.replace('.data.pk', '.npz'), fimg.replace('.data.pk', '.imgnames.pk')
imgls = list(map(takeimg, pickle.load( open( imgnm, "rb" ) )))
wtsname = embname.split('/')[-1].replace('.emb.npz', '')
embmat = np.load(embname)['arr_0']

In [6]:
datadf = pd.read_csv(f'{args.data_dir}/train.csv.zip')
datadf = datadf.set_index('SOPInstanceUID').loc[imgls].reset_index()
folddf = pd.read_csv(f'{args.data_dir}/{args.folds_csv}')

In [7]:
logger.info('Create traindatasets')
trndataset = RSNASequenceDataset(datadf, 
                                   embmat, 
                                   #embexmmat, 
                                   folddf,
                                   mode="train",
                                   imgclasses=CFG["image_target_cols"],
                                   studyclasses=CFG['exam_target_cols'],
                                   fold=args.fold,
                                   label_smoothing=args.label_smoothing,
                                   folds_csv=args.folds_csv)
logger.info('Create valdatasets')
valdataset = RSNASequenceDataset(datadf, 
                                   embmat, 
                                   #embexmmat,
                                   folddf,
                                   mode="valid",
                                   imgclasses=CFG["image_target_cols"],
                                   studyclasses=CFG['exam_target_cols'],
                                   fold=args.fold,
                                   label_smoothing=args.label_smoothing,
                                   folds_csv=args.folds_csv)

2020-10-12 21:12:56,887 - LSTM - INFO - Create traindatasets
2020-10-12 21:12:57,917 - LSTM - INFO - Create valdatasets


In [10]:
logger.info('Create loaders...')
valloader = DataLoader(valdataset, batch_size=args.batchsize, shuffle=False, num_workers=4, collate_fn=collateseqfn)
embed_size = embmat.shape[1]
gc.collect()

2020-10-12 21:13:07,982 - LSTM - INFO - Create loaders...


3

In [11]:
logger.info('Create model')
model = LSTMNet(embed_size, 
                       nimgclasses = len(CFG["image_target_cols"]), 
                       nstudyclasses = len(CFG['exam_target_cols']),
                       LSTM_UNITS=args.lstm_units, 
                       DO = args.dropout)
model = model.to(args.device)
DECAY = 0.0
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
plist = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': DECAY},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = torch.optim.Adam(plist, lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=args.lrgamma, last_epoch=-1)

2020-10-12 21:13:08,097 - LSTM - INFO - Create model


In [12]:
ypredls = []
ypredtstls = []
scaler = torch.cuda.amp.GradScaler()
bce_func_exam = torch.nn.BCEWithLogitsLoss(reduction='none', 
                    weight = torch.tensor(CFG['exam_weights']).to(args.device))
bce_func_img = torch.nn.BCEWithLogitsLoss(reduction='none')

In [13]:
bcewLL_func = torch.nn.BCEWithLogitsLoss(reduction='none')
label_w = torch.tensor([0.0736196319, 
             0.2346625767, 
             0.0782208589, 
             0.06257668712, 
             0.1042944785, 
             0.06257668712, 
             0.1042944785, 
             0.1877300613, 
             0.09202453988]).to(args.device, dtype=torch.float)
image_w = torch.tensor(0.07361963).to(args.device, dtype=torch.float)

def exam_lossfn(studylogits, ystudy, criterion = bcewLL_func):
    exam_loss = criterion(studylogits, ystudy)
    exam_wts = exam_loss.shape[0]
    exam_loss = torch.sum(exam_loss*label_w, 1).sum()
    return exam_loss, exam_wts

def image_lossfn(imglogits, yimg, mask, criterion = bcewLL_func):
    criterion = bcewLL_func
    qi = yimg.sum(1)/mask.sum(1)
    img_num = mask.sum(1)
    image_loss = (criterion(imglogits.squeeze(-1), yimg) * mask).sum(1)
    image_loss = torch.sum(image_w*qi*image_loss)
    image_wt = torch.sum(image_w*qi*img_num)
    return image_loss, image_wt

class resultsfn:
    loss   = 0.
    wts    = 0.
    imgloss   = 0.
    imgwts    = 0.
    exmloss   = 0.
    exmwts    = 0.

In [14]:
del tqdm

In [15]:
logger.info('Start training')
best_val_loss = 100.
from tqdm import tqdm
for epoch in range(args.epochs):
    examsampler = examSampler(trndataset.datadf, trndataset.folddf)
    trnloader = DataLoader(trndataset, batch_size=args.batchsize, sampler = examsampler, num_workers=4, collate_fn=collateseqfn)
    for param in model.parameters():
        param.requires_grad = True
    model.train()  
    trnres = resultsfn()
    pbartrn = tqdm(enumerate(trnloader), 
                total = len(trndataset)//trnloader.batch_size, 
                desc=f"Train epoch {epoch}", ncols=0)
    for step, batch in pbartrn:
        img_names, yimg, ystudy, masktrn, lelabels = splitbatch(batch, args.device)
        if yimg.sum()==0: 
            logger.info('AAAAAA')
            continue
        xtrn = batch['emb'].to(args.device, dtype=torch.float)
        xtrn = torch.autograd.Variable(xtrn, requires_grad=True)
        yimg = torch.autograd.Variable(yimg)
        ystudy = torch.autograd.Variable(ystudy)
        with autocast():
            studylogits, imglogits = model(xtrn, masktrn)
            exam_loss, exam_wts = exam_lossfn(studylogits, ystudy)
            image_loss, image_wts = image_lossfn(imglogits, yimg, masktrn)
        loss = (exam_loss+image_loss)/(exam_wts+image_wts)
        scaler.scale(loss).backward()
        trnres.loss += (exam_loss+image_loss).item()
        trnres.wts += (exam_wts+image_wts).item()
        trnres.imgloss   += image_loss.item()
        trnres.imgwts    += image_wts.item()
        trnres.exmloss   += exam_loss.item()
        trnres.exmwts    += exam_wts
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        pbartrn.set_postfix({'train loss': trnres.loss/trnres.wts, 
                          'image loss': trnres.imgloss/trnres.imgwts, 
                          'exam loss': trnres.exmloss/trnres.exmwts})
        
        if step%100==0:
            torch.cuda.empty_cache()  

    #logger.info(f'Epoch {epoch} train loss all {trnres.loss/trnres.wts:.4f}')
    output_model_file = f'weights/exam_lstm_{wtsname}__epoch{epoch}.bin'
    torch.save(model.state_dict(), output_model_file)
    
    scheduler.step()
    model.eval()  
    valres = resultsfn()
    pbarval = tqdm(enumerate(valloader), 
                total = len(valdataset)//valloader.batch_size, 
                desc=f"Valid epoch {epoch}", ncols=0)
    for step, batch in pbarval:
        img_names, yimg, ystudy, maskval, lelabels = splitbatch(batch, args.device)
        if yimg.sum()==0: 
            logger.info('AAAAAA')
            continue
        xval = batch['emb'].to(args.device, dtype=torch.float)
        with torch.no_grad():
            studylogits, imglogits = model(xval, maskval)
            exam_loss, exam_wts = exam_lossfn(studylogits, ystudy)
            image_loss, image_wts = image_lossfn(imglogits, yimg, maskval)
        loss = (exam_loss+image_loss)/(exam_wts+image_wts)
        valres.loss += (exam_loss+image_loss).item()
        valres.wts += (exam_wts+image_wts).item()
        valres.imgloss   += image_loss.item()
        valres.imgwts    += image_wts.item()
        valres.exmloss   += exam_loss.item()
        valres.exmwts    += exam_wts
        # logger.info(f'{image_loss.item():.4f}\t{(img_w*qi*img_num).item():.4f}\t{exam_loss.item():.4f}\t{label_w.sum().item():.4f}\t')
        pbarval.set_postfix({'valid loss': valres.loss/valres.wts, 
                          'image loss': valres.imgloss/valres.imgwts, 
                          'exam loss': valres.exmloss/valres.exmwts})
    val_loss = valres.loss/valres.wts
    if best_val_loss>val_loss:
        output_model_file = f'weights/exam_lstm_{wtsname}__epoch{epoch}.bin'
        torch.save(model.state_dict(), output_model_file)
        best_val_loss=val_loss
        logger.info(f'Best Epoch {epoch} val loss all {best_val_loss:.4f}')
        

2020-10-12 21:13:11,235 - LSTM - INFO - Start training
Train epoch 0: 182it [03:47,  1.25s/it, train loss=0.386, image loss=0.357, exam loss=0.415]             
Valid epoch 0: 46it [00:56,  1.22s/it, valid loss=0.293, image loss=0.285, exam loss=0.301]            
2020-10-12 21:17:57,034 - LSTM - INFO - Best Epoch 0 val loss all 0.2929
Train epoch 1: 182it [03:44,  1.23s/it, train loss=0.251, image loss=0.238, exam loss=0.263]             
Valid epoch 1: 46it [00:55,  1.22s/it, valid loss=0.253, image loss=0.259, exam loss=0.248]            
2020-10-12 21:22:39,614 - LSTM - INFO - Best Epoch 1 val loss all 0.2532
Train epoch 2: 182it [03:45,  1.24s/it, train loss=0.225, image loss=0.224, exam loss=0.226]             
Valid epoch 2: 46it [00:55,  1.22s/it, valid loss=0.247, image loss=0.26, exam loss=0.234]             
2020-10-12 21:27:23,195 - LSTM - INFO - Best Epoch 2 val loss all 0.2468
Train epoch 3: 182it [03:46,  1.25s/it, train loss=0.218, image loss=0.22, exam loss=0.216]     

KeyboardInterrupt: 