In [6]:
MODEL_NAME = "resnet18"
DUTIL_PERCENT = 100
WEAK_STRONG = "strong"
PROJECT_NAME = f"{MODEL_NAME}_{str(DUTIL_PERCENT)}%_{WEAK_STRONG}-model"
LEARNING_RATE = 1e-4

import cv2
import glob
import numpy as np
import pandas as pd

import timm 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import Dataset, DataLoader

from collections import Counter
from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torchmetrics
from torchmetrics.functional import confusion_matrix
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

flist = glob.glob("data/whole_patches_in_pos_WSI_2048_only_labeled/*.png")

df = []
for f in flist:
    f_split = f.split("/")[-1].split("_")
    label = [1 if f_split[0] == "pos" else 0][0]
    p_id = f_split[1]
    s_id = f_split[2]

    df.append([p_id, s_id, label, f])

df = pd.DataFrame(df, columns=["p_id", "s_id", "label", "fpath"])
df_id = df.s_id.unique()
train_id, test_id = train_test_split(df_id, test_size=0.2, random_state=42)
train_id, valid_id = train_test_split(train_id, test_size=0.2, random_state=42)

train_df = df[df.s_id.isin(train_id)].reset_index(drop=True)
# _, train_df = train_test_split(train_df, test_size=0.2, random_state=42)
# train_df = train_df.reset_index(drop=True)

valid_df = df[df.s_id.isin(valid_id)].reset_index(drop=True)
test_df = df[df.s_id.isin(test_id)].reset_index(drop=True)
print(f"# of patches\t train: {len(train_df)} valid: {len(valid_df)} test: {len(test_df)}")
print(f"# of labels\t train: {Counter(train_df.label.values)} valid: {Counter(valid_df.label.values)} test: {Counter(test_df.label.values)}")

# of patches	 train: 622 valid: 156 test: 195
# of labels	 train: Counter({0: 441, 1: 181}) valid: Counter({0: 100, 1: 56}) test: Counter({0: 130, 1: 65})


In [7]:
train_df_pos = train_df[(train_df.label == 1)].reset_index(drop=True)
train_df_neg = train_df[(train_df.label == 0)].reset_index(drop=True)

train_df_pos_sample = train_df_pos.sample(int(len(train_df_pos) * 0.3), random_state=42).reset_index(drop=True)
train_df = pd.concat([train_df_pos_sample, train_df_neg], axis=0).reset_index(drop=True)
print(f"# of patches\t train: {len(train_df)} valid: {len(valid_df)} test: {len(test_df)}")
print(f"# of labels\t train: {Counter(train_df.label.values)} valid: {Counter(valid_df.label.values)} test: {Counter(test_df.label.values)}")

# of patches	 train: 495 valid: 156 test: 195
# of labels	 train: Counter({0: 441, 1: 54}) valid: Counter({0: 100, 1: 56}) test: Counter({0: 130, 1: 65})


In [8]:
weak_model_project_name = PROJECT_NAME.replace("strong", "weak")
aug_flist = glob.glob(f"results/aug-data_{MODEL_NAME}_{DUTIL_PERCENT}%/*.jpg")

aug_df = []
for aug_f in aug_flist:
    p_id = aug_f.split("/")[-1].split("_")[0]
    s_id = aug_f.split("/")[-1].split("_")[2]
    label = 0
    aug_df.append([p_id, s_id, label, aug_f])
    
aug_df = pd.DataFrame(aug_df, columns=["p_id", "s_id", "label", "fpath"])

train_df = train_df.append(aug_df).reset_index(drop=True)
train_df.label = train_df.label.astype(np.uint8)

print(f"# of P_id\t train: {len(train_id)} valid: {len(valid_id)} test: {len(test_id)}")
print(f"# of patches\t train: {len(train_df)} valid: {len(valid_df)} test: {len(test_df)}")
print(f"# of labels\t train: {Counter(train_df.label.values)} valid: {Counter(valid_df.label.values)} test: {Counter(test_df.label.values)}")

