In [1]:
# https://github.com/selimsef/dfdc_deepfake_challenge/blob/master/training/pipelines/train_classifier.py
import argparse
import json
import os
import glob
import pickle
import gc
import sys
import itertools
from collections import defaultdict, OrderedDict
import platform
PATH = '/Users/dhanley/Documents/rsnastr' \
        if platform.system() == 'Darwin' else '/data/rsnastr'
os.chdir(PATH)
sys.path.append(PATH)
import warnings
warnings.filterwarnings("ignore")
from sklearn.metrics import log_loss
from utils.logs import get_logger
from utils.utils import RSNAWEIGHTS, RSNA_CFG as CFG
from training.tools.config import load_config
import pandas as pd
import cv2


import torch
from torch.backends import cudnn
from torch.nn import DataParallel
from torch import nn
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast


from tqdm import tqdm
import torch.distributed as dist
from training.datasets.classifier_dataset import RSNASequenceDataset, collateseqfn, \
        valSeedSampler
from training.zoo.sequence import SpatialDropout, LSTMNet
from training.tools.utils import create_optimizer, AverageMeter
from training.losses import getLoss
from training import losses
from torch.optim.swa_utils import AveragedModel, SWALR
from tensorboardX import SummaryWriter

os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensor
logger = get_logger('LSTM', 'INFO') 

In [2]:
import sys; sys.argv=['']; del sys
logger.info('Load args')
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
arg('--workers', type=int, default=6, help='number of cpu threads to use')
arg('--device', type=str, default='cpu' if platform.system() == 'Darwin' else 'cuda', help='device for model - cpu/gpu')
arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
arg('--output-dir', type=str, default='weights/')
arg('--resume', type=str, default='')
arg('--fold', type=int, default=0)
arg('--batchsize', type=int, default=4)
arg('--lr', type=float, default = 0.00001)
arg('--lrgamma', type=float, default = 0.95)
arg('--labeltype', type=str, default='all') # or 'single'
arg('--dropout', type=float, default = 0.2)
arg('--prefix', type=str, default='classifier_')
arg('--data-dir', type=str, default="data")
arg('--folds-csv', type=str, default='folds.csv.gz')
arg('--nclasses', type=str, default=1)
arg('--crops-dir', type=str, default='jpegip')
arg('--lstm_units',   type=int, default=512)
arg('--epochs',   type=int, default=12)
arg('--nbags',   type=int, default=12)
arg('--label-smoothing', type=float, default=0.00)
arg('--logdir', type=str, default='logs/b2_1820')
arg("--local_rank", default=0, type=int)
arg("--seed", default=777, type=int)
args = parser.parse_args()

2020-10-03 22:20:27,501 - LSTM - INFO - Load args


In [3]:
args.lr=0.0001   
args.label_smoothing=0.0  
args.device='cuda' 
args.fold=0 
args.batchsize=4
args.embrgx='weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_*__fold*_best_dice__hflip0_transpose0_size320.emb' 


In [4]:
def takeimg(s):
    return s.split('/')[-1].replace('.jpg', '')
#embrgx = 'classifier_RSNAClassifier_resnext101_32x8d_*__fold*_epoch24__hflip*_transpose0_size320.emb'
#embrgx = 'classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_*__fold*_epoch24__hflip0_transpose0_size320.emb'
datals = sorted(glob.glob(f'emb/{args.embrgx}*data.pk'))
imgls = []
for i, f in enumerate(datals):
    logger.info(f'File load : {f}')
    dfname, embname, imgnm = f, f.replace('.data.pk', '.npz'), f.replace('.data.pk', '.imgnames.pk')
    if i == 0:
        datadf = pd.read_pickle(dfname)
        embmat = np.load(embname)['arr_0']
    if i>0:
        embmat = np.append( embmat, np.load(embname)['arr_0'], 0)
        datadf = pd.concat([datadf, pd.read_pickle(dfname)], 0)
    imgls += list(map(takeimg, pickle.load( open( imgnm, "rb" ) )))
    logger.info(f'Embedding shape : {embmat.shape}')
    logger.info(f'DataFrame shape : {datadf.shape}')
    logger.info(f'DataFrame shape : {len(imgls)}')
    gc.collect()
