In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import fastai
from fastai.vision import *
from fastai.callbacks import SaveModelCallback
import os
from radam import *
from csvlogger import *
from mish import *
import cv2
from albumentations import *
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import confusion_matrix

import warnings
warnings.filterwarnings("ignore")

fastai.__version__

'1.0.60'

In [2]:
sz = 128
bs = 64
nfolds = 4
fold = 0
SEED = 43
N = 12
TRAIN = 'data/train_16x128x128'
LABELS = 'data/train_d1.csv'
OUT = 'd1'
NUM_WORKERS = 12

os.makedirs(OUT, exist_ok=True)

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [3]:
df = pd.read_csv(LABELS).set_index('image_id')
files = sorted(set([p[:32] for p in os.listdir(TRAIN)]))
df.gleason_score = df.gleason_score.replace('negative','0+0')
df = df.loc[files]
df = df.reset_index()
splits = StratifiedKFold(n_splits=nfolds, random_state=SEED, shuffle=True)
splits = list(splits.split(df,df.isup_grade))
folds_splits = np.zeros(len(df)).astype(np.int)
for i in range(nfolds):
    folds_splits[splits[i][1]] = i
df['split'] = folds_splits
Ng, Ns = df.nunique()[2], df.nunique()[3]
score_map = {s:i for i,s in enumerate(df.gleason_score.unique())}
df['score'] = df.gleason_score.map(score_map)
df['provider'] = df.data_provider == 'karolinska'
df.head()

Unnamed: 0,image_id,data_provider,isup_grade,gleason_score,isup_grade0,split,score,provider
0,0005f7aaab2800f6170c399693a96917,karolinska,0,0+0,0,0,0,True
1,000920ad0b612851f8e01bcc880d9b3d,karolinska,0,0+0,0,1,0,True
2,0018ae58b01bdadc8e347995b69f99aa,radboud,4,4+4,4,2,1,False
3,001c62abd11fa4b57bf7a6c603a11bb9,karolinska,4,4+4,4,3,1,True
4,001d865e65ef5d2579c190a0e0350d8f,karolinska,0,0+0,0,3,0,True


In [4]:
mean = np.array([1.0-0.90949707, 1.0-0.8188697, 1.0-0.87795304])
std = np.array([0.36357649, 0.49984502, 0.40477625])

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class PANDADataset(Dataset):
    def __init__(self, df, fold=fold, train=True, tfms=None):
        self.df = df.loc[df.split != fold].copy() if train else df.loc[df.split == fold].copy()
        self.df = self.df.reset_index(drop=True)
        self.train = train
        self.tfms = tfms
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        labels = self.df.iloc[idx][['isup_grade','score','provider','isup_grade0']].astype(np.long).values
        provider = self.df.iloc[idx].data_provider
        
        idx = self.df.iloc[idx].image_id
        imgs = []
        for i in range(N):
            img = cv2.cvtColor(cv2.imread(os.path.join(TRAIN,idx+'_'+str(i)+'.png')), cv2.COLOR_BGR2RGB)
            img = 255 - img
            if self.tfms is not None:
                augmented = self.tfms(image=img)
                img = augmented['image']
            imgs.append(img)
        imgs = [img2tensor((img/255.0 - mean)/std,np.float32) for img in imgs]

        return torch.stack(imgs,0), labels

In [5]:
def get_aug(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.3, rotate_limit=15, p=0.9, 
                         border_mode=cv2.BORDER_CONSTANT),
        #OneOf([
        #    MotionBlur(blur_limit=3, p=0.1),
        #    MedianBlur(blur_limit=3, p=0.1),
        #    Blur(blur_limit=3, p=0.1),
        #], p=0.2),
        #OneOf([
        #    OpticalDistortion(p=0.3),
        #    GridDistortion(p=.1),
        #    IAAPiecewiseAffine(p=0.3),
        #], p=0.3),
        OneOf([
            HueSaturationValue(10,15,10),
            CLAHE(clip_limit=2),
            IAASharpen(),
            RandomBrightnessContrast(),            
        ], p=0.3),
    ], p=p)

