In [1]:
# https://github.com/selimsef/dfdc_deepfake_challenge/blob/master/training/pipelines/train_classifier.py
import argparse
import json
import os
import glob
import pickle
import gc
import sys
import itertools
from collections import defaultdict, OrderedDict
import platform
PATH = '/Users/dhanley/Documents/rsnastr' \
        if platform.system() == 'Darwin' else '/data/rsnastr'
os.chdir(PATH)
sys.path.append(PATH)
import warnings
warnings.filterwarnings("ignore")
from sklearn.metrics import log_loss
from utils.logs import get_logger
from utils.utils import RSNAWEIGHTS, RSNA_CFG as CFG
from training.tools.config import load_config
import pandas as pd
import cv2

import torch
from torch.backends import cudnn
from torch.nn import DataParallel
from torch import nn
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast

from tqdm import tqdm
import torch.distributed as dist
from training.datasets.classifier_dataset import RSNASequenceDataset, collateseqfn, \
        valSeedSampler, examSampler
from training.zoo.sequence import SpatialDropout, LSTMNet
from training.tools.utils import create_optimizer, AverageMeter, collectPreds, collectLoss
from training.tools.utils import splitbatch, unmasklabels, unmasklogits
from training.losses import getLoss
from training import losses
from torch.optim.swa_utils import AveragedModel, SWALR
from tensorboardX import SummaryWriter

os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensor
logger = get_logger('LSTM', 'INFO') 

In [2]:
import sys; sys.argv=['']; del sys
logger.info('Load args')
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
arg('--workers', type=int, default=6, help='number of cpu threads to use')
arg('--device', type=str, default='cpu' if platform.system() == 'Darwin' else 'cuda', help='device for model - cpu/gpu')
arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
arg('--output-dir', type=str, default='weights/')
arg('--resume', type=str, default='')
arg('--fold', type=int, default=0)
arg('--batchsize', type=int, default=4)
arg('--lr', type=float, default = 0.00001)
arg('--lrgamma', type=float, default = 0.95)
arg('--labeltype', type=str, default='all') # or 'single'
arg('--dropout', type=float, default = 0.2)
arg('--prefix', type=str, default='classifier_')
arg('--data-dir', type=str, default="data")
arg('--folds-csv', type=str, default='folds.csv.gz')
arg('--nclasses', type=str, default=1)
arg('--crops-dir', type=str, default='jpegip')
arg('--lstm_units',   type=int, default=512)
arg('--epochs',   type=int, default=12)
arg('--nbags',   type=int, default=12)
arg('--label-smoothing', type=float, default=0.00)
arg('--logdir', type=str, default='logs/b2_1820')
arg("--local_rank", default=0, type=int)
arg("--seed", default=777, type=int)
args = parser.parse_args()

2020-10-04 16:04:36,041 - LSTM - INFO - Load args


In [3]:
args.lr=0.001   
args.label_smoothing=0.0  
args.device='cuda' 
args.fold=0 
args.batchsize=8
args.embrgx='weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_*__fold*_best_dice__hflip0_transpose0_size320.emb' 
args.embrgx='weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_*__fold*_best_dice__all_size320.emb'
datals = sorted(glob.glob(f'emb/{args.embrgx}*data.pk'))
datals[0]

'emb/weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_0__fold0_best_dice__all_size320.emb.data.pk'

In [4]:
def takeimg(s):
    return s.split('/')[-1].replace('.jpg', '')
f=datals[0]
logger.info(f'File load : {f}')
dfname, embname, imgnm = f, f.replace('.data.pk', '.npz'), f.replace('.data.pk', '.imgnames.pk')
datadf = pd.read_pickle(dfname)
embmat = np.load(embname)['arr_0']

2020-10-04 16:04:36,124 - LSTM - INFO - File load : emb/weights/classifier_RSNAClassifier_tf_efficientnet_b5_ns_04d_0__fold0_best_dice__all_size320.emb.data.pk