folddf = pd.read_csv(f'{args.data_dir}/{args.folds_csv}')
datadf = datadf.set_index('SOPInstanceUID').loc[imgls].reset_index()
datadf.iloc[0]

2020-10-03 22:20:27,528 - LSTM - INFO - File load : emb/weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_0__fold0_best_dice__hflip0_transpose0_size320.emb.data.pk
2020-10-03 22:20:51,268 - LSTM - INFO - Embedding shape : (359209, 2048)
2020-10-03 22:20:51,269 - LSTM - INFO - DataFrame shape : (359209, 17)
2020-10-03 22:20:51,269 - LSTM - INFO - DataFrame shape : 359209
2020-10-03 22:20:51,338 - LSTM - INFO - File load : emb/weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_1__fold1_best_dice__hflip0_transpose0_size320.emb.data.pk
2020-10-03 22:21:16,922 - LSTM - INFO - Embedding shape : (715642, 2048)
2020-10-03 22:21:16,923 - LSTM - INFO - DataFrame shape : (715642, 17)
2020-10-03 22:21:16,923 - LSTM - INFO - DataFrame shape : 715642
2020-10-03 22:21:16,996 - LSTM - INFO - File load : emb/weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_2__fold2_best_dice__hflip0_transpose0_size320.emb.data.pk
2020-10-03 22:21:44,379 - LSTM - INFO - Embedding shape : (1

SOPInstanceUID                c0f3cb036d06
StudyInstanceUID              6897fa9de148
SeriesInstanceUID             2bfbb7fd2e8b
pe_present_on_image                      0
negative_exam_for_pe                     0
qa_motion                                0
qa_contrast                              0
flow_artifact                            0
rv_lv_ratio_gte_1                        0
rv_lv_ratio_lt_1                         1
leftsided_pe                             1
chronic_pe                               0
true_filling_defect_not_pe               0
rightsided_pe                            1
acute_and_chronic_pe                     0
central_pe                               0
indeterminate                            0
Name: 0, dtype: object

In [5]:
logger.info('Create traindatasets')
trndataset = RSNASequenceDataset(datadf, 
                                   embmat, 
                                   folddf,
                                   mode="train",
                                   imgclasses=CFG["image_target_cols"],
                                   studyclasses=CFG['exam_target_cols'],
                                   fold=args.fold,
                                   label_smoothing=args.label_smoothing,
                                   folds_csv=args.folds_csv)
logger.info('Create valdatasets')
valdataset = RSNASequenceDataset(datadf, 
                                   embmat, 
                                   folddf,
                                   mode="valid",
                                   imgclasses=CFG["image_target_cols"],
                                   studyclasses=CFG['exam_target_cols'],
                                   fold=args.fold,
                                   label_smoothing=args.label_smoothing,
                                   folds_csv=args.folds_csv)

logger.info('Create loaders...')
trnloader = DataLoader(trndataset, batch_size=args.batchsize, shuffle=True, num_workers=4, collate_fn=collateseqfn)
valloader = DataLoader(valdataset, batch_size=args.batchsize*8, shuffle=False, num_workers=4, collate_fn=collateseqfn)
embed_size = embmat.shape[1]
del embmat
gc.collect()

2020-10-03 22:22:47,454 - LSTM - INFO - Create traindatasets
2020-10-03 22:22:48,530 - LSTM - INFO - Create valdatasets
2020-10-03 22:22:49,407 - LSTM - INFO - Create loaders...


16

In [22]:
logger.info('Create model')
model = LSTMNet(embed_size, 
                       nimgclasses = len(CFG["image_target_cols"]), 
                       nstudyclasses = len(CFG['exam_target_cols']),
                       LSTM_UNITS=args.lstm_units, 
                       DO = args.dropout)
model = model.to(args.device)
DECAY = 0.0
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
plist = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': DECAY},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = torch.optim.Adam(plist, lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=args.lrgamma, last_epoch=-1)

2020-10-03 22:35:32,897 - LSTM - INFO - Create model


In [23]:
ypredls = []
ypredtstls = []
scaler = torch.cuda.amp.GradScaler()
bce_func_exam = torch.nn.BCEWithLogitsLoss(reduction='none', 
                    weight = torch.tensor(CFG['exam_weights']).to(args.device))
bce_func_img = torch.nn.BCEWithLogitsLoss(reduction='none')

In [24]:
def groupBy(samples, labels, unique_labels, labels_count, grptype = 'mean'):
    res = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, samples)
    if grptype == 'sum':
        return res
    if grptype == 'mean':
        res = res / labels_count.float().unsqueeze(1)
        return res