In [8]:
class Model(nn.Module):
    def __init__(self, arch='resnext50_32x4d_ssl', n=Ns, pre=True,ps=0.5):
        super().__init__()
        m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', arch)
        nc = list(m.children())[-1].in_features
        self.enc = nn.Sequential(*list(m.children())[:-2])
        self.head = nn.Sequential(AdaptiveConcatPool2d(),Flatten(),
                                  nn.Linear(2*nc,512),Mish(),nn.GroupNorm(32,512),
                                  nn.Dropout(ps),nn.Linear(512,n+1))
        
    def forward(self, x):
        shape = x.shape
        n = shape[1]
        x = x.view(-1,shape[2],shape[3],shape[4])
        x = self.enc(x)
        
        shape = x.shape
        x = x.view(-1,n,shape[1],shape[2],shape[3]).permute(0,2,1,3,4).contiguous()\
          .view(-1,shape[1],shape[2]*n,shape[3])
        x = self.head(x)
        return x[:,:1],x[:,1:]

In [9]:
y_shift = df.isup_grade.mean()
def Kloss(x, target):
    x = Ng*torch.sigmoid(x.float()).view(-1) - 0.5
    target = target.float()
    return 1.0 - (2.0*((x-y_shift)*(target-y_shift)).sum() - 1e-3)/\
        (((x-y_shift)**2).sum() + ((target-y_shift)**2).sum() + 1e-3)

def Combine_loss(x, target):
    loss_c = Kloss(x[0].float(),target[:,0])
    loss_caux = F.cross_entropy(x[1].float(),target[:,1])
    return loss_c + 0.1*loss_caux

In [10]:
class DConfusionMatrix(Callback):
    def __init__(self, provider=None, original=False, **kwargs):
        self.provider=provider
        self.original=original
        super().__init__(**kwargs)
    
    def on_train_begin(self, **kwargs):
        self.n_classes = 0

    def on_epoch_begin(self, **kwargs):
        self.cm = None

    def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
        if self.provider is None:
            last_output = last_output[0]
            last_target = last_target[:,0] if not self.original else last_target[:,3]
        else:
            idxs = last_target[:,2] == self.provider
            last_output = last_output[0][idxs]
            last_target = last_target[:,0][idxs]  if not self.original else last_target[:,3][idxs]
            if len(last_output)  == 0: return
        preds = torch.clamp((Ng*torch.sigmoid(last_output.float())).long().view(-1).cpu(),0,Ng-1)
        targs = last_target.cpu()
        if self.n_classes == 0:
            self.n_classes = Ng
            self.x = torch.arange(0, self.n_classes)
        cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])).sum(dim=2, dtype=torch.float32)
        if self.cm is None: self.cm =  cm
        else:               self.cm += cm

    def on_epoch_end(self, **kwargs):
        self.metric = self.cm

class DKappaScore(DConfusionMatrix):
    def __init__(self, weights:Optional[str]=None, **kwargs):
        super().__init__(**kwargs)
        self.weights = weights

    def on_epoch_end(self, last_metrics, **kwargs):
        sum0 = self.cm.sum(dim=0)
        sum1 = self.cm.sum(dim=1)
        expected = torch.einsum('i,j->ij', (sum0, sum1)) / sum0.sum()
        if self.weights is None:
            w = torch.ones((self.n_classes, self.n_classes))
            w[self.x, self.x] = 0
        elif self.weights == "linear" or self.weights == "quadratic":
            w = torch.zeros((self.n_classes, self.n_classes))
            w += torch.arange(self.n_classes, dtype=torch.float)
            w = torch.abs(w - torch.t(w)) if self.weights == "linear" else (w - torch.t(w)) ** 2
        else: raise ValueError('Unknown weights. Expected None, "linear", or "quadratic".')
        k = torch.sum(w * self.cm) / torch.sum(w * expected)
        return add_metrics(last_metrics, 1-k)
    
class kappa_k(DKappaScore):
    def __init__(self):
        super().__init__(weights='quadratic',provider=1)
        
class kappa_r(DKappaScore):
    def __init__(self):
        super().__init__(weights='quadratic',provider=0)
        