# of P_id	 train: 622 valid: 156 test: 195
# of patches	 train: 600 valid: 156 test: 195
# of labels	 train: Counter({0: 546, 1: 54}) valid: Counter({0: 100, 1: 56}) test: Counter({0: 130, 1: 65})


In [9]:
model = timm.create_model("resnet18", pretrained=True, num_classes=2)

# for name, param in model.named_parameters():
#     if name[:6] == "layer4":
#         param.requires_grad = True
#     else:
#         param.requires_grad = False

train_transforms = A.Compose([ 
#     A.RandomSizedCrop(min_max_height=[224, 512], 
#                       height=224, width=224, p=1.0),
    A.Resize(width=224, height=224, p=1.0),
    A.ShiftScaleRotate(scale_limit=(-0.5, 0.5), rotate_limit=(0, 0), shift_limit=(0, 0), p=1),
#     A.ShiftScaleRotate(scale_limit=(-0.5, 0.5), rotate_limit=(-25, 25), shift_limit=(-0.3, 0.3), p=1),
    A.OneOf([
        A.Transpose(),
        A.HorizontalFlip(),
        A.VerticalFlip()
    ], p=0.5),
    
    A.OneOf([
        A.ElasticTransform(),
#         A.OpticalDistortion(distort_limit=2, shift_limit=0.5),
#         A.GridDistortion(),
        A.Rotate(25),
    ], p=0.8),

    A.OneOf([
       A.Blur(),
       A.GaussianBlur(),
       A.GaussNoise(),
       A.MedianBlur()
    ], p=0.2),

    A.OneOf([
#        A.ChannelShuffle(),
#        A.ColorJitter(),
#        A.HueSaturationValue(),
       A.RandomBrightnessContrast()
    ], p=0.5),

    A.Normalize(p=1.0),
    ToTensorV2()
])


valid_transforms = A.Compose([ 
    A.Resize(width=224, height=224, p=1.0),
    A.Normalize(p=1.0),
    ToTensorV2()
])


class PatchDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.transform = transform
           
    def __len__(self):    
        return len(self.df)
    
    def __getitem__(self, idx):
        x = cv2.imread(self.df.loc[idx, "fpath"])
        x = self.transform(image=x)['image']
        
        y = self.df.loc[idx, "label"]
                
        return x, torch.tensor(y).long()


train_dataset = PatchDataset(train_df, train_transforms)
valid_dataset = PatchDataset(valid_df, valid_transforms)
test_dataset = PatchDataset(test_df, valid_transforms)

counts = np.bincount(train_df.label)
labels_weights = 1. / counts
weights = labels_weights[train_df.label]
sampler = WeightedRandomSampler(weights, len(weights))

train_dataloader = DataLoader(train_dataset, batch_size=32, sampler=sampler,
                              pin_memory=True, num_workers=32)
valid_dataloader = DataLoader(valid_dataset, batch_size=256, shuffle=False,
                              pin_memory=True, num_workers=32)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False,
                             pin_memory=True, num_workers=32)

  cpuset_checked))


In [10]:
from sklearn.metrics import classification_report