In [5]:
imgls = list(map(takeimg, pickle.load( open( imgnm, "rb" ) )))
datadf = pd.read_csv(f'{args.data_dir}/train.csv.zip')
datadf = datadf.set_index('SOPInstanceUID').loc[imgls].reset_index()
folddf = pd.read_csv(f'{args.data_dir}/{args.folds_csv}')

In [6]:
print(datadf.shape)
datadf.iloc[0]

(1790593, 17)


SOPInstanceUID                c0f3cb036d06
StudyInstanceUID              6897fa9de148
SeriesInstanceUID             2bfbb7fd2e8b
pe_present_on_image                      0
negative_exam_for_pe                     0
qa_motion                                0
qa_contrast                              0
flow_artifact                            0
rv_lv_ratio_gte_1                        0
rv_lv_ratio_lt_1                         1
leftsided_pe                             1
chronic_pe                               0
true_filling_defect_not_pe               0
rightsided_pe                            1
acute_and_chronic_pe                     0
central_pe                               0
indeterminate                            0
Name: 0, dtype: object

In [7]:
logger.info('Create traindatasets')
trndataset = RSNASequenceDataset(datadf, 
                                   embmat, 
                                   folddf,
                                   mode="train",
                                   imgclasses=CFG["image_target_cols"],
                                   studyclasses=CFG['exam_target_cols'],
                                   fold=args.fold,
                                   label_smoothing=args.label_smoothing,
                                   folds_csv=args.folds_csv)
logger.info('Create valdatasets')
valdataset = RSNASequenceDataset(datadf, 
                                   embmat, 
                                   folddf,
                                   mode="valid",
                                   imgclasses=CFG["image_target_cols"],
                                   studyclasses=CFG['exam_target_cols'],
                                   fold=args.fold,
                                   label_smoothing=args.label_smoothing,
                                   folds_csv=args.folds_csv)

2020-10-04 16:09:02,386 - LSTM - INFO - Create traindatasets
2020-10-04 16:09:04,713 - LSTM - INFO - Create valdatasets


In [8]:
logger.info('Create loaders...')
valloader = DataLoader(valdataset, batch_size=args.batchsize*8, shuffle=False, num_workers=4, collate_fn=collateseqfn)
embed_size = embmat.shape[1]
# del embmat
gc.collect()

logger.info('Create model')
model = LSTMNet(embed_size, 
                       nimgclasses = len(CFG["image_target_cols"]), 
                       nstudyclasses = len(CFG['exam_target_cols']),
                       LSTM_UNITS=args.lstm_units, 
                       DO = args.dropout)
model = model.to(args.device)
DECAY = 0.0
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
plist = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': DECAY},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = torch.optim.Adam(plist, lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=args.lrgamma, last_epoch=-1)

2020-10-04 16:09:06,469 - LSTM - INFO - Create loaders...
2020-10-04 16:09:06,612 - LSTM - INFO - Create model


In [9]:
ypredls = []
ypredtstls = []
scaler = torch.cuda.amp.GradScaler()
bce_func_exam = torch.nn.BCEWithLogitsLoss(reduction='none', 
                    weight = torch.tensor(CFG['exam_weights']).to(args.device))
bce_func_img = torch.nn.BCEWithLogitsLoss(reduction='none')

In [10]:
def groupBy(samples, labels, unique_labels, labels_count, grptype = 'mean'):
    res = torch.zeros_like(unique_labels, dtype=torch.float).scatter_add_(0, labels, samples)
    if grptype == 'sum':
        return res
    if grptype == 'mean':
        res = res / labels_count.float().unsqueeze(1)
        return res
    
