In [1]:
import os
import cv2
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import Sampler
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score

import warnings
warnings.filterwarnings(action='ignore')

N_EPOCHS = 1000
BATCH_SIZE = 128
LEARNING_RATE = 0.0005
PAITIENCE = 30

IM_HEIGHT = 256
IM_WIDTH = 256

model_name = "resnet50"

In [2]:
def generate_patch_df(flist, label):
    df = pd.DataFrame({"fpath": flist})
    df['slide_id'] = df['fpath'].map(lambda x: x.split("/")[-1].split(".")[0].split("_")[0])
    df['patient_id'] = df['slide_id'].map(lambda x: x.split("-")[0])
    df['target'] = label

    df = df.loc[:, ["patient_id", "slide_id", "fpath", "target"]]
    
    return df


def train_test_split(positive_df, negative_df, sampling_level=2, sampling_rate=0.2):
    # sampling_level {0: "patient_id", 1: "slide_id", 2: "patch"}
    if sampling_level == 0:
        column_name = "patient_id"
    elif sampling_level == 1:
        column_name = "slide_id"
    elif sampling_level == 2:
        column_name = "fpath"
    else:
        print("Set sampling level in [0, 1, 2]")
        raise
    
    N = len(pd.unique(positive_df[column_name]))
    
    test_index = np.random.choice(pd.unique(positive_df[column_name]), round(N * sampling_rate), replace=False)
    
    train_positive = positive_df[~positive_df[column_name].isin(test_index)]
    test_positive = positive_df[positive_df[column_name].isin(test_index)]

    train_negative = negative_df[~negative_df[column_name].isin(test_index)]
    test_negative = negative_df[negative_df[column_name].isin(test_index)]
    
    train_df = pd.concat([train_positive, train_negative]).reset_index(drop=True)
    test_df = pd.concat([test_positive, test_negative]).reset_index(drop=True)
    
    train_df, valid_df = train_valid_split(train_df, column_name, sampling_rate)
    
    return train_df, valid_df, test_df


def train_valid_split(train_df, column_name, sampling_rate):
    N = len(pd.unique(train_df[column_name]))
    valid_index = np.random.choice(pd.unique(train_df[column_name]), round(N * sampling_rate), replace=False)
        
    valid_df = train_df[train_df[column_name].isin(valid_index)]
    train_df = train_df[~train_df[column_name].isin(valid_index)]
    
    return train_df.reset_index(drop=True), valid_df.reset_index(drop=True)


positive_flist = glob.glob("../data/LVI_dataset/patch_image_size-400_overlap-100/positive/*.png")
negative_flist = glob.glob("../data/LVI_dataset/patch_image_size-400_overlap-100/negative/*.png")

positive_df = generate_patch_df(positive_flist, 1)
negative_df = generate_patch_df(negative_flist, 0)

train_df, valid_df, test_df = train_test_split(positive_df, negative_df, sampling_level=2, sampling_rate=0.2)
print(f"train_df: {train_df.shape}\nvalid_df: {valid_df.shape}\ntest_df: {test_df.shape}")

train_df: (0, 4)
valid_df: (0, 4)
test_df: (0, 4)


In [4]:
train_transforms = A.Compose([ 
    A.RandomCrop(width=IM_WIDTH, height=IM_HEIGHT, p=1.0),
    
    A.OneOf([
        A.Transpose(),
        A.HorizontalFlip(),
        A.VerticalFlip()
    ], p=0.5),

    A.OneOf([
       A.ShiftScaleRotate(),
       A.ElasticTransform(),
       A.RandomScale()
    ], p=0.5),

    A.OneOf([
       A.Blur(),
       A.GaussianBlur(),
       A.GaussNoise(),
       A.MedianBlur()
    ], p=0.5),

    A.OneOf([
       A.ChannelShuffle(),
       A.ColorJitter(),
       A.HueSaturationValue(),
       A.RandomBrightnessContrast()
    ], p=0.5),

    A.Normalize(p=1.0),
    ToTensorV2()
])


valid_transforms = A.Compose([ 
    A.Resize(width=IM_WIDTH, height=IM_HEIGHT, p=1.0),
    A.Normalize(p=1.0),
    ToTensorV2()
])


class LVIDataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        image  = cv2.imread(self.df.iloc[idx, 0])
        target = self.df.iloc[idx, 1]

        augmented = self.transforms(image=image)
        image = augmented['image']  
        
        return image, target
    
    def get_labels(self):
        return list(self.df.target.values)


class BalanceClassSampler(Sampler):
    """Abstraction over data sampler.
    Allows you to create stratified sample on unbalanced classes.
    """

    def __init__(self, labels, mode="downsampling"):
        """
        Args:
            labels (List[int]): list of class label
                for each elem in the datasety
            mode (str): Strategy to balance classes.
                Must be one of [downsampling, upsampling]
        """
        super().__init__(labels)

        labels = np.array(labels)
        samples_per_class = {
            label: (labels == label).sum() for label in set(labels)
        }

        self.lbl2idx = {
            label: np.arange(len(labels))[labels == label].tolist()
            for label in set(labels)
        }

        if isinstance(mode, str):
            assert mode in ["downsampling", "upsampling"]

        if isinstance(mode, int) or mode == "upsampling":
            samples_per_class = (
                mode
                if isinstance(mode, int)
                else max(samples_per_class.values())
            )
        else:
            samples_per_class = min(samples_per_class.values())

        self.labels = labels
        self.samples_per_class = samples_per_class
        self.length = self.samples_per_class * len(set(labels))

    def __iter__(self):
        """
        Yields:
            indices of stratified sample
        """
        indices = []
        for key in sorted(self.lbl2idx):
            replace_ = self.samples_per_class > len(self.lbl2idx[key])
            indices += np.random.choice(
                self.lbl2idx[key], self.samples_per_class, replace=replace_
            ).tolist()
        assert len(indices) == self.length
        np.random.shuffle(indices)

        return iter(indices)

    def __len__(self) -> int:
        """
        Returns:
             length of result sample
        """
        return self.length

    
train_dataset = LVIDataset(df=train_df, transforms=train_transforms)
train_iterator = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=8, shuffle=False, 
    sampler=BalanceClassSampler(labels=train_dataset.get_labels(), mode='downsampling'))

valid_dataset = LVIDataset(df=valid_df, transforms=valid_transforms)
valid_iterator = DataLoader(valid_dataset, batch_size=BATCH_SIZE, num_workers=8, shuffle=False,
    sampler=BalanceClassSampler(labels=valid_dataset.get_labels(), mode='downsampling'))

test_dataset = LVIDataset(df=test_df, transforms=valid_transforms)
test_iterator = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=8, shuffle=False)

ValueError: min() arg is an empty sequence

In [None]:
# device       = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device="cpu"
model        = timm.create_model(model_name, num_classes=2, pretrained=True).to(device)
optimizer    = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler    = ReduceLROnPlateau(optimizer, 'min')
class_weight = torch.tensor([1.1, 1.0]).to(device)
criterion    = nn.CrossEntropyLoss(weight=class_weight)

for name, param in model.named_parameters():
    param.requires_grad = True
#     if name[:4] == "head":
#         param.requires_grad = True
#     else:
#         param.requires_grad = False

In [None]:
def train(model, iterator, criterion, optimizer, device=device):  
    model.train()
    epoch_loss = 0
    correct    = 0    
    
    for image, target in iterator:
        optimizer.zero_grad()
        image  = image.to(device)
        target = target.long().to(device)
        
        output = model(image).squeeze()
        
        loss = criterion(output, target)
        loss.backward()
        epoch_loss += loss.item()
        
        pred     = torch.argmax(output, axis=1)
        correct += (pred == target).sum()

        optimizer.step()

    return epoch_loss / len(iterator.dataset), correct / len(iterator.dataset)


@torch.no_grad()
def evaluate(model, iterator, criterion, device=device):
    model.eval()
    epoch_loss = 0
    correct    = 0 
    
    for image, target in iterator:
        image  = image.to(device)
        target = target.long().to(device)

        output = model(image).squeeze()

        loss = criterion(output, target)
        epoch_loss += loss.item()

        pred     = torch.argmax(output, axis=1)
        correct += (pred == target).sum()
            
    return epoch_loss / len(iterator.dataset), correct / len(iterator.dataset)


