In [1]:
import os
import pandas as pd
import numpy as np
import json
from tqdm.auto import tqdm
from panda_challenge.models import ClassifcationMultiCropModel, ClassifcationMultiCropModelMultiHead
from panda_challenge.dataset import ClassifcationMultiCropInferenceDataset, ClassifcationDatasetMultiCropMultiHead
from torch.utils.data import DataLoader
from sklearn.metrics import cohen_kappa_score
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
with open('./configs/heavy_multicrop_regression_config.json', 'r') as f:
    params = json.load(f)
data_val = pd.read_csv('/data/personal_folders/skolchenko/panda/data_val.csv')
data_val.head(5)

Unnamed: 0,image_id,data_provider,isup_grade,gleason_score,has_slide
0,90b8c7eac5e1df7eb3310d2346a24ebc,karolinska,0,0+0,True
1,ecae863e7c478594aa4c84ce132b3825,karolinska,1,3+3,True
2,d2fa7b7e8b5d7ac158a745839005336a,karolinska,2,3+4,True
3,e3c973ccb8b37ad37da0acd96be64830,karolinska,2,3+4,True
4,5c869a5a8bd2bd340fe13b188be333c1,karolinska,1,3+3,True


Load model and weights

In [4]:
model = ClassifcationMultiCropModel(
    params['model_name'],
    **params['model_config'])
model.cuda()
checkpoint = torch.load(os.path.join(params['log_dir'], 'checkpoints/best.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

ClassifcationMultiCropModel(
  (enc): Sequential(
    (0): Conv2dSame(3, 40, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (1): BatchNorm2d(40, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (2): Swish()
    (3): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
          (bn1): BatchNorm2d(40, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (act1): Swish()
          (se): SqueezeExcite(
            (avg_pool): AdaptiveAvgPool2d(output_size=1)
            (conv_reduce): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1))
            (act1): Swish()
            (conv_expand): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1))
          )
          (conv_pw): Conv2d(40, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=

Define dataset and dataloader

In [5]:
dataset_val = ClassifcationDatasetMultiCropMultiHead(
        params['val_csv'],
        params['val_transformations'],
        params['val_image_dir'],
        params['val_mask_dir'],
        **params['dataset_config'])
validation_loader = DataLoader(
    dataset_val,
    batch_size=params['batch_size'],
    num_workers=params['n_workers'],
    pin_memory=True,
    shuffle=False
    )

Run predictions, if necesarry, run tta

In [6]:
validation_preds = []
with torch.no_grad():
    for data_b in validation_loader:
        preds = model(data_b['features'].cuda())
        isup_score = preds
        isup_score = isup_score.cpu().numpy()
        validation_preds.extend(isup_score)     
    if params['dataset_config']['output_type'] == 'regression':
        validation_preds = np.array([x[0] for x in validation_preds])
        validation_preds = np.rint(validation_preds)
        validation_preds[validation_preds < 0] = 0
        validation_preds[validation_preds > 5] = 5

Compute metrics

In [7]:
cohen_kappa_score(
    validation_preds, 
    data_val.isup_grade.values, 
    weights='quadratic')

0.8348904572175693

In [1]:
import timm

In [2]:
timm.list_models()

['adv_inception_v3',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dpn68',
 'dpn68b',
 'dpn92',
 'dpn98',
 'dpn107',
 'dpn131',
 'ecaresnet18',
 'ecaresnet50',
 'ecaresnext26tn_32x4d',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'efficientnet_b2a',
 'efficientnet_b3',
 'efficientnet_b3a',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_b8',
 'efficientnet_cc_b0_4e',
 'efficientnet_cc_b0_8e',
 'efficientnet_cc_b1_8e',
 'efficientnet_el',
 'efficientnet_em',
 'efficientnet_es',
 'efficientnet_l2',
 'efficientnet_lite0',
 'efficientnet_lite1',
 'efficientnet_lite2',
 'efficientnet_lite3',
 'efficientnet_lite4',
 'ens_adv_inception_resnet_v2',
 'fbnetc_100',
 'gluon_inception_v3',
 'gluon_resnet18_v1b',
 'gluon_resnet34_v1b',
 'gluon_resnet50_v1b',
 'gluon_