def rsna_criterion_all(y_pred_exam_, 
                   y_true_exam_, 
                   y_pred_img_, 
                   y_true_img_,
                   le_study, 
                   img_wt):
    # Groupby 
    labels = le_study.view(le_study.size(0), 1).expand(-1, 1)
    unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
    
    #logger.info('Exam loss')
    exam_loss = bce_func_exam(y_pred_exam_, y_true_exam_)
    exam_loss = exam_loss.sum(1).unsqueeze(1)
    exam_loss = groupBy(exam_loss, labels, unique_labels, labels_count, grptype = 'mean').sum()
    exam_wts = torch.tensor(le_study.unique().shape[0]).float()
    
    #logger.info('Image loss')
    image_loss = bce_func_img(y_pred_img_, y_true_img_)
    image_loss = groupBy(image_loss, labels, unique_labels, labels_count, grptype = 'sum')
    qi_all = groupBy(y_true_img_, labels, unique_labels, labels_count, grptype = 'mean')
    image_loss = (img_wt * qi_all * image_loss).sum()
    img_wts = (img_wt * y_true_img_).sum()
    
    #logger.info('Final loss')
    img_loss_out =  image_loss / img_wts
    exam_loss_out = exam_loss / exam_wts
    final_loss = (image_loss + exam_loss)/(img_wts + exam_wts)
    return final_loss , img_loss_out, exam_loss_out

In [None]:
logger.info('Start training')
for epoch in range(args.epochs):
    examsampler = examSampler(trndataset.datadf, trndataset.folddf)
    trnloader = DataLoader(trndataset, batch_size=args.batchsize, sampler = examsampler, num_workers=4, collate_fn=collateseqfn)
    for param in model.parameters():
        param.requires_grad = True
    model.train()  
    img_wt = torch.tensor(CFG['image_weight']).to(args.device, dtype=torch.float)
    trncollect = collectPreds()
    valcollect = collectPreds()
    trnloss = collectLoss(trnloader, mode = 'train')
    valloss = collectLoss(valloader, mode = 'valid')
    logger.info(50*'-')
    for step, batch in enumerate(trnloader):
        img_names, yimg, ystudy, masktrn, lelabels = splitbatch(batch, args.device)
        if yimg.sum()==0: 
            logger.info('AAAAAA')
            continue
        xtrn = batch['emb'].to(args.device, dtype=torch.float)
        xtrn = torch.autograd.Variable(xtrn, requires_grad=True)
        yimg = torch.autograd.Variable(yimg)
        ystudy = torch.autograd.Variable(ystudy)
        with autocast():
            studylogits, imglogits = model(xtrn, masktrn)#.to(args.device, dtype=torch.float)
            yimg, ystudy, lelabels, img_names = unmasklabels(yimg, ystudy, lelabels, img_names, masktrn)
            imglogits, studylogits = unmasklogits(imglogits, studylogits, masktrn)
            # Loss function
            loss, img_loss, exam_loss = rsna_criterion_all(studylogits, 
                                                       ystudy, 
                                                       imglogits, 
                                                       yimg, 
                                                       lelabels, 
                                                       img_wt)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        
        trncollect.append(img_names, lelabels, imglogits, studylogits, yimg, ystudy)
        trnloss.increment(loss, img_loss, exam_loss)
        if step % 50==0: logger.info(f'LOOP:{trnloss.log()}')
    trn_loss, trn_img_loss, trn_exam_loss = rsna_criterion_all(*trncollect.concat(args.device), img_wt)
    logger.info(f'Train loss all {trn_loss:.4f} img {trn_img_loss:.4f} exam {trn_exam_loss:.4f}')
    logger.info(50*'-')
    scheduler.step()
    logger.info('Prep test sub...')
    model.eval()
    for step, batch in enumerate(valloader):
        img_names, yimg, ystudy, maskval, lelabels = splitbatch(batch, args.device)
        xval = batch['emb'].to(args.device, dtype=torch.float)
        studylogits, imglogits = model(xval, maskval)#.to(args.device, dtype=torch.float)
        # Repeat studies to have a prediction for every image
        yimg, ystudy, lelabels, img_names = unmasklabels(yimg, ystudy, lelabels, img_names, maskval)
        imglogits, studylogits = unmasklogits(imglogits, studylogits, maskval)
        loss, img_loss, exam_loss = rsna_criterion_all(studylogits, 
                                                   ystudy, 
                                                   imglogits, 
                                                   yimg, 
                                                   lelabels, 
                                                   img_wt)
        valcollect.append(img_names, lelabels, imglogits, studylogits, yimg, ystudy)
        valloss.increment(loss, img_loss, exam_loss)
    val_loss, val_img_loss, val_exam_loss = rsna_criterion_all(*valcollect.concat(args.device), img_wt)
    logger.info(f'Valid loss all {val_loss:.4f} img {val_img_loss:.4f} exam {val_exam_loss:.4f}')

