In [4]:
PROJECT_NAME = "vit"
LEARNING_RATE = 1e-4

import os
import cv2
import glob
import timm
import random
import pickle
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler, RandomSampler

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import torchmetrics
from torchmetrics.functional import confusion_matrix
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

batch_size = 256

test_df = glob.glob("../data/*.png")
test_df = pd.DataFrame(test_df, columns=["fpath"])
test_df.loc[:, "Y"] = test_df.fpath.map(lambda x: 1 if x.split("/")[-1].split("_")[0] == "pos" else 0)
test_df.loc[:, "fname"] = test_df.fpath.map(lambda x: x.split("/")[-1])

test_df

Unnamed: 0,fpath,Y,fname
0,../data/pos_A00032_39919-4762.png,1,pos_A00032_39919-4762.png
1,../data/neg_A00020_30254-10225.png,0,neg_A00020_30254-10225.png
2,../data/pos_A00007_36844-16740.png,1,pos_A00007_36844-16740.png
3,../data/neg_A00041_32121-6353.png,0,neg_A00041_32121-6353.png
4,../data/neg_A00011_26462-14125.png,0,neg_A00011_26462-14125.png
...,...,...,...
968,../data/pos_A00036_24924-29724.png,1,pos_A00036_24924-29724.png
969,../data/pos_A00042_24414-32208.png,1,pos_A00042_24414-32208.png
970,../data/pos_A00008_8902-13743.png,1,pos_A00008_8902-13743.png
971,../data/pos_A00008_40683-11016.png,1,pos_A00008_40683-11016.png


In [5]:
valid_transforms = A.Compose([ 
    A.Resize(width=224, height=224, p=1.0),
    A.Normalize(p=1.0),
    ToTensorV2()
])


class LVIDataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        fname = self.df.loc[idx, "fname"]
        target = self.df.loc[idx, "Y"]
        fpath = self.df.loc[idx, "fpath"]
        image  = cv2.imread(fpath)
    
        augmented = self.transforms(image=image)
        image = augmented['image']  
        
        return image, torch.tensor(target).long(), fname
    
    
test_dataset = LVIDataset(test_df, valid_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, 
                              num_workers=16, pin_memory=True)

In [6]:
class LVIPatchClassifier(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    
    def step(self, batch):
        X, y, fname = batch
        pred = self.model(X).squeeze(-1)
        loss = F.cross_entropy(pred, y.long())
        
        pred = F.softmax(pred)[:, 1]
        
        recall_ = torchmetrics.functional.recall(pred, y.long(), task="binary")
        auroc_ = torchmetrics.functional.auroc(pred, y.long(), task="binary")
        auprc_ = torchmetrics.functional.average_precision(pred, y.long(), task="binary")
        
        return pred, y, fname, loss
    
    
    def logging(self, logging_object, mode="train"):
        auroc, auprc = logging_object
        
        self.log(f'{mode}_auroc', auroc, on_step=False, on_epoch=True, prog_bar=True)
        self.log(f'{mode}_auprc', auprc, on_step=False, on_epoch=True, prog_bar=True)
    
    def training_step(self, batch, batch_idx):
        _, _, _, loss = self.step(batch)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        preds, y, _, loss = self.step(batch)
        self.log("valid_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
    
    def test_step(self, batch, batch_idx):
        preds, y, _, loss = self.step(batch)
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
   
    def predict_step(self, batch, batch_idx):
        preds, y, fname, loss = self.step(batch)

        return {"preds": preds, "target": y, "fname": fname}
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10 * len(train_dataloader))
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
model = timm.create_model("convit_small", pretrained=True, num_classes=2)
classifier = LVIPatchClassifier(model)

callbacks = [
    ModelCheckpoint(monitor='valid_auroc', mode="max",
                    save_top_k=1, dirpath=f'weights/{PROJECT_NAME}', filename='{epoch:03d}-{valid_loss:.4f}-{valid_recall:.4f}-{valid_auroc:.4f}-{valid_auprc:.4f}'),
]

trainer = pl.Trainer(max_epochs=100, gpus=[1], 
                     enable_progress_bar=True, 
                     callbacks=callbacks, precision=16)

  rank_zero_deprecation(
Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
from sklearn.metrics import roc_auc_score, average_precision_score, recall_score, confusion_matrix, accuracy_score, precision_score, f1_score

classifier = LVIPatchClassifier(model).load_from_checkpoint(
    "weights/convit_lightning.ckpt", model=model, strict=False)

pred = trainer.predict(classifier, test_dataloader)

fname_list = []
pred_list = []
target_list = []
for batch in pred:
    fname_list += batch['fname']
    pred_list += batch['preds'].detach().cpu().numpy().tolist()
    target_list += batch['target'].detach().cpu().numpy().tolist()
    
print(f"AUROC: {roc_auc_score(target_list, pred_list).round(4)}")
print(f"AUPRC: {average_precision_score(target_list, pred_list).round(4)}")
print(f"Accuracy: {accuracy_score(target_list, [1 if pred >= 0.5 else 0 for pred in pred_list]).round(4)}")
print(f"Precision: {precision_score(target_list, [1 if pred >= 0.5 else 0 for pred in pred_list]).round(4)}")
print(f"Recall: {recall_score(target_list, [1 if pred >= 0.5 else 0 for pred in pred_list]).round(4)}")
print(f"F1 score: {f1_score(target_list, [1 if pred >= 0.5 else 0 for pred in pred_list]).round(4)}")

print("Confusion matrix")
print(confusion_matrix(target_list, [1 if pred >= 0.5 else 0 for pred in pred_list]))

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

  pred = F.softmax(pred)[:, 1]


AUROC: 0.9184
AUPRC: 0.8726
Accuracy: 0.8674
Precision: 0.7781
Recall: 0.8013
F1 score: 0.7896
Confusion matrix
[[602  69]
 [ 60 242]]


In [8]:
pred_df = pd.DataFrame({
    "fname": fname_list,
    "target": target_list,
    "pred": pred_list
})

pred_df.to_csv("output/convit_inference.csv", index=False)