# Introduction

**Please find the public repo at: https://github.com/jenniferG328/Final-Project_CS598_HealthCare**

## Background of the Problem
#### Context: 
Advances in deep learning have enhanced automated medical image analysis, but existing techniques face high computational requirements and performance drops with reduced batch sizes or training epochs.

#### Problem Type: 
The paper addresses issues in medical image analysis, particularly focusing on self-supervised learning approaches for processing label-free images efficiently.

#### Importance: 

Solving this problem is crucial because data labeling in medical imaging is expensive and time-consuming, and often data is scarce, especially for emerging diseases like certain autoimmune conditions.

#### Difficulty: 

Challenges include minimal data availability, the need for domain-specific knowledge for labeling, patient privacy issues, and an incomplete understanding of diseases.

#### State-of-the-Art Methods and Effectiveness

Current Techniques: The prevailing methods in medical image analysis primarily rely on self-supervised learning frameworks utilizing either Convolutional Neural Networks (CNNs) or Transformers. These techniques are heavily dependent on extensive datasets and large batch sizes.

Performance Challenges: A notable limitation of these existing approaches is their significant reduction in performance when the conditions of large datasets and batch sizes are not met. This issue becomes more pronounced with constrained computational resources.

Computational Demands: Current state-of-the-art methods require considerable computational power, which poses a barrier to their application, especially in settings with limited resources. Such extensive computational requirements limit the practical accessibility of these advanced techniques.

## Paper Explanation
#### Proposal: 
The paper introduces Cross Architectural Self-Supervision (CASS), a novel method combining CNNs and Transformers in a self-supervised learning setting. It addresses the challenges of limited data and computational resources in medical image analysis.

#### Innovations: 
CASS leverages both CNN and Transformer architectures simultaneously, improving robustness to changes in batch size and training epochs and reducing computational requirements.

#### Performance: 
Demonstrated improvements across four medical datasets in terms of F1 Score and Recall, using less labeled data and significantly less training time compared to existing methods.

#### Contribution: 
CASS represents a significant step in self-supervised learning for medical image analysis, especially beneficial for emerging diseases with limited data. It stands out in efficiency, effectiveness, and adaptability to resource constraints.

# Scope of Reproducibility 

**Hypotheses to be Tested**

-  Hypotheses 1  
Reproduced CASS-trained models outperform existing self-supervised learning methods in terms of
accuracy and efficiency on healthcare tasks shown in the paper, i.e., disease cell classification, brain
tumor classification, and skin lesion classification.

-  Hypotheses 2  
Reproduced CASS demonstrates greater robustness to variations in batch size and pretraining epochs
compared to current methods

# Methodology:

## Data

#### Data descriptions

Autoimmune Diseases Biopsy Slides Dataset: This dataset includes 198 TIFF images from muscle biopsies of dermatomyositis patients. These slides are stained with different proteins to help diagnose dermatomyositis, a type of autoimmune disease. The dataset involves multi-label classification for different cell classes, such as TFH-1, TFH-217, TFH-Like B cells, and others. The images are consistent in size, measuring 352 by 469 pixels in RGB format .

Dermofit Dataset: Comprising 1300 normal RGB images captured indoors with an SLR camera and ring lightning, this dataset categorizes images into 10 classes associated with skin lesions and conditions. The images vary in size, ranging from 205×205 to 1020×1020 pixels, with no two images being the same size. The dataset's primary task is multi-class classification .

Brain Tumor MRI Dataset: This dataset includes 7022 images of human brain MRIs, classified into four categories: glioma, meningioma, no tumor, and pituitary. The images vary in size from 512×512 to 219×234 pixels. The dataset's source is a combination of different datasets and includes 5712 images for training and 1310 for testing .


Using Brain Tumor MRI Dataset for a demo in this report, switch the  name to run experiments on other datasets.

#### Data Access

Data will be accessed through collaborations with healthcare institutions and the use of publicly available medical datasets.

We could refer to the following links for the datasets.

(1) https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset

(2) https://homepages.inf.ed.ac.uk/rbf/DERMOFIT/datasets.htm

(3) https://challenge.isic-archive.com/data/

(4) https://github.com/pranavsinghps1/DEDL/data


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
import numpy as np
import matplotlib.pyplot as plt

