In [None]:
import warnings
warnings.filterwarnings(action='once')
import pandas as pd
import numpy as np
import matplotlib.pylab as plt
import seaborn as sns
import datetime
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import pickle
%load_ext autoreload
%autoreload 2
%matplotlib inline
from exp.misc import *
from exp.ProcessData import *
from exp.PytorchModels import *
from exp.LearnerClass import *
import torch
import torch.nn as nn
import torch.utils.data as D
import torch.nn.functional as F
import copy
from torchvision import transforms
import PIL.Image
from sklearn.metrics import roc_auc_score
import torchvision.transforms.functional as TF
from types import MethodType
import sandesh
import pydicom

In [None]:
params=json_to_parameters('config.json')
num_folds=5
SEED=220
folds=[1]
device = device_by_name('RTX')
add_seed=220221
model_type='tf_efficientnet_b5_ns'
basic_version = 'pdint'
basic_name_tamplate='image_{}'
num_heads=2
num_layers=4
ffdim=2048
dropout=0.1
name_tamplate='transformer'+f'_{num_heads}_{num_layers}_{ffdim}_'+'{}'
model_version='pdint128n03f02'
gamma=0.2
fnoise=0.3
val_max_len=512
train_max_len=128
torch.cuda.set_device(device)

In [None]:
df = pd.read_csv(params.path.data+'full_train.csv')

In [None]:
def masked_bce_with_logits(y_pred,y_true,mask=None,reduction='mean',gamma=0.):
    if mask is None:
        out = F.binary_cross_entropy_with_logits(y_pred,y_true)
    else:
        yp=torch.where(mask,y_pred,-20*torch.ones_like(y_pred))
        yt=torch.where(mask,y_true,torch.zeros_like(y_pred))
        if gamma==0:
            loss = F.binary_cross_entropy_with_logits(yp,yt,reduction='none')
        else:
            loss = binary_focal_loss(yp,yt,gamma=gamma,reduction='none')
        if reduction=='mean':
            out=loss.sum()/mask.sum()
        if reduction=='sum':
            out=loss.sum()
        else:
            out=loss
    return out

In [None]:
cols=  ['rv_lv_ratio_gte_1',
        'rv_lv_ratio_lt_1',
        'leftsided_pe',
        'chronic_pe',
        'negative_exam_for_pe',
        'rightsided_pe',
        'acute_and_chronic_pe',
        'central_pe',
        'indeterminate']

In [None]:
OutputMap={'dummy0':0,
           'dummy1':1,
           'dummy2':2,
           'dummy3':3,
           'true_filling_defect_not_pe':0,
           'qa_motion':1,
           'qa_contrast':2,
           'flow_artifact':3,
           'rv_lv_ratio_gte_1':4,
           'rv_lv_ratio_lt_1':5,
           'leftsided_pe':6,
           'chronic_pe':7,
           'negative_exam_for_pe':8,
           'rightsided_pe':9,
           'acute_and_chronic_pe':10,
           'central_pe':11,
           'indeterminate':12}

In [None]:
lvs=[OutputMap['rv_lv_ratio_gte_1'],
     OutputMap['rv_lv_ratio_lt_1'],
     OutputMap['negative_exam_for_pe'],
     OutputMap['indeterminate']]
chron=[OutputMap['chronic_pe'],
       OutputMap['acute_and_chronic_pe'],
       OutputMap['dummy0'],
       OutputMap['negative_exam_for_pe'],
       OutputMap['indeterminate']]

def calc_pred(y):
    lv = F.log_softmax(y[:,lvs],dim=1)
    chn = F.log_softmax(y[:,chron],dim=1)
    out = y
    out[:,lvs[:2]]=lv[:,:2]
    out[:,chron[:2]]=chn[:,:2]
    out[:,lvs[-2:]]=(lv[:,-2:]+chn[:,-2:])/2
    return out

In [None]:
WEIGHTS=torch.tensor([0.0,0.0,0.0,0.0,
                      0.2346625767,
                      0.0782208589,
                      0.06257668712,
                      0.1042944785,
                      0.0736196319,
                      0.06257668712,
                      0.1042944785,
                      0.1877300613,
                      0.09202453988])