def rsna_criterion(y_pred_exam_, 
                   y_true_exam_, 
                   y_pred_img_, 
                   y_true_img_,
                   le_study, 
                   img_wt,
                   verbose = False):
    # Groupby 
    labels = le_study.view(le_study.size(0), 1).expand(-1, 1)
    unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
    
    #logger.info('Exam loss')
    exam_loss = bce_func_exam(y_pred_exam_, y_true_exam_)
    exam_loss = exam_loss.sum(1).unsqueeze(1)
    exam_loss = groupBy(exam_loss, labels, unique_labels, labels_count, grptype = 'mean').sum()
    exam_wts = torch.tensor(le_study.unique().shape[0]).float()
    
    
    #logger.info('Image loss')
    image_loss = bce_func_img(y_pred_img_, y_true_img_)
    image_loss = groupBy(image_loss, labels, unique_labels, labels_count, grptype = 'sum')
    
    qi_all = groupBy(y_true_img_, labels, unique_labels, labels_count, grptype = 'mean')
    
    img_wts = img_wt * (y_true_img_).sum()
    #if verbose and (img_wts==0):
    #    logger.info((qi_all * image_loss.detach()).sum())
        
    img_loss = img_wt * (qi_all * image_loss.detach()).sum()
            
    #logger.info('Final loss')
    img_loss_out = img_loss if (img_loss == img_wts == 0) else img_loss / img_wts
    exam_loss_out = exam_loss / exam_wts
    final_loss = (img_loss + exam_loss)/(img_wts + exam_wts)
    if verbose:
        log = f'Final loss {final_loss:.3f} img loss {img_loss_out:.3f} wt {img_wts:.3f}'
        log += f' exam loss {exam_loss_out:.3f} wt {exam_wts:.3f}'
        logger.info(log)
        logger.info(50*'-')
        
    return final_loss, img_loss_out, exam_loss_out


def splitbatch(batch, device):
    img_names = batch['img_name']
    yimg = batch['imglabels'].to(args.device, dtype=torch.float)
    ystudy = batch['studylabels'].to(args.device, dtype=torch.float)
    mask = batch['mask'].to(args.device, dtype=torch.int)
    lelabels = batch['lelabels'].to(args.device, dtype=torch.int64)
    return img_names, yimg, ystudy, mask, lelabels

def unmasklabels(yimg, ystudy, lelabels, img_names, mask):
    ystudy = ystudy.unsqueeze(2).repeat(1, 1, imglogits.size(1))
    ystudy = ystudy.transpose(2, 1)
    # get the mask for masked img labels
    maskidx = mask.view(-1)==1
    # Flatten them all along batch and seq dimension and remove masked values
    yimg = yimg.view(-1, 1)[maskidx]
    ystudy = ystudy.reshape(-1, ystudy.size(-1))[maskidx]
    lelabels = lelabels.view(-1, 1)[maskidx] 
    lelabels = lelabels.flatten()
    img_names = img_names.flatten()[maskidx.detach().cpu().numpy()]
    return yimg, ystudy, lelabels, img_names
    