import os
import pytorch_lightning as pl
import pandas as pd
import timm
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import KFold
import math

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchcontrib.optim import SWA
from torchmetrics import Metric
from torch.utils.tensorboard import SummaryWriter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Data preprocessing and augmentation
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


train_dataset = ImageFolder('brain_tumor/Training', transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

val_dataset = ImageFolder('brain_tumor/Testing', transform=data_transforms)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)


# Load a batch of images and labels for visualization
data_iter = iter(train_loader)
images, labels = next(data_iter)

# Convert images to numpy arrays and denormalize
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
images = (images.numpy().transpose((0, 2, 3, 1)) * std + mean).clip(0, 1)

# Create a grid of images
num_images = len(images)
rows = int(np.ceil(num_images / 4))
fig, axes = plt.subplots(rows, 4, figsize=(15, 15))

# Plot images with labels
for i, ax in enumerate(axes.flat):
    if i < num_images:
        ax.imshow(images[i])
        ax.set_title(f'Label: {train_dataset.classes[labels[i]]}')
    ax.axis('off')

plt.tight_layout()
plt.show()

## Model



#### Model Architecture
Layers and Types: CASS employs a dual architecture comprising a Convolutional Neural Network (CNN) and a Vision Transformer (ViT). Specific examples mentioned are ResNet-50 for the CNN and ViT Base/16 for the Transformer.
Activation Functions: The paper does not explicitly mention the activation functions used in the model architectures, but standard practices for ResNet and ViT typically involve ReLU and GELU activations, respectively.
#### Training Objectives
Loss Function: The model uses a cosine similarity-based loss function and focal loss, specifically designed for comparing the logits outputs from the CNN and ViT.
#### Optimizer: 
The training employs Adam optimizer with a learning rate of 1e-3 for both the CNN and ViT, along with stochastic weight averaging (SWA) for optimization.

#### Weight of Each Loss Term: 

Details about the weight of each loss term are not explicitly mentioned. However, the loss is computed as the mean value of all elements in the tensor derived from the cosine similarity calculation between the outputs of the CNN and ViT.
#### Others
Pretraining: CASS is based on self-supervised learning, and it is mentioned that the models were trained from ImageNet initialization for 100 epochs.

Training Process: The paper describes the training process, mentioning that they use the same set of augmentations in self-supervised training and also detail the hyper-parameters for both the self-supervised and supervised training phases.

In [None]:
"""
Define Focal-Loss
"""

class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-12  # prevent training from Nan-loss error
        self.cls_weights = torch.tensor([CFG.cls_weight],dtype=torch.float, requires_grad=False, device=CFG.device)

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = torch.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt
        focal_loss = focal_loss * self.cls_weights
        return torch.mean(focal_loss)

In [None]:
class CFG:
    # label_num2str = {'0': 'glioma', '1': 'meningioma', '2': 'notumor', '3': 'pituitary'}
    label_num2str = {0: 'glioma', 1: 'meningioma', 2: 'notumor', 3: 'pituitary'}
    label_str2num = {'glioma': '0', 'meningioma': '1', 'notumor': '2','pituitary': '3'}
    fl_alpha = 1.0  # alpha of focal_loss
    fl_gamma = 2.0  # gamma of focal_loss
    cls_weight =  [0.5, 0.5, 0.5, 0.5]
    cnn_name='resnet50'
    vit_name='vit_base_patch16_224'
    seed = 77
    num_classes = 4
    batch_size = 16
    t_max = 16
    lr = 1e-3
    min_lr = 1e-6
    n_fold = 6
    num_workers = 8
    gpu_idx = 0
    device = torch.device(f'cuda:{gpu_idx}' if torch.cuda.is_available() else 'cpu')
    gpu_list = [gpu_idx]
    CNN = True
    VIT = False
    
cfg=CFG()

In [None]:
model_cnn = timm.create_model(cfg.cnn_name, pretrained=True)
model_vit = timm.create_model(cfg.vit_name, pretrained=True)
model_cnn.to(device)
model_vit.to(device)

## Training

#### Computational requirements and GPU Utilization: 
The model was trained on a single NVIDIA RTX8000 GPU, which significantly facilitated a reduced training time. We use NVIDIA RXT2080ti to reproduce all the experiments.