class kappa_k0(DKappaScore):
    def __init__(self):
        super().__init__(weights='quadratic',provider=1,original=True)
        
class kappa_r0(DKappaScore):
    def __init__(self):
        super().__init__(weights='quadratic',provider=0,original=True)
        
class kappa0(DKappaScore):
    def __init__(self):
        super().__init__(weights='quadratic',original=True)

In [11]:
fname = 'RNXT50_s43'
pred,pred_y = [],[]
for fold in range(nfolds):
    ds_t = PANDADataset(df, fold=fold, train=True, tfms=get_aug())
    ds_v = PANDADataset(df, fold=fold, train=False)
    data = DataBunch.create(ds_t,ds_v,bs=bs,num_workers=NUM_WORKERS)
    model = nn.DataParallel(Model())
    learn = Learner(data, model, loss_func=Combine_loss, opt_func=Over9000, 
                metrics=[DKappaScore(weights='quadratic'),kappa_k(),kappa_r(),
                         kappa0(),kappa_k0(),kappa_r0()]).to_fp16()
    logger = CSVLogger(learn,os.path.join(OUT,f'log_{fname}_{fold}'))
    learn.clip_grad = 1.0
    learn.split([model.module.head])
    learn.unfreeze()

    learn.fit_one_cycle(36, max_lr=slice(0.5e-3,0.2e-2), div_factor=50, pct_start=0.0, 
      callbacks = [SaveModelCallback(learn,name=f'model',monitor='d_kappa_score')])
    torch.save(learn.model.module.state_dict(),os.path.join(OUT,f'{fname}_{fold}.pth'))
    
    learn.model.eval()
    with torch.no_grad():
        for step, (x, y) in progress_bar(enumerate(data.dl(DatasetType.Valid)),
                                     total=len(data.dl(DatasetType.Valid))):
            p = learn.model(x.cuda().half())
            pred.append(p[0].float().view(-1).cpu())
            pred_y.append(y[:,0].cpu())
            
p = torch.clamp((6.0*torch.sigmoid(torch.cat(pred))).long(),0,Ng-1)
t = torch.cat(pred_y)
print(cohen_kappa_score(p,t,weights='quadratic'))
print(confusion_matrix(t,p))

Using cache found in /home/iafoss/.cache/torch/hub/facebookresearch_semi-supervised-ImageNet1K-models_master


epoch,train_loss,valid_loss,d_kappa_score,kappa_k,kappa_r,kappa0,kappa_k0,kappa_r0,time
0,0.55985,0.40058,0.736352,0.662918,0.719059,0.702008,0.652163,0.661217,01:11
1,0.420296,0.369072,0.758162,0.72946,0.701591,0.720732,0.716369,0.636698,01:03
2,0.375028,0.347288,0.774504,0.734936,0.717273,0.736542,0.724937,0.648204,01:03
3,0.355469,0.290738,0.818576,0.757803,0.806316,0.777482,0.748402,0.733482,01:03
4,0.342032,0.315884,0.798101,0.766426,0.751933,0.756777,0.755097,0.678747,01:03
5,0.323345,0.277177,0.82745,0.790659,0.804889,0.787814,0.778517,0.736454,01:04
6,0.310577,0.47546,0.65192,0.512,0.727526,0.619953,0.503845,0.667388,01:04
7,0.304128,0.264425,0.840268,0.802979,0.820021,0.798671,0.792552,0.747311,01:05
8,0.275856,0.288222,0.819175,0.76373,0.813679,0.77907,0.755415,0.744826,01:05
9,0.294842,0.382508,0.735796,0.685274,0.707,0.700422,0.674888,0.645386,01:06