def unmasklogits(imglogits, studylogits, mask):
    imglogits = imglogits.squeeze()
    studylogits = studylogits.unsqueeze(2).repeat(1, 1, imglogits.size(1))
    # get the mask for masked img labels
    maskidx = mask.view(-1)==1
    # Flatten them all along batch and seq dimension and remove masked values
    imglogits = imglogits.view(-1, 1)[maskidx]
    studylogits = studylogits.reshape(-1, ystudy.size(-1))[maskidx]
    return imglogits, studylogits

class collectPreds:
    def __init__(self):
        self.lelabelsls = []
        self.imgnamesls = []
        self.imgpredsls = []
        self.imglabells = []
        self.studylabells = []
        self.studypredsls = []
        self.maxlelabel = 0

    def append(self, img_names, lelabels, imgpreds, studypreds, yimg, ystudy):
        lelabels = lelabels.detach().cpu()
        if len(self.lelabelsls)>0:
            increment = self.lelabelsls[-1].max() + torch.tensor(1).cpu()
            lelabels = lelabels + increment
        self.lelabelsls.append(lelabels)
        self.imgpredsls.append(imglogits.detach().cpu())
        self.imglabells.append(yimg.detach().cpu())
        self.studylabells.append(ystudy.detach().cpu())
        self.studypredsls.append(studylogits.detach().cpu())
        self.imgnamesls.append(img_names)

    def concat(self, device):
        lelabels = torch.cat(self.lelabelsls).to(device)
        imgpreds = torch.cat(self.imgpredsls).to(device)
        imglabels = torch.cat(self.imglabells).to(device)
        studylabels = torch.cat(self.studylabells).to(device)
        studypreds = torch.cat(self.studypredsls).to(device)
        return studypreds, studylabels, imgpreds, imglabels, lelabels
    
    def series(self, series):
        if series=='lelabels': return torch.cat(self.lelabelsls)
        if series=='img_preds': return torch.cat(self.imgpredsls)
        if series=='img_labels': return torch.cat(self.imglabells)
        if series=='study_labels': return torch.cat(self.studylabells)
        if series=='study_preds': return torch.cat(self.studypredsls)
        if series=='img_names': return np.concatenate(self.imgnamesls)

class collectLoss:
    def __init__(self, loader, mode = 'train'):
        self.mode = mode
        self.loss = 0.
        self.img_loss = 0.
        self.exam_loss = 0.
        self.step = 1
        self.loaderlen = len(loader)

    def increment(self, loss, img_loss, exam_loss):
        self.loss += loss.item()
        self.img_loss += img_loss.item()
        self.exam_loss += exam_loss.item()
        self.step += 1

    def log(self):
        logs = f'{self.mode} step {self.step} of {self.loaderlen} trn loss {(self.loss/(self.step)):.4f} '
        logs += f'img loss {(self.img_loss/(self.step)):.4f} exam loss {(self.exam_loss/(self.step)):.4f}'
        return logs

In [27]:

def rsna_criterion_all(y_pred_exam_, 
                   y_true_exam_, 
                   y_pred_img_, 
                   y_true_img_,
                   le_study, 
                   img_wt):
    # Groupby 
    labels = le_study.view(le_study.size(0), 1).expand(-1, 1)
    unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
    
    logger.info('Exam loss')
    exam_loss = bce_func_exam(y_pred_exam_, y_true_exam_)
    exam_loss = exam_loss.sum(1).unsqueeze(1)
    exam_loss = groupBy(exam_loss, labels, unique_labels, labels_count, grptype = 'mean').sum()
    exam_wts = torch.tensor(le_study.unique().shape[0]).float()
    
    logger.info('Image loss')
    image_loss = bce_func_img(y_pred_img_, y_true_img_)
    image_loss = groupBy(image_loss, labels, unique_labels, labels_count, grptype = 'sum')
    qi_all = groupBy(y_true_img_, labels, unique_labels, labels_count, grptype = 'mean')
    image_loss = (img_wt * qi_all * image_loss).sum()
    img_wts = (img_wt * y_true_img_).sum()
    
    logger.info('Final loss')
    img_loss_out =  image_loss / img_wts
    exam_loss_out = exam_loss / exam_wts
    final_loss = (image_loss + exam_loss)/(img_wts + exam_wts)
    return final_loss , img_loss_out, exam_loss_out