#### Training Time Efficiency: 
The paper highlights that CASS took substantially less time compared to the DINO method for self-supervised training. For example, on the Autoimmune Diseases Biopsy Slides dataset, CASS required only 21 minutes compared to DINO's 1 hour 13 minutes. 

In [None]:
def ssl_train_model(train_loader,model_vit,criterion_vit,optimizer_vit,scheduler_vit,model_cnn,criterion_cnn,optimizer_cnn,scheduler_cnn,num_epochs):
    writer = SummaryWriter()
    phase = 'train'
    model_cnn.train()
    model_vit.train()
    f1_score_cnn=0
    f1_score_vit=0
    for i in tqdm(range(num_epochs)):
        with torch.set_grad_enabled(phase == 'train'):
            for img,_ in tqdm(train_loader):
                f1_score_cnn=0
                f1_score_vit=0
                img = img.to(device)
                pred_vit = model_vit(img)
                pred_cnn = model_cnn(img)
                model_sim_loss=loss_fn(pred_vit,pred_cnn)
                loss = model_sim_loss.mean()
                loss.backward()
                optimizer_cnn.step()
                optimizer_vit.step()
                scheduler_cnn.step()
                scheduler_vit.step()
            print('For -',i,'Loss:',loss) 
            writer.add_scalar("Self-Supervised Loss/train", loss, i)
    writer.flush()

In [None]:
optimizer_cnn = SWA(torch.optim.Adam(model_cnn.parameters(), lr= 1e-3))
optimizer_vit = SWA(torch.optim.Adam(model_vit.parameters(), lr= 1e-3))
scheduler_cnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_cnn,
                                                                    T_max=16,
                                                                    eta_min=1e-6)
scheduler_vit = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vit,
                                                                    T_max=16,
                                                                    eta_min=1e-6)


fl_alpha = 1.0  # alpha of focal_loss
fl_gamma = 2.0  # gamma of focal_loss
cls_weight = [0.5, 0.5, 0.5, 0.5]
criterion_vit = FocalLoss(fl_alpha, fl_gamma)
criterion_cnn = FocalLoss(fl_alpha, fl_gamma)

In [None]:
def loss_fn(x, y):
    x =  torch.nn.functional.normalize(x, dim=-1, p=2)
    y =  torch.nn.functional.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

In [None]:
ssl_train_model(train_loader,model_vit,criterion_vit,optimizer_vit,scheduler_vit,model_cnn,criterion_cnn,optimizer_cnn,scheduler_cnn,num_epochs=10)
    #Saving SSL Models
print('Saving Cov-T')
    
torch.save(model_cnn,'./cass-r50.pt')
torch.save(model_vit,'./cass-vit.pt')

## Downstream Tuning and Evaluation

#### Metrics Descriptions in CASS Paper

- **Metric Used**: F1 Score
- **Definition**: 
  - F1 Score is calculated as `F1 = 2 × (Precision × Recall) / (Precision + Recall)`, 
    where Precision is the ratio of true positive predictions to the total positive predictions,
    and Recall is the ratio of true positive predictions to the total actual positives.
- **Reason for Choice**: 
  - Selected based on previous work or as defined by the dataset provider.