Better model found at epoch 0 with d_kappa_score value: 0.7363522052764893.
Better model found at epoch 1 with d_kappa_score value: 0.7581620216369629.
Better model found at epoch 2 with d_kappa_score value: 0.7745035886764526.
Better model found at epoch 3 with d_kappa_score value: 0.8185763955116272.
Better model found at epoch 5 with d_kappa_score value: 0.8274499177932739.
Better model found at epoch 7 with d_kappa_score value: 0.8402683734893799.
Better model found at epoch 13 with d_kappa_score value: 0.8583625555038452.
Better model found at epoch 19 with d_kappa_score value: 0.8591977953910828.
Better model found at epoch 20 with d_kappa_score value: 0.8612195253372192.
Better model found at epoch 21 with d_kappa_score value: 0.8729861974716187.
Better model found at epoch 24 with d_kappa_score value: 0.8774484992027283.
Better model found at epoch 27 with d_kappa_score value: 0.8814295530319214.
Better model found at epoch 29 with d_kappa_score value: 0.8816941976547241.


Using cache found in /home/iafoss/.cache/torch/hub/facebookresearch_semi-supervised-ImageNet1K-models_master


epoch,train_loss,valid_loss,d_kappa_score,kappa_k,kappa_r,kappa0,kappa_k0,kappa_r0,time
0,0.560236,0.45172,0.69113,0.669094,0.565739,0.66317,0.647881,0.528366,01:02
1,0.429245,0.367537,0.757546,0.657544,0.758863,0.719243,0.635635,0.700069,01:02
2,0.390836,0.297869,0.816388,0.736172,0.80985,0.778098,0.712356,0.749561,01:02
3,0.36321,0.290589,0.822925,0.738895,0.821118,0.785645,0.718385,0.760756,01:02
4,0.351997,0.315488,0.799963,0.738488,0.779036,0.762633,0.717505,0.719845,01:02
5,0.32682,0.268986,0.840036,0.769217,0.843303,0.799945,0.745276,0.780448,01:02
6,0.319133,0.363689,0.756077,0.685015,0.752184,0.726587,0.669164,0.704532,01:02
7,0.300136,0.290187,0.826421,0.761738,0.818169,0.787404,0.738902,0.756541,01:02
8,0.298187,0.284694,0.82332,0.737524,0.823434,0.785505,0.716256,0.761976,01:02
9,0.290759,0.253101,0.849802,0.796548,0.842509,0.812091,0.773634,0.784139,01:02


Better model found at epoch 0 with d_kappa_score value: 0.6911299228668213.
Better model found at epoch 1 with d_kappa_score value: 0.7575457692146301.
Better model found at epoch 2 with d_kappa_score value: 0.8163880109786987.
Better model found at epoch 3 with d_kappa_score value: 0.8229246139526367.
Better model found at epoch 5 with d_kappa_score value: 0.840036153793335.
Better model found at epoch 9 with d_kappa_score value: 0.8498024940490723.
Better model found at epoch 11 with d_kappa_score value: 0.8553290367126465.
Better model found at epoch 19 with d_kappa_score value: 0.8648329973220825.
Better model found at epoch 25 with d_kappa_score value: 0.870124876499176.
Better model found at epoch 26 with d_kappa_score value: 0.8710570335388184.
Better model found at epoch 27 with d_kappa_score value: 0.8783380389213562.


Using cache found in /home/iafoss/.cache/torch/hub/facebookresearch_semi-supervised-ImageNet1K-models_master


epoch,train_loss,valid_loss,d_kappa_score,kappa_k,kappa_r,kappa0,kappa_k0,kappa_r0,time
0,0.538944,0.397812,0.746545,0.700709,0.679233,0.711507,0.680392,0.625794,01:05
1,0.420902,0.33081,0.792605,0.700553,0.779348,0.751125,0.677489,0.714617,01:05
2,0.380144,0.387816,0.754783,0.659041,0.722067,0.714882,0.639098,0.657467,01:04
3,0.343255,0.287814,0.82321,0.783934,0.794167,0.779363,0.760843,0.725324,01:04
4,0.330577,0.312066,0.800285,0.739963,0.784935,0.758344,0.714923,0.723726,01:04
5,0.309815,0.282124,0.828561,0.79382,0.809109,0.786083,0.772485,0.742787,01:04
6,0.310466,0.283575,0.829554,0.786447,0.80059,0.785857,0.762465,0.731599,01:04
7,0.304779,0.26751,0.839959,0.775581,0.832105,0.795344,0.751729,0.762199,01:04
8,0.29292,0.278199,0.832289,0.814182,0.792996,0.789342,0.790252,0.726567,01:04
9,0.292102,0.301543,0.815179,0.800236,0.775451,0.773288,0.777712,0.71173,01:05