def rsna_criterion(y_pred_exam_, 
                   y_true_exam_, 
                   y_pred_img_, 
                   y_true_img_,
                   le_study, 
                   img_wt,
                   verbose = False):
    #logger.info('Exam loss')
    exam_loss = bce_func_exam(y_pred_exam_, y_true_exam_).mean()
    #logger.info('Image loss')
    img_loss = bce_func_img(y_pred_img_, y_true_img_).mean()
    
    final_loss = (exam_loss + img_loss)/2
        
    return final_loss, img_loss, exam_loss


logger.info('Start training')
for epoch in range(args.epochs):
    for param in model.parameters():
        param.requires_grad = True
    model.train()  
    img_wt = torch.tensor(CFG['image_weight']).to(args.device, dtype=torch.float)
    trncollect = collectPreds()
    valcollect = collectPreds()
    trnloss = collectLoss(trnloader, mode = 'train')
    valloss = collectLoss(valloader, mode = 'valid')
    logger.info(50*'-')
    for step, batch in enumerate(trnloader):
        img_names, yimg, ystudy, masktrn, lelabels = splitbatch(batch, args.device)
        #if yimg.sum()==0: continue
        xtrn = batch['emb'].to(args.device, dtype=torch.float)
        xtrn = torch.autograd.Variable(xtrn, requires_grad=True)
        yimg = torch.autograd.Variable(yimg)
        ystudy = torch.autograd.Variable(ystudy)
        with autocast():
            studylogits, imglogits = model(xtrn, masktrn)#.to(args.device, dtype=torch.float)
            yimg, ystudy, lelabels, img_names = unmasklabels(yimg, ystudy, lelabels, img_names, masktrn)
            imglogits, studylogits = unmasklogits(imglogits, studylogits, masktrn)
            # Loss function
            loss, img_loss, exam_loss = rsna_criterion(studylogits, 
                                                       ystudy, 
                                                       imglogits, 
                                                       yimg, 
                                                       lelabels, 
                                                       img_wt,
                                                       verbose = False)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        
        trncollect.append(img_names, lelabels, imglogits, studylogits, yimg, ystudy)
        trnloss.increment(loss, img_loss, exam_loss)
        if step % 50==0: logger.info(f'LOOP:{trnloss.log()}')
    trn_loss, trn_img_loss, trn_exam_loss = rsna_criterion_all(*trncollect.concat(args.device), img_wt)
    logger.info(f'Train loss all {trn_loss:.4f} img {trn_img_loss:.4f} exam {trn_exam_loss:.4f}')
    logger.info(50*'-')
    scheduler.step()
    logger.info('Prep test sub...')
    model.eval()
    for step, batch in enumerate(valloader):
        img_names, yimg, ystudy, maskval, lelabels = splitbatch(batch, args.device)
        xval = batch['emb'].to(args.device, dtype=torch.float)
        studylogits, imglogits = model(xval, maskval)#.to(args.device, dtype=torch.float)
        # Repeat studies to have a prediction for every image
        yimg, ystudy, lelabels, img_names = unmasklabels(yimg, ystudy, lelabels, img_names, maskval)
        imglogits, studylogits = unmasklogits(imglogits, studylogits, maskval)
        loss, img_loss, exam_loss = rsna_criterion(studylogits, 
                                                   ystudy, 
                                                   imglogits, 
                                                   yimg, 
                                                   lelabels, 
                                                   img_wt)
        valcollect.append(img_names, lelabels, imglogits, studylogits, yimg, ystudy)
        valloss.increment(loss, img_loss, exam_loss)
    val_loss, val_img_loss, val_exam_loss = rsna_criterion_all(*valcollect.concat(args.device), img_wt)
    logger.info(f'Valid loss all {val_loss:.4f} img {val_img_loss:.4f} exam {val_exam_loss:.4f}')

