In [1]:
!git clone https://github.com/clementgr/lymphocytosis-mil.git

Cloning into 'lymphocytosis-mil'...
remote: Enumerating objects: 123, done.[K
remote: Counting objects: 100% (123/123), done.[K
remote: Compressing objects: 100% (84/84), done.[K
remote: Total 123 (delta 58), reused 91 (delta 29), pack-reused 0[K
Receiving objects: 100% (123/123), 20.76 KiB | 6.92 MiB/s, done.
Resolving deltas: 100% (58/58), done.


In [2]:
import os
os.chdir('lymphocytosis-mil')

In [3]:
%%capture
!pip install -r requirements.txt
!rm -r /content/sample_data/

In [4]:
%%capture
!sh scripts/data_download.sh

train

In [5]:
import torch.nn as nn
from pytorch_lightning import seed_everything, Trainer

from callbacks import ProgressBar
from data.data_module import LymphoDataModule
from models.mil_model import MILModel
from models.simple_cnn import SimpleCNN
from models.resnet import se_resnet50

In [6]:
seed_everything(21)

data_module = LymphoDataModule(
    data_dir='data/3md3070-dlmi/',
    batch_size=32, 
    num_workers=0)
data_module.setup()

Global seed set to 21
100%|██████████| 130/130 [00:00<00:00, 3058.62it/s]
100%|██████████| 33/33 [00:00<00:00, 2795.69it/s]
100%|██████████| 42/42 [00:00<00:00, 552.42it/s]


In [7]:
import pandas as pd
df = pd.read_csv('data/3md3070-dlmi/train.csv')
df.head()

Unnamed: 0,id,label,gender,dob,lymph_count,tiles
0,P154,1,M,6/20/1945,11.97,data/3md3070-dlmi/train/P154/000012.jpg
1,P154,1,M,6/20/1945,11.97,data/3md3070-dlmi/train/P154/000027.jpg
2,P154,1,M,6/20/1945,11.97,data/3md3070-dlmi/train/P154/000035.jpg
3,P154,1,M,6/20/1945,11.97,data/3md3070-dlmi/train/P154/000086.jpg
4,P154,1,M,6/20/1945,11.97,data/3md3070-dlmi/train/P154/000065.jpg


In [8]:
os.makedirs('checkpoints/exp1/')

In [9]:
model = se_resnet50()
model.last_linear = nn.Linear(2048, 1)
clf = MILModel(model, topk=10, aggregation='mean')
trainer = Trainer(
  reload_dataloaders_every_epoch=True, 
  check_val_every_n_epoch=1,
  weights_save_path='checkpoints/exp1',
  gpus=1,
  callbacks=[ProgressBar()])

Downloading: "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth" to /root/.cache/torch/hub/checkpoints/se_resnet50-ce0d4300.pth


HBox(children=(FloatProgress(value=0.0, max=112611220.0), HTML(value='')))




GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [10]:
def run_training(clf, trainer, data_module):
  assert trainer.reload_dataloaders_every_epoch
  # print('Setting trainer.max_epochs = 1.')
  trainer.max_epochs = 1
  data_module.inference_dataset_reference = data_module.train_dataset
  inference_metrics = trainer.test(clf, datamodule=data_module)[0]
  data_module.train_dataset_reference = data_module.train_dataset
  data_module.validation_dataset_reference = data_module.val_dataset
  trainer.fit(clf, datamodule=data_module)
  training_metrics = clf.training_metrics
  validation_metrics = clf.validation_metrics
  return inference_metrics, training_metrics, validation_metrics

In [11]:
all_training_metrics = []
all_validation_metrics = []
for epoch in range(20):
  inference_metrics, training_metrics, validation_metrics = run_training(clf, trainer, data_module)
  all_training_metrics.append(training_metrics)
  all_validation_metrics.append(validation_metrics)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.692307710647583,
 'auc': 0.7469444274902344,
 'precision': 0.692307710647583,
 'recall': 0.692307710647583}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.5230769515037537,
 'auc': 0.6838889122009277,
 'precision': 0.5230769515037537,
 'recall': 0.5230769515037537}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.7692307829856873,
 'auc': 0.8574999570846558,
 'precision': 0.7692307829856873,
 'recall': 0.7692307829856873}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.692307710647583,
 'auc': 0.8916666507720947,
 'precision': 0.692307710647583,
 'recall': 0.692307710647583}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.7076923251152039,
 'auc': 0.9102777242660522,
 'precision': 0.7076923251152039,
 'recall': 0.7076923251152039}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.692307710647583,
 'auc': 0.8938888311386108,
 'precision': 0.692307710647583,
 'recall': 0.692307710647583}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.7538461685180664,
 'auc': 0.9030556082725525,
 'precision': 0.7538461685180664,
 'recall': 0.7538461685180664}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.7230769395828247,
 'auc': 0.9127777814865112,
 'precision': 0.7230769395828247,
 'recall': 0.7230769395828247}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.8307692408561707,
 'auc': 0.9322222471237183,
 'precision': 0.8307692408561707,
 'recall': 0.8307692408561707}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.8999999761581421,
 'auc': 0.930555522441864,
 'precision': 0.8999999761581421,
 'recall': 0.8999999761581421}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.692307710647583,
 'auc': 0.9558333158493042,
 'precision': 0.692307710647583,
 'recall': 0.692307710647583}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.8538461327552795,
 'auc': 0.9358333349227905,
 'precision': 0.8538461327552795,
 'recall': 0.8538461327552795}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.8999999761581421,
 'auc': 0.9908332824707031,
 'precision': 0.8999999761581421,
 'recall': 0.8999999761581421}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.8461538553237915,
 'auc': 0.9477777481079102,
 'precision': 0.8461538553237915,
 'recall': 0.8461538553237915}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.7153846025466919,
 'auc': 0.9891666769981384,
 'precision': 0.7153846025466919,
 'recall': 0.7153846025466919}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.699999988079071,
 'auc': 0.9961111545562744,
 'precision': 0.699999988079071,
 'recall': 0.699999988079071}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.8461538553237915,
 'auc': 0.9730555415153503,
 'precision': 0.8461538553237915,
 'recall': 0.8461538553237915}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.7615384459495544,
 'auc': 0.9827777147293091,
 'precision': 0.7615384459495544,
 'recall': 0.7615384459495544}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.8769230842590332,
 'auc': 0.9491665959358215,
 'precision': 0.8769230842590332,
 'recall': 0.8769230842590332}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


  | Name  | Type    | Params