class MyLoss():
    def __init__(self,weight_image=0.07361963,weights_series=WEIGHTS,weight_gen_img=0.1,
                 conf_weight=0.3,mean_len=246.,eq_series=1.,gamma=0.,image_series=0.5,eth=0.,use_max=False,do_calc=False):
        self.weight_image=weight_image
        self.weights_series=weights_series
        self.weight_gen_img=weight_gen_img
        self.mean_len=mean_len
        self.eq_series=eq_series/2
        self.conf_weight=conf_weight
        self.gamma=gamma
        self.image_series=image_series
        self.eth=eth
        self.use_max=use_max
        self.calc = calc_pred if do_calc else lambda x:x
    def __call__(self,y_pred,y_true):
        wim=2*self.image_series
        wis=2*(1-self.image_series)
        mask=y_true[0]>=0
        mi=(y_true[0]*mask).sum(1).to(torch.float32)
        ni=mask.sum(1).to(torch.float32)
        qi=mi/ni
        wi=self.weight_image*qi
        yp= y_pred[...,0]
        yprs=y_pred.max(1)[0] if self.use_max else y_pred[:,-1,:]
        l0=masked_bce_with_logits(yp,torch.clamp(y_true[0],self.eth,1-self.eth),mask=mask,reduction='none',gamma=self.gamma)
        bc2=binary_focal_loss(self.calc(yprs[:,1:]),torch.clamp(y_true[1],self.eth,1-self.eth),gamma=self.gamma,reduction='none')
        npe=yprs[:,[9,13]].max(1)[0]

        conf=torch.stack([yp.max(1)[0]*npe,
                           npe*y_pred[:,-1,[5,6]].max(1)[0],
                           y_pred[:,-1,5]*y_pred[:,-1,6]*(npe<0).to(float),
                           npe*y_pred[:,-1,[7,10,12]].max(1)[0],
                           npe*y_pred[:,-1,[8,11]].max(1)[0],
                           y_pred[:,-1,8]*y_pred[:,-1,11]*(npe<0)],1)
        bconf = binary_focal_loss(yprs[:,1:],torch.clamp(y_true[1],self.eth,1-self.eth),gamma=self.gamma,reduction='mean')*bc2.shape[0]*self.conf_weight
        m= ((wis*bc2*self.weights_series).sum()*self.eq_series+bconf+wim*(l0.sum(1)*wi).sum())/(wim*(wi*ni).sum()+wis*self.weights_series.sum()*bc2.shape[0]*self.eq_series)    
        li3=self.weight_gen_img*l0.sum()/mask.sum()
        return m+li3
        
        
    

In [None]:
class MyMetric():
    def __init__(self,weight_image=0.07361963,weights_series=WEIGHTS,weight_gen_img=0.1,mean_len=246.,eq_series=1.,do_calc=False,use_max=False,image_series=0.5):
        self.weight_image=weight_image
        self.weights_series=weights_series
        self.weight_gen_img=weight_gen_img
        self.mean_len=mean_len
        self.image_series=image_series
        self.eq_series=eq_series/2
        self.use_max=use_max
        self.calc = calc_pred if do_calc else lambda x:x
    def __call__(self,y_pred,y_true):
        wim=2*self.image_series
        wis=2*(1-self.image_series)
        yt0=torch.tensor(y_true[0])
        mask=yt0>=0
        mi=(yt0>0).sum(1).to(torch.float32)
        ni=(yt0>=0).sum(1).to(torch.float32)
        qi=mi/ni
        wi=self.weight_image*qi
        yp= torch.tensor(y_pred[...,0])
        l0=masked_bce_with_logits(yp,yt0,mask=mask,reduction='none')
        yprs=y_pred.max(1) if self.use_max else y_pred[:,-1,:]
#         yprs=y_pred.max(1)
        li1=float((l0.sum(1)*wi).sum()/(ni*wi).sum())
        bc2=F.binary_cross_entropy_with_logits(self.calc(torch.tensor(yprs[:,1:])),
                                                     torch.tensor(y_true[1],dtype=torch.float32),reduction='none')
        li2=float((bc2.mean(0)[4:]*self.weights_series[4:]).sum())
        li21=float((bc2.mean(0)[:4]*self.weights_series[:4]).sum())
        li3=float(self.weight_gen_img*l0.sum()/mask.sum())
        m= (wis*(bc2[:,4:]*self.weights_series[4:]).sum()*self.eq_series+wim*(l0.sum(1)*wi).sum())/(wim*(wi*ni).sum()+wis*self.weights_series[4:].sum()*bc2.shape[0]*self.eq_series)
        metric_dict={'metric':float(m),'images':li1,'series':li2,'li21':li21,'li3':li3}
        return metric_dict 
    