2020-10-03 22:41:15,203 - LSTM - INFO - Start training
2020-10-03 22:41:15,212 - LSTM - INFO - --------------------------------------------------
2020-10-03 22:41:16,665 - LSTM - INFO - LOOP:train step 2 of 1456 trn loss 0.0914 img loss 0.1514 exam loss 0.0315
2020-10-03 22:41:24,073 - LSTM - INFO - LOOP:train step 52 of 1456 trn loss 0.0651 img loss 0.0884 exam loss 0.0419
2020-10-03 22:41:31,331 - LSTM - INFO - LOOP:train step 102 of 1456 trn loss 0.0656 img loss 0.0888 exam loss 0.0424
2020-10-03 22:41:38,796 - LSTM - INFO - LOOP:train step 152 of 1456 trn loss 0.0665 img loss 0.0914 exam loss 0.0415
2020-10-03 22:41:46,290 - LSTM - INFO - LOOP:train step 202 of 1456 trn loss 0.0682 img loss 0.0945 exam loss 0.0418
2020-10-03 22:41:53,868 - LSTM - INFO - LOOP:train step 252 of 1456 trn loss 0.0671 img loss 0.0927 exam loss 0.0415
2020-10-03 22:42:01,407 - LSTM - INFO - LOOP:train step 302 of 1456 trn loss 0.0661 img loss 0.0908 exam loss 0.0413
2020-10-03 22:42:08,963 - LSTM - INFO 

2020-10-03 22:49:28,178 - LSTM - INFO - Prep test sub...


KeyboardInterrupt: 

In [None]:
print(trncollect.series('img_names').shape)
print(trncollect.series('study_labels').shape)

In [None]:
def myseries(collect, feat):
    return torch.sigmoid(collect.series(feat).float().cpu()).numpy()

trnpreds = pd.DataFrame(np.concatenate((myseries(trncollect, 'img_preds'), 
                    myseries(trncollect, 'study_preds')), 1),
                    columns = CFG['image_target_cols']+CFG['exam_target_cols'], 
                    index = trncollect.series('img_names'))
valpreds = pd.DataFrame(np.concatenate((myseries(valcollect, 'img_preds'), 
                    myseries(valcollect, 'study_preds')), 1),
                    columns = CFG['image_target_cols']+CFG['exam_target_cols'], 
                    index = valcollect.series('img_names'))   
valpreds.hist(figsize = (25,25))

In [None]:
trnpreds.hist(figsize = (25,25))


In [None]:
valpreds.tail(1)

In [None]:
print(valpreds.tail(1).index)
print(xval[-1][-1])

In [None]:
print(trnpreds.tail(1).index)
print(xtrn[-1][-1])

In [None]:
f = datals[0]
logger.info(f'File load : {f}')
dfname, embname, imgnm = f, f.replace('.data.pk', '.npz'), f.replace('.data.pk', '.imgnames.pk')
datadf = pd.read_pickle(dfname)
embmat = np.load(embname)['arr_0']
embmat = embmat[np.where(datadf.SOPInstanceUID=='32ec58af84d9')[0]]
embmat

In [None]:
f = datals[1]
logger.info(f'File load : {f}')
dfname, embname, imgnm = f, f.replace('.data.pk', '.npz'), f.replace('.data.pk', '.imgnames.pk')
datadf = pd.read_pickle(dfname)
embmat = np.load(embname)['arr_0']
embmat = embmat[np.where(datadf.SOPInstanceUID=='dba5502f56b7')[0]]
embmat

In [None]:
masktrn

In [None]:
maskval

In [None]:
xval.mean(2).mean()

In [None]:
xtrn.mean(2).mean()