----------------------------------
0 | model | SENet   | 26.0 M
1 | loss  | BCELoss | 0     
----------------------------------
26.0 M    Trainable params
0         Non-trainable params
26.0 M    Total params



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0.692307710647583,
 'auc': 0.887222170829773,
 'precision': 0.692307710647583,
 'recall': 0.692307710647583}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [10]:
# all_validation_metrics

test

In [12]:
from torch.utils.data import DataLoader
from data.dataset import MILImageDataset

test_df = pd.read_csv('data/3md3070-dlmi/test.csv')
test_dataset = MILImageDataset(test_df, training=False)
data_module.inference_dataset_reference = test_dataset
data_module.batch_size = 1
clf.compute_test_metrics = False

In [13]:
test_metrics = trainer.test(clf, datamodule=data_module)[0]

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc': 0,
 'auc': 0.887222170829773,
 'precision': 0.692307710647583,
 'recall': 0.692307710647583}
--------------------------------------------------------------------------------


In [15]:
preds_df = pd.read_csv('inference.csv')
submission_df = preds_df[['id', 'preds']].rename(columns={"id": "Id", "preds": "Predicted"})
submission_df.head()

Unnamed: 0,Id,Predicted
0,P108,0
1,P114,1
2,P119,0
3,P120,1
4,P132,1


In [17]:
submission_df.to_csv('submission_k=10_mean.csv', index=False)