In [None]:
epoch_in_rep=10
reps =2
num_epochs=epoch_in_rep*reps
batch_size=32
reps_lr=[1e-4*batch_size/32,3e-5*batch_size/32,1e-5*batch_size/32]
linear_embd=OrderedDict([('slice',[16,16])]) #OrderedDict([('extra',[16,16]),('slice',[16,16])])
extra=None #{'col':'kvp','norm':100.,'noise':0.03}
extra_val=None #{'col':'kvp','norm':100.,'noise':0.}
for SEED in [220]:
    val_folds, train_folds, patients_val = create_folds(df,num_folds,SEED)
    all_patients=[]
    for p in patients_val:
        all_patients.extend(p)
    for fold in folds:
        torch.manual_seed(SEED+fold+add_seed)
        np.random.seed(SEED+fold+add_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        basic_name=params.model_format.format(model_type,basic_name_tamplate.format(SEED),basic_version,fold).split('.')[0]+'.pkl'
        with open(params.path.features+basic_name,'rb') as f:
            features0=pickle.load(f)
        print('nans=',np.isnan(features0).sum())
        features=torch.tensor(np.nan_to_num(features0),dtype=torch.float32)
        validate_ds=PatientFeaturesDataset(features[:1],df,patients_val[fold],max_len=val_max_len,
                                           rand_split=False,rep=1,new_z=False,extra=extra_val)
        train_ds=PatientFeaturesDataset(features,df,
                                        list(set(all_patients).difference(patients_val[fold])),
                                        max_len=train_max_len,rand_split=True,rep=1,new_z=False,fnoise=fnoise,extra=extra)
        eq_series_val =len(patients_val[fold])/len(validate_ds) 
        eq_series_train =len( list(set(all_patients).difference(patients_val[fold])))/len(train_ds) 
        print(eq_series_val,eq_series_train )
        epoch_size=len(train_ds)
        model = get_transformer_model(dim_feedforward=ffdim,n_heads=num_heads,linear_embd=linear_embd,freeze=False,
                                      n_encoders=num_layers,dropout=dropout,use_src_mask=False,res=False).to(device)
        name=params.model_format.format(model_type,name_tamplate.format(SEED),model_version,fold)
        my_loss=MyLoss(weights_series=WEIGHTS.to(device),eq_series=eq_series_train,
                       weight_gen_img=0.1,conf_weight=0.,gamma=gamma,image_series=0.5,eth=0.,use_max=False,do_calc=False)
        my_metric=MyMetric(eq_series=eq_series_val,image_series=0.5,use_max=False,do_calc=False)
        learner = Learner(model,None,loss_func=my_loss,name=name,scheduler=None,device=device)
        learner.metric=my_metric
        learner.optimizer = torch.optim.Adam(learner.model.parameters(), lr=1e-4)

        def new_get_y(self,batch):
            return batch[-3],batch[-2]
        def run_model(self,model,batch):
            mask = batch[1]==-1
            return model(*(x.to(self.device) for x in batch[:3+(extra is not None)]),mask=mask.to(device))
        def calc_loss(self,y_pred,y_true):
            return self.loss_func(y_pred,tuple(y.to(self.device) for y in y_true))
        def on_epoch_begin(self,*args,**kargs):
            train_ds.reset()
        learner.get_y=MethodType(new_get_y, learner)
        learner.run_model=MethodType(run_model, learner)
        learner.calc_loss=MethodType(calc_loss, learner)
        learner.on_epoch_begin=MethodType(on_epoch_begin, learner)
        train_dl_args={'shuffle':True,'batch_size':batch_size}
        for t in range(reps):
            learner.scheduler = torch.optim.lr_scheduler.OneCycleLR(learner.optimizer, pct_start=0.01,final_div_factor= 10,
                                                                    max_lr=reps_lr[t], steps_per_epoch=epoch_size//batch_size+1, 
                                                                    epochs=num_epochs//reps)

            learner.fit(num_epochs//reps,train_ds,validate_ds,batch_size=batch_size,eval_batch=batch_size,path=params.path.models,
                        train_dl_args=train_dl_args,num_workers=12,send_log=False)
        sandesh.send({'name':learner.name,'best_metric':learner.best_metric})
        learner.save_model(params.path.models)
        