class LVIClassifier(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.best_valid_loss = 999
        self.is_best_valid_loss_updated = 0
        
    
    def step(self, batch):
        X, y = batch
        preds = self.model(X)
        preds = F.softmax(preds)
        
        loss = F.cross_entropy(preds, y)
        try:
            auroc_ = torchmetrics.functional.auroc(preds, y, task="multiclass", num_classes=2)
            auprc_ = torchmetrics.functional.average_precision(preds, y, task="multiclass", num_classes=2)

            sensitivity_ = torchmetrics.functional.recall(preds, y, task="multiclass", num_classes=2)
            specificity_ = torchmetrics.functional.specificity(preds, y, task="multiclass", num_classes=2)
            f1_score_ = torchmetrics.functional.f1_score(preds, y, task="multiclass", num_classes=2)
        except:
            auroc_ = torchmetrics.functional.auroc(preds, y, num_classes=2)
            auprc_ = torchmetrics.functional.average_precision(preds, y, num_classes=2)

            sensitivity_ = torchmetrics.functional.recall(preds, y, num_classes=2)
            specificity_ = torchmetrics.functional.specificity(preds, y, num_classes=2)
            f1_score_ = torchmetrics.functional.f1_score(preds, y, num_classes=2)
        
        return preds, y, loss, auroc_, auprc_, sensitivity_, specificity_, f1_score_
    
    
    def logging(self, output, mode="train"):
        if mode != "test":
            loss, auroc_, auprc_ = output
        else:
            loss, auroc_, auprc_, sensitivity_, specificity_, f1_score_ = output
        
        self.log(f"{mode}_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log(f"{mode}_auroc", auroc_, on_step=False, on_epoch=True, prog_bar=True)
        self.log(f"{mode}_auprc", auprc_, on_step=False, on_epoch=True, prog_bar=True)
        
        if mode == "test":
            self.log(f"{mode}_sensitivity", sensitivity_, on_step=False, on_epoch=True, prog_bar=True)
            self.log(f"{mode}_specificity", specificity_, on_step=False, on_epoch=True, prog_bar=True)
            self.log(f"{mode}_f1_score", f1_score_, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def print_confusion_mat(self, outputs):
        preds = torch.cat([tmp['preds'] for tmp in outputs])
        targets = torch.cat([tmp['target'] for tmp in outputs])
        try:
            conf_mat = confusion_matrix(preds, targets, task="multiclass", num_classes=2)
        except:
            conf_mat = confusion_matrix(preds, targets, num_classes=2)
        print(f"Validation Epoch {self.trainer.current_epoch}")
        print(conf_mat.detach().cpu().numpy())
                
        print(classification_report(
            targets.detach().to("cpu").numpy(),
            [1 if p >= 0.5 else 0 for p in preds.detach().to("cpu").numpy()[:, 1]]))
        
    
    def training_step(self, batch, batch_idx):
        _, _, loss, auroc, auprc, _, _, _ = self.step(batch)
        self.logging([loss, auroc, auprc], mode="train")
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        preds, y, loss, auroc, auprc, _, _, _ = self.step(batch)
        self.logging([loss, auroc, auprc], mode="valid")
        if self.best_valid_loss > loss:
            self.best_valid_loss = loss
            self.is_best_valid_loss_updated = 1
            
        return {"preds": preds, "target": y}
    
    
    def validation_epoch_end(self, outputs):
        if self.is_best_valid_loss_updated == 1:
            self.print_confusion_mat(outputs)
            self.is_best_valid_loss_updated = 0
    
    
    def test_step(self, batch, batch_idx):
        preds, y, loss, auroc, auprc, sensitivity, specificity, f1_score = self.step(batch)
        self.logging([loss, auroc, auprc, sensitivity, specificity, f1_score], mode="test")
        
        return {"preds": preds, "target": y}
    
    
    def test_epoch_end(self, outputs):
        self.print_confusion_mat(outputs)
    
    
    def predict_step(self, batch, batch_idx):
        X, y = batch
        feats = self.model.forward_features(X)
        feats = torch.mean(torch.mean(feats, dim=-1), dim=-1)
        
        return feats, y
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10 * len(train_dataloader))
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
          
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', mode="min",
                    save_top_k=1, dirpath=f'weights/{PROJECT_NAME}', filename='{epoch:03d}-{valid_loss:.4f}-{valid_auroc:.4f}-{valid_auprc:.4f}.pt'),
]

predictor = LVIClassifier(model)
trainer = pl.Trainer(max_epochs=10 * np.round(np.log(DUTIL_PERCENT)), gpus=1, enable_progress_bar=True, 
                     callbacks=callbacks, precision=16)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [11]:
import warnings
warnings.filterwarnings(action='ignore')