In [None]:
"""
Define F1 score metric
"""
class MyF1Score(Metric):
    def __init__(self, cfg, threshold: float = 0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.cfg = cfg
        self.threshold = threshold
        self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        preds_str_batch = self.num_to_str(torch.sigmoid(preds))
        target_str_batch = self.num_to_str(target)
        tp, fp, fn = 0, 0, 0
        for pred_str_list, target_str_list in zip(preds_str_batch, target_str_batch):
            for pred_str in pred_str_list:
                if pred_str in target_str_list:
                    tp += 1
                if pred_str not in target_str_list:
                    fp += 1

            for target_str in target_str_list:
                if target_str not in pred_str_list:
                    fn += 1
        self.tp += tp
        self.fp += fp
        self.fn += fn

    def compute(self):
        f1 = 2.0 * self.tp / (2.0 * self.tp + self.fn + self.fp)
        return f1
    
    def num_to_str(self, ts: torch.Tensor) -> list:
        batch_bool_list = (ts > self.threshold).detach().cpu().numpy().tolist()
        batch_str_list = []
        for one_sample_bool in batch_bool_list:
            lb_str_list = [self.cfg.label_num2str[lb_idx] for lb_idx, bool_val in enumerate(one_sample_bool) if bool_val]
            batch_str_list.append(lb_str_list)
        return batch_str_list

In [None]:
if cfg.CNN:

    model_vit=torch.load('./cass-vit.pt')
    model_cnn=torch.load('./cass-r50.pt')
    last_loss=math.inf
    val_loss_arr=[]
    train_loss_arr=[]
    counter=0
        
    model_cnn.to(device)
    model_vit.to(device)
    print('*'*10)
    
        
    #Train Correspong Supervised CNN
    print('Fine tunning Cov-T')
    writer = SummaryWriter()
    model_cnn.fc=nn.Linear(in_features=2048, out_features=4, bias=True)
    criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
    metric = MyF1Score(cfg)
    val_metric=MyF1Score(cfg)
    optimizer = torch.optim.Adam(model_cnn.parameters(), lr = 3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr)
    model_cnn.train()
    from torch.autograd import Variable
    best=0
    best_val=0
    for epoch in tqdm(range(200)):
        total_loss = 0
        for images, label in train_loader:
            model_cnn.train()
            images = images.to(device)
            label = label.to(device)
            model_cnn.to(device)
            pred_ts=model_cnn(images)
            label_one_hot = torch.nn.functional.one_hot(label, num_classes=4).float()
            # print(pred_ts.shape, label.shape,label)
            loss = criterion(pred_ts, label_one_hot)
            score = metric(pred_ts,label_one_hot)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            total_loss += loss.detach()
        avg_loss=total_loss/ len(train_loader)
        train_score=metric.compute()
        logs = {'train_loss': avg_loss, 'train_f1': train_score, 'lr': optimizer.param_groups[0]['lr']}
        writer.add_scalar("CNN Supervised Loss/train", loss, epoch)
        writer.add_scalar("CNN Supervised F1/train", train_score, epoch)
        print(logs)
        if best < train_score:
            best=train_score
            model_cnn.eval()
            total_loss = 0
            for images,label in val_loader:
                images = images.to(device)
                label = label.to(device)
                model_cnn.to(device)
                pred_ts=model_cnn(images)
                label_one_hot = torch.nn.functional.one_hot(label, num_classes=4).float()
                score_val = val_metric(pred_ts, label_one_hot)
                val_loss = criterion(pred_ts, label_one_hot)
                total_loss += val_loss.detach()
            avg_loss=total_loss/ len(train_loader)   
            print('Val Loss:',avg_loss)
            val_score=val_metric.compute()
            print('CNN Validation Score:',val_score)
            writer.add_scalar("CNN Supervised F1/Validation", val_score, epoch)
            if avg_loss > last_loss:
                counter+=1
            else:
                counter=0
                    
            last_loss = avg_loss
            if counter > 5:
                print('Early Stopping!')
                break
            else:
                if val_score > best_val:
                    best_val=val_score
                    print('Saving')
                    torch.save(model_cnn,
                        './cass-r50-tuned.pt')


if cfg.VIT:
    print('Fine tunning Cov-T')
    writer = SummaryWriter()
    
    model_vit=torch.load('./cass-vit.pt')
    model_cnn=torch.load('./cass-r50.pt')
    last_loss=math.inf
    val_loss_arr=[]
    train_loss_arr=[]
    counter=0
        
    model_cnn.to(device)
    model_vit.to(device)
    print('*'*10)
    
    writer.flush()
    last_loss=999999999
    val_loss_arr=[]
    train_loss_arr=[]
    counter=0
    # Training the Corresponding ViT
    model_vit.head=nn.Linear(in_features=768, out_features=4, bias=True)
    criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
    metric = MyF1Score(cfg)
    optimizer = torch.optim.Adam(model_vit.parameters(), lr = 3e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr)
    model_vit.train()
    val_metric=MyF1Score(cfg)
    writer = SummaryWriter()
    from torch.autograd import Variable
    best=0
    best_val=0
    for epoch in tqdm(range(200)):
        total_loss = 0
        for images,label in train_loader:
            model_vit.train()
            images = images.to(device)
            label = label.to(device)
            label = torch.nn.functional.one_hot(label, num_classes=4).float().to(device)
            model_vit.to(device)
            pred_ts=model_vit(images)
            loss = criterion(pred_ts, label)
            score = metric(pred_ts,label)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            total_loss += loss.detach()
        avg_loss=total_loss/ len(train_loader)
        train_score=metric.compute()
        logs = {'train_loss': loss, 'train_f1': train_score, 'lr': optimizer.param_groups[0]['lr']}
        writer.add_scalar("ViT Supervised Loss/train", loss, epoch)
        writer.add_scalar("ViT Supervised F1/train", train_score, epoch)
        print(logs)
        if best < train_score:
            best=train_score
            model_vit.eval()
            total_loss = 0
            for images,label in val_loader:
                images = images.to(device)
                label = label.to(device)
                label = torch.nn.functional.one_hot(label, num_classes=4).float().to(device)
                model_vit.to(device)
                pred_ts=model_vit(images)
                score_val = val_metric(pred_ts,label)
                val_loss = criterion(pred_ts, label)
                total_loss += val_loss.detach()
            avg_loss=total_loss/ len(train_loader)
            val_score=val_metric.compute()
            print('ViT Validation Score:',val_score)
            print('Val Loss:',avg_loss)
            writer.add_scalar("ViT Supervised F1/Validation", val_score, epoch)
            if avg_loss > last_loss:
                counter+=1
            else:
                counter=0
                    
            last_loss = avg_loss
            if counter > 5:
                print('Early Stopping!')
                break
            else:
                if val_score > best_val:
                    best_val=val_score
                    print('Saving')
                    torch.save(model_vit,
                                       './cass-vit-tuned.pt')
                            
        writer.flush()                
        print('*'*10)

## Results

### Performance Summary
#### Autoimmune Diseases Biopsy Slides Dataset

- CASS Resnet-50
  - F1 Score: 0.8621
- CASS ViT B/16 g
  - F1 Score: 0.8781

#### Dermofit Dataset

- CASS Resnet-50
  - F1 Score: 0.7112
- CASS ViT B/16 g
  - F1 Score: 0.6675

#### Brain Tumor MRI Dataset

- CASS Resnet-50
  - F1 Score: 0.9859
- CASS ViT B/16 g
  - F1 Score: 0.9211


### Analysis
- **Replication**: The replication of the study yielded results that are comparable to those presented in the original paper, affirming the robustness and reliability of the CASS model.
- **Performance Across Datasets**: The model's effectiveness is highlighted across various datasets: in autoimmune disease biopsy analysis, it demonstrated improved F1 scores with 100% labeled data. In the Dermofit dataset, CASS outshone both supervised and other self-supervised methods, showcasing its proficiency in handling diverse skin lesion types. For the Brain MRI Classification dataset, the model showed a notable improvement in bringing the performance of CNNs and Transformers closer. In the challenging ISIC-2019 dataset, known for class imbalances and inconsistent images, CASS again proved superior, especially in scenarios with limited labeled data.
- **Model's Strengths**: These results across different medical imaging datasets emphasize CASS's adaptability to varied image characteristics and its robustness in scenarios with limited labeled data.

### Conclusion
- The CASS model stands out as a significant advancement in the field of medical image analysis. Its ability to maintain high accuracy and efficiency across diverse and challenging datasets positions it as a powerful tool for AI-driven medical diagnostics, especially in conditions with sparse data availability.

### Future Plan: Ablation Studies
- **Objective**: These studies aim to identify the optimal training configurations for the CASS model, ensuring it is well-tuned for diverse medical imaging tasks.
- **Focus Areas**:
  - **Training Epochs**: Assessing the impact of varying the number of epochs on the model's accuracy and learning efficiency.
  - **Optimizers**: Examining how different optimizer choices influence the model’s performance.
  - **Batch Size**: Understanding the effects of batch size variations on the model's training dynamics.
  - **Encoder Size**: Exploring the influence of different encoder sizes on the model’s capability to process and learn from medical images.




# Test

## If you'd like to test the checkpoints, use the the below section for testing

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
import numpy as np
import matplotlib.pyplot as plt

import os
import pytorch_lightning as pl
import pandas as pd
import timm
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import KFold
import math

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchcontrib.optim import SWA
from torchmetrics import Metric
from torch.utils.tensorboard import SummaryWriter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Data preprocessing and augmentation
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


train_dataset = ImageFolder('brain_tumor/Training', transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

val_dataset = ImageFolder('brain_tumor/Testing', transform=data_transforms)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)


## Case1: load pre-trained model from unsupervised learning

In [3]:
class CFG:
    # label_num2str = {'0': 'glioma', '1': 'meningioma', '2': 'notumor', '3': 'pituitary'}
    label_num2str = {0: 'glioma', 1: 'meningioma', 2: 'notumor', 3: 'pituitary'}
    label_str2num = {'glioma': '0', 'meningioma': '1', 'notumor': '2','pituitary': '3'}
    fl_alpha = 1.0  # alpha of focal_loss
    fl_gamma = 2.0  # gamma of focal_loss
    cls_weight =  [0.5, 0.5, 0.5, 0.5]
    cnn_name='resnet50'
    vit_name='vit_base_patch16_224'
    seed = 77
    num_classes = 4
    batch_size = 16
    t_max = 16
    lr = 1e-3
    min_lr = 1e-6
    n_fold = 6
    num_workers = 8
    gpu_idx = 0
    device = torch.device(f'cuda:{gpu_idx}' if torch.cuda.is_available() else 'cpu')
    gpu_list = [gpu_idx]
    CNN = True
    VIT = False

class MyF1Score(Metric):
    def __init__(self, cfg, threshold: float = 0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.cfg = cfg
        self.threshold = threshold
        self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        preds_str_batch = self.num_to_str(torch.sigmoid(preds))
        target_str_batch = self.num_to_str(target)
        tp, fp, fn = 0, 0, 0
        for pred_str_list, target_str_list in zip(preds_str_batch, target_str_batch):
            for pred_str in pred_str_list:
                if pred_str in target_str_list:
                    tp += 1
                if pred_str not in target_str_list:
                    fp += 1

            for target_str in target_str_list:
                if target_str not in pred_str_list:
                    fn += 1
        self.tp += tp
        self.fp += fp
        self.fn += fn

    def compute(self):
        f1 = 2.0 * self.tp / (2.0 * self.tp + self.fn + self.fp)
        return f1
    
    def num_to_str(self, ts: torch.Tensor) -> list:
        batch_bool_list = (ts > self.threshold).detach().cpu().numpy().tolist()
        batch_str_list = []
        for one_sample_bool in batch_bool_list:
            lb_str_list = [self.cfg.label_num2str[lb_idx] for lb_idx, bool_val in enumerate(one_sample_bool) if bool_val]
            batch_str_list.append(lb_str_list)
        return batch_str_list

class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-12  # prevent training from Nan-loss error
        self.cls_weights = torch.tensor([CFG.cls_weight],dtype=torch.float, requires_grad=False, device=CFG.device)

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = torch.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt
        focal_loss = focal_loss * self.cls_weights
        return torch.mean(focal_loss)
    
cfg=CFG()
val_metric=MyF1Score(cfg)

fl_alpha = 1.0  # alpha of focal_loss
fl_gamma = 2.0  # gamma of focal_loss
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)

model_cnn = timm.create_model(cfg.cnn_name, pretrained=True)
model_vit = timm.create_model(cfg.vit_name, pretrained=True)
model_cnn.to(device)
model_vit.to(device)
model_vit=torch.load('./cass-vit.pt')
model_cnn=torch.load('./cass-r50.pt')

## Case2: load pre-trained model from downstream supervised learning

In [4]:
model_cnn=torch.load('./cass-r50-tuned.pt')
model_cnn.eval()
for images,label in val_loader:
    images = images.to(device)
    label = label.to(device)
    model_cnn.to(device)
    pred_ts=model_cnn(images)
    label_one_hot = torch.nn.functional.one_hot(label, num_classes=4).float()
    score_val = val_metric(pred_ts, label_one_hot)
    val_loss = criterion(pred_ts, label_one_hot)  
val_score=val_metric.compute()
print('CNN Validation Score:',val_score)

CNN Validation Score: tensor(0.9859)