2020-10-04 16:13:09,892 - LSTM - INFO - Start training
2020-10-04 16:13:11,545 - LSTM - INFO - --------------------------------------------------
2020-10-04 16:13:14,809 - LSTM - INFO - LOOP:train step 2 of 728 trn loss 0.1701 img loss 0.0855 exam loss 0.1966
2020-10-04 16:13:32,885 - LSTM - INFO - LOOP:train step 52 of 728 trn loss 0.2964 img loss 0.2027 exam loss 0.3766
2020-10-04 16:13:52,436 - LSTM - INFO - LOOP:train step 102 of 728 trn loss 0.2941 img loss 0.2012 exam loss 0.3697
2020-10-04 16:14:10,381 - LSTM - INFO - LOOP:train step 152 of 728 trn loss 0.2994 img loss 0.2134 exam loss 0.3721
2020-10-04 16:14:29,425 - LSTM - INFO - LOOP:train step 202 of 728 trn loss 0.2990 img loss 0.2134 exam loss 0.3753
2020-10-04 16:14:48,991 - LSTM - INFO - LOOP:train step 252 of 728 trn loss 0.3005 img loss 0.2162 exam loss 0.3758
2020-10-04 16:15:06,852 - LSTM - INFO - LOOP:train step 302 of 728 trn loss 0.3021 img loss 0.2185 exam loss 0.3762
2020-10-04 16:15:24,011 - LSTM - INFO - LOOP:

In [None]:
print(trncollect.series('img_names').shape)
print(trncollect.series('study_labels').shape)

In [None]:
def myseries(collect, feat):
    return torch.sigmoid(collect.series(feat).float().cpu()).numpy()

trnpreds = pd.DataFrame(np.concatenate((myseries(trncollect, 'img_preds'), 
                    myseries(trncollect, 'study_preds')), 1),
                    columns = CFG['image_target_cols']+CFG['exam_target_cols'], 
                    index = trncollect.series('img_names'))
valpreds = pd.DataFrame(np.concatenate((myseries(valcollect, 'img_preds'), 
                    myseries(valcollect, 'study_preds')), 1),
                    columns = CFG['image_target_cols']+CFG['exam_target_cols'], 
                    index = valcollect.series('img_names'))   
valpreds.hist(figsize = (25,25))

In [None]:
trnpreds.hist(figsize = (25,25))


In [None]:
valpreds.tail(1)

In [None]:
print(valpreds.tail(1).index)
print(xval[-1][-1])

In [None]:
print(trnpreds.tail(1).index)
print(xtrn[-1][-1])

In [None]:
f = datals[0]
logger.info(f'File load : {f}')
dfname, embname, imgnm = f, f.replace('.data.pk', '.npz'), f.replace('.data.pk', '.imgnames.pk')
datadf = pd.read_pickle(dfname)
embmat = np.load(embname)['arr_0']
embmat = embmat[np.where(datadf.SOPInstanceUID=='32ec58af84d9')[0]]
embmat

In [None]:
f = datals[1]
logger.info(f'File load : {f}')
dfname, embname, imgnm = f, f.replace('.data.pk', '.npz'), f.replace('.data.pk', '.imgnames.pk')
datadf = pd.read_pickle(dfname)
embmat = np.load(embname)['arr_0']
embmat = embmat[np.where(datadf.SOPInstanceUID=='dba5502f56b7')[0]]
embmat

In [None]:
masktrn

In [None]:
maskval

In [None]:
xval.mean(2).mean()

In [None]:
xtrn.mean(2).mean()