trainer.fit(predictor, train_dataloader, valid_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
22.355    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Validation Epoch 0
[[93  7]
 [49  7]]
              precision    recall  f1-score   support

           0       0.65      0.93      0.77       100
           1       0.50      0.12      0.20        56

    accuracy                           0.64       156
   macro avg       0.58      0.53      0.48       156
weighted avg       0.60      0.64      0.56       156



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation Epoch 0
[[96  4]
 [24 32]]
              precision    recall  f1-score   support

           0       0.80      0.96      0.87       100
           1       0.89      0.57      0.70        56

    accuracy                           0.82       156
   macro avg       0.84      0.77      0.78       156
weighted avg       0.83      0.82      0.81       156



Validation: 0it [00:00, ?it/s]

Validation Epoch 1
[[91  9]
 [16 40]]
              precision    recall  f1-score   support

           0       0.85      0.91      0.88       100
           1       0.82      0.71      0.76        56

    accuracy                           0.84       156
   macro avg       0.83      0.81      0.82       156
weighted avg       0.84      0.84      0.84       156



Validation: 0it [00:00, ?it/s]

Validation Epoch 2
[[96  4]
 [12 44]]
              precision    recall  f1-score   support

           0       0.89      0.96      0.92       100
           1       0.92      0.79      0.85        56

    accuracy                           0.90       156
   macro avg       0.90      0.87      0.88       156
weighted avg       0.90      0.90      0.90       156



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation Epoch 4
[[95  5]
 [11 45]]
              precision    recall  f1-score   support

           0       0.90      0.95      0.92       100
           1       0.90      0.80      0.85        56

    accuracy                           0.90       156
   macro avg       0.90      0.88      0.89       156
weighted avg       0.90      0.90      0.90       156



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation Epoch 8
[[99  1]
 [15 41]]
              precision    recall  f1-score   support

           0       0.87      0.99      0.93       100
           1       0.98      0.73      0.84        56

    accuracy                           0.90       156
   macro avg       0.92      0.86      0.88       156
weighted avg       0.91      0.90      0.89       156



Validation: 0it [00:00, ?it/s]

Validation Epoch 9
[[98  2]
 [12 44]]
              precision    recall  f1-score   support

           0       0.89      0.98      0.93       100
           1       0.96      0.79      0.86        56

    accuracy                           0.91       156
   macro avg       0.92      0.88      0.90       156
weighted avg       0.91      0.91      0.91       156



Exception ignored in: <function _releaseLock at 0x7ff5a58c0b00>
Traceback (most recent call last):
  File "/home/user/miniconda3/envs/pytorch/lib/python3.7/logging/__init__.py", line 221, in _releaseLock
    def _releaseLock():
KeyboardInterrupt
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff4d5d1ce60>
Traceback (most recent call last):
  File "/home/user/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    
self._shutdown_workers()  File "/home/user/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1464, in _shutdown_workers
      File "/home/user/miniconda3/envs/pytorch/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
if w.is_alive():
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


AssertionError: 

In [12]:
predictor = LVIClassifier(model).load_from_checkpoint(
    f"weights/{PROJECT_NAME}/epoch=009-valid_loss=0.4002-valid_auroc=0.9377-valid_auprc=0.9433.pt.ckpt", 
    model=model)
trainer.test(predictor, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: 0it [00:00, ?it/s]

Validation Epoch 10
[[123   7]
 [ 10  55]]
              precision    recall  f1-score   support

           0       0.92      0.95      0.94       130
           1       0.89      0.85      0.87        65

    accuracy                           0.91       195
   macro avg       0.91      0.90      0.90       195
weighted avg       0.91      0.91      0.91       195

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_auprc           0.9246059060096741
       test_auroc           0.9515679478645325
      test_f1_score         0.9128205180168152
        test_loss           0.4012998044490814
    test_sensitivity        0.9128205180168152
    test_specificity        0.9128205180168152
──────────────────────────────────────────────────────────────

[{'test_loss': 0.4012998044490814,
  'test_auroc': 0.9515679478645325,
  'test_auprc': 0.9246059060096741,
  'test_sensitivity': 0.9128205180168152,
  'test_specificity': 0.9128205180168152,
  'test_f1_score': 0.9128205180168152}]

In [7]:
torch.save(predictor.model.state_dict(), f"weights/{PROJECT_NAME}/strong_model.pt")