@torch.no_grad()
def predict(model, iterator, device=device):
    model.eval()
    pred = []
    true = []
    
    for image, target in iterator:
        image  = image.to(device)
        target = target.long().to(device)

        output = model(image)

        pred.append(output.to("cpu").tolist()[0])
        true.append(target.to("cpu").tolist()[0])

    return np.argmax(pred, axis=1), true


def print_train_log(epoch_num, train_loss, valid_loss, train_acc, valid_acc):
    print(f"EPOCH: {epoch_num:04}")
    print(f"Train loss: {round(train_loss, 4)}\tTrain acc : {round(float(train_acc), 4)}\tValid loss: {round(valid_loss, 4)}\tValid acc : {round(float(valid_acc), 4)}")
    
    
def compute_test_metrics(true, pred):
    confusion_mat = confusion_matrix(true, pred)
    accuracy      = accuracy_score(true, pred)
    precision     = precision_score(true, pred)
    recall        = recall_score(true, pred)
    f1            = f1_score(true, pred)
    
    return confusion_mat, accuracy, precision, recall, f1


def print_test_log(epoch_num, accuracy, precision, recall, f1):
    print(f"EPOCH: {epoch_num:04} prediction results ")
    print(f"confusion matrix\n{confusion_mat}")
    print(f"accuracy score  : {round(accuracy, 4)}")
    print(f"precision score : {round(precision, 4)}")
    print(f"recall score    : {round(recall, 4)}")
    print(f"f1 score        : {round(f1, 4)}")

In [None]:
start_epoch = 0
if len(glob.glob()) != 0:
    print("load trained model ... ")
    start_epoch = len(glob.glob(os.path.join(os.path.join("output", model_name, "*.txt")))) - 1 
    model.load_state_dict(torch.load("weights/best_" + model_name + ".pt"))

n_paitience = 0
best_valid_loss = float('inf')

optimizer.zero_grad()
optimizer.step()

for epoch_num in range(start_epoch, N_EPOCHS):
    train_loss, train_acc = train(model, train_iterator, criterion, optimizer, device)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)
    
    scheduler.step(valid_loss)
    
    print_train_log(epoch_num, train_loss, valid_loss, train_acc, valid_acc)
    with open("output/resnet50/log.txt", "a") as f:
        f.write("epoch: {0:04d} train loss: {1:.4f}, valid loss: {2:.4f}, train Acc: {3:.4f}, valid Acc: {4:.4f}\n".format(epoch_num, train_loss, valid_loss, train_acc, valid_acc))

    if n_paitience < PAITIENCE:
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), "weights/best_" + model_name + ".pt")
            n_paitience = 0
        elif best_valid_loss <= valid_loss:
            n_paitience += 1
    else:
        print("Early stop!")
        model.load_state_dict(torch.load("weights/best_" + model_name + ".pt"))
        break
        
    if epoch_num % 1 == 0:
        
        pred, true = predict(model, test_iterator)
        confusion_mat, accuracy, precision, recall, f1 = compute_test_metrics(true, pred)
        
        print_test_log(epoch_num, accuracy, precision, recall, f1)
        with open(os.path.join("output", model_name, f"epoch_{epoch_num:04d}_eval_metrics.txt"), "a") as f:
            f.write("accuracy score: {0:.4f}, precision score: {1:.4f}, recall score: {2:.4f}, f1 score: {3:.4f}\n".format(accuracy, precision, recall, f1))
            

In [7]:
model.load_state_dict(torch.load('weights/tumor_classification/best_resnet50.pt'))

pred, true = predict(model, test_iterator)
confusion_mat, accuracy, precision, recall, f1 = compute_test_metrics(true, pred)
        
print_test_log(epoch_num, accuracy, precision, recall, f1)

EPOCH: 0133 prediction results 
confusion matrix
[[ 736    9]
 [  17 1636]]
accuracy score  : 0.9892
precision score : 0.9945
recall score    : 0.9897
f1 score        : 0.9921