Better model found at epoch 0 with d_kappa_score value: 0.7465454339981079.
Better model found at epoch 1 with d_kappa_score value: 0.7926052808761597.
Better model found at epoch 3 with d_kappa_score value: 0.8232097625732422.
Better model found at epoch 5 with d_kappa_score value: 0.8285610675811768.
Better model found at epoch 6 with d_kappa_score value: 0.82955402135849.
Better model found at epoch 7 with d_kappa_score value: 0.8399591445922852.
Better model found at epoch 11 with d_kappa_score value: 0.8556734323501587.
Better model found at epoch 17 with d_kappa_score value: 0.8627267479896545.
Better model found at epoch 19 with d_kappa_score value: 0.8659891486167908.
Better model found at epoch 21 with d_kappa_score value: 0.8706778287887573.
Better model found at epoch 23 with d_kappa_score value: 0.8733838200569153.
Better model found at epoch 27 with d_kappa_score value: 0.8754150867462158.
Better model found at epoch 29 with d_kappa_score value: 0.8771340847015381.
Better 

Using cache found in /home/iafoss/.cache/torch/hub/facebookresearch_semi-supervised-ImageNet1K-models_master


epoch,train_loss,valid_loss,d_kappa_score,kappa_k,kappa_r,kappa0,kappa_k0,kappa_r0,time
0,0.550755,0.399086,0.745308,0.657568,0.719653,0.704506,0.634487,0.658604,01:02
1,0.435932,0.376373,0.751285,0.563407,0.782431,0.709021,0.541167,0.716801,01:02
2,0.371888,0.303651,0.813159,0.740175,0.788832,0.769571,0.714273,0.721448,01:02
3,0.368971,0.332822,0.78406,0.752128,0.729091,0.743782,0.730341,0.667384,01:02
4,0.329287,0.276173,0.826485,0.754317,0.816677,0.778438,0.728983,0.742116,01:02
5,0.315704,0.281812,0.823728,0.732106,0.823688,0.775669,0.704473,0.749491,01:02
6,0.300318,0.274189,0.833962,0.783102,0.82248,0.788038,0.757777,0.752913,01:02
7,0.305398,0.291682,0.818023,0.772518,0.782277,0.772987,0.748853,0.71124,01:02
8,0.286935,0.438386,0.695114,0.608522,0.664332,0.65707,0.585418,0.611304,01:02
9,0.279465,0.261604,0.842281,0.800506,0.82635,0.794212,0.774609,0.752098,01:02


Better model found at epoch 0 with d_kappa_score value: 0.7453080415725708.
Better model found at epoch 1 with d_kappa_score value: 0.751285195350647.
Better model found at epoch 2 with d_kappa_score value: 0.8131594657897949.
Better model found at epoch 4 with d_kappa_score value: 0.826484739780426.
Better model found at epoch 6 with d_kappa_score value: 0.8339618444442749.
Better model found at epoch 9 with d_kappa_score value: 0.8422808647155762.
Better model found at epoch 13 with d_kappa_score value: 0.8478544354438782.
Better model found at epoch 17 with d_kappa_score value: 0.8642752766609192.
Better model found at epoch 23 with d_kappa_score value: 0.8708791732788086.
Better model found at epoch 24 with d_kappa_score value: 0.8720924854278564.
Better model found at epoch 31 with d_kappa_score value: 0.8729544281959534.
Better model found at epoch 33 with d_kappa_score value: 0.8751765489578247.
Better model found at epoch 34 with d_kappa_score value: 0.8764822483062744.
Better 

0.8803126684578566
[[2190  431   62   21    3    0]
 [ 434 1673  454   82   17    2]
 [  66  407  639  283   52    5]
 [  20   90  227  507  342   52]
 [  18   44   69  248  713  383]
 [   6    6   15   40  191  724]]
