In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pickle
from MIL_layers import *
import sklearn.metrics as metrics
from VarMIL import *
from CLAM import *

In [2]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score

## UNI

In [3]:
# import timm
# from timm.data import resolve_data_config
# from timm.data.transforms_factory import create_transform
# from huggingface_hub import login

# login(token = "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")  # login with your User Access Token, found at https://huggingface.co/settings/tokens

# # pretrained=True needed to load UNI weights (and download weights for the first time)
# # init_values need to be passed in to successfully load LayerScale parameters (e.g. - block.0.ls1.gamma)
# model = timm.create_model("hf-hub:MahmoodLab/UNI", pretrained=True, init_values=1e-5, dynamic_img_size=True)
# transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
# model.eval()

In [4]:
# from PIL import Image
# image = Image.open("UNI/.github/uni.jpg")
# image = transform(image).unsqueeze(dim=0) # Image (torch.Tensor) with shape [1, 3, 224, 224] following image resizing and normalization (ImageNet parameters)
# with torch.inference_mode():
#     feature_emb = model(image) # Extracted features (torch.Tensor) with shape [1,1024]

## ABMIL

In [5]:
file_path = "./data/train_dict.pkl"
with open(file_path, 'rb') as f:
    train_dict = pickle.load(f)

file_path = "./data/test_dict.pkl"
with open(file_path, 'rb') as f:
    test_dict = pickle.load(f)

In [6]:
X_train = train_dict['embeddings'][:,1:,:]
y_train = train_dict['labels']
X_test = test_dict['embeddings'][:,1:,:]
y_test = test_dict['labels']

In [7]:
# Convert data to PyTorch tensors
train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32), 
                               torch.tensor(y_train, dtype=torch.int))
test_dataset = TensorDataset(torch.tensor(X_test, dtype=torch.float32), 
                              torch.tensor(y_test, dtype=torch.int))

# Define DataLoaders
batch_size = 1  # Adjust batch size as needed
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [8]:
def train(epoch,model,lr=0.001,weight_decay=0.0005):
    train_loss = 0.
    train_error = 0.
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) #betas ?

    for batch_idx, (data, label) in enumerate(train_loader):
        bag_label = label[0]
        if torch.cuda.is_available():
            data, bag_label = data.cuda(), bag_label.cuda()
        
        # reset gradients
        optimizer.zero_grad()
        # calculate loss and metrics
        loss, _ = model.calculate_objective(data, bag_label)
        train_loss += loss.item()
        error, _ = model.calculate_classification_error(data, bag_label)
        train_error += error
        # backward pass
        loss.backward()
        # step
        optimizer.step()
    
    # calculate loss and error for epoch
    train_loss /= len(train_loader)
    train_error /= len(train_loader)

    print('Epoch: {}, Loss: {:.4f}, Train error: {:.4f}'.format(epoch, train_loss, train_error))


In [9]:
def test(model):
    model.eval()
    test_loss = 0.
    test_error = 0.
    y_pred =[]
    with torch.no_grad():
        for batch_idx, (data, label) in enumerate(test_loader):
            bag_label = label[0]
            if torch.cuda.is_available():
                data, bag_label = data.cuda(), bag_label.cuda()
            loss, attention_weights = model.calculate_objective(data, bag_label)
            test_loss += loss.item()
            error, predicted_label = model.calculate_classification_error(data, bag_label)
            test_error += error
            y_pred.append(predicted_label.cpu().numpy().item())

            #print('Predicted label: {}, True label: {}'.format(predicted_label.item(), bag_label))
    test_error /= len(test_loader)
    test_loss /= len(test_loader)

    print('\nTest Set, Loss: {:.4f}, Test error: {:.4f}'.format(test_loss, test_error))
    print('Accuracy :' , accuracy_score(y_test, y_pred))
    print('Precision :' , precision_score(y_test, y_pred))
    print('Recall :' , recall_score(y_test, y_pred))
    print('F1 Score :' , f1_score(y_test, y_pred))


### BASELINE : Embedding +Mean

In [10]:
model = Emb_mean()
#TRAIN THE MODEL
for epoch in range(1, 20):
    train(epoch, model, 0.001)

Epoch: 1, Loss: 0.1436, Train error: 0.0308
Epoch: 2, Loss: 0.0837, Train error: 0.0385
Epoch: 3, Loss: 0.0532, Train error: 0.0231
Epoch: 4, Loss: 0.0735, Train error: 0.0231
Epoch: 5, Loss: 0.0759, Train error: 0.0308
Epoch: 6, Loss: 0.0483, Train error: 0.0231
Epoch: 7, Loss: 0.0893, Train error: 0.0308
Epoch: 8, Loss: 0.0431, Train error: 0.0154
Epoch: 9, Loss: 0.0489, Train error: 0.0308
Epoch: 10, Loss: 0.0627, Train error: 0.0231
Epoch: 11, Loss: 0.0671, Train error: 0.0231
Epoch: 12, Loss: 0.0492, Train error: 0.0231
Epoch: 13, Loss: 0.0633, Train error: 0.0308
Epoch: 14, Loss: 0.0676, Train error: 0.0308
Epoch: 15, Loss: 0.0587, Train error: 0.0154
Epoch: 16, Loss: 0.0773, Train error: 0.0308
Epoch: 17, Loss: 0.0465, Train error: 0.0231
Epoch: 18, Loss: 0.0600, Train error: 0.0231
Epoch: 19, Loss: 0.0509, Train error: 0.0231


In [11]:
test(model)


Test Set, Loss: 1.1652, Test error: 0.2667
Accuracy : 0.7333333333333333
Precision : 0.6666666666666666
Recall : 0.9333333333333333
F1 Score : 0.7777777777777778


### BASELINE : Embedding +max 

In [12]:
model = Emb_max()
#TRAIN THE MODEL
for epoch in range(1, 20):
    train(epoch, model, 0.001)

Epoch: 1, Loss: 0.6581, Train error: 0.3385
Epoch: 2, Loss: 0.2305, Train error: 0.0923
Epoch: 3, Loss: 0.2007, Train error: 0.0615
Epoch: 4, Loss: 0.1821, Train error: 0.0462
Epoch: 5, Loss: 0.1352, Train error: 0.0538
Epoch: 6, Loss: 0.0986, Train error: 0.0462
Epoch: 7, Loss: 0.0684, Train error: 0.0308
Epoch: 8, Loss: 0.0664, Train error: 0.0385
Epoch: 9, Loss: 0.0709, Train error: 0.0308
Epoch: 10, Loss: 0.0778, Train error: 0.0308
Epoch: 11, Loss: 0.1193, Train error: 0.0308
Epoch: 12, Loss: 0.0477, Train error: 0.0077
Epoch: 13, Loss: 0.0325, Train error: 0.0077
Epoch: 14, Loss: 0.0419, Train error: 0.0154
Epoch: 15, Loss: 0.0643, Train error: 0.0154
Epoch: 16, Loss: 0.0437, Train error: 0.0154
Epoch: 17, Loss: 0.0563, Train error: 0.0308
Epoch: 18, Loss: 0.0398, Train error: 0.0077
Epoch: 19, Loss: 0.0943, Train error: 0.0462


In [13]:
test(model)


Test Set, Loss: 3.9925, Test error: 0.5000
Accuracy : 0.5
Precision : 0.5
Recall : 1.0
F1 Score : 0.6666666666666666


### ATTENTION

In [14]:
model = Attention(hidden_size=512, dropout=0.5)
#TRAIN THE MODEL
for epoch in range(1, 20):
    train(epoch, model, 0.001)

Epoch: 1, Loss: 0.2321, Train error: 0.0923
Epoch: 2, Loss: 0.0520, Train error: 0.0385
Epoch: 3, Loss: 0.0681, Train error: 0.0154
Epoch: 4, Loss: 0.0791, Train error: 0.0154
Epoch: 5, Loss: 0.0778, Train error: 0.0231
Epoch: 6, Loss: 0.0512, Train error: 0.0154
Epoch: 7, Loss: 0.2161, Train error: 0.0385
Epoch: 8, Loss: 0.0871, Train error: 0.0308
Epoch: 9, Loss: 0.0827, Train error: 0.0231
Epoch: 10, Loss: 0.0658, Train error: 0.0308
Epoch: 11, Loss: 0.0317, Train error: 0.0154
Epoch: 12, Loss: 0.0502, Train error: 0.0231
Epoch: 13, Loss: 0.0641, Train error: 0.0308
Epoch: 14, Loss: 0.0456, Train error: 0.0154
Epoch: 15, Loss: 0.0716, Train error: 0.0385
Epoch: 16, Loss: 0.0266, Train error: 0.0077
Epoch: 17, Loss: 0.0201, Train error: 0.0077
Epoch: 18, Loss: 0.0824, Train error: 0.0154
Epoch: 19, Loss: 0.0498, Train error: 0.0154


In [15]:
test(model)


Test Set, Loss: 1.6763, Test error: 0.4167
Accuracy : 0.5833333333333334
Precision : 0.5510204081632653
Recall : 0.9
F1 Score : 0.6835443037974683


### Gated Attention

In [16]:
model = GatedAttention(hidden_size=512, dropout=0.1)
#TRAIN THE MODEL
for epoch in range(1, 20):
    train(epoch, model, 0.001)

Epoch: 1, Loss: 0.1443, Train error: 0.0692
Epoch: 2, Loss: 0.1035, Train error: 0.0462
Epoch: 3, Loss: 0.0730, Train error: 0.0231
Epoch: 4, Loss: 0.0980, Train error: 0.0231
Epoch: 5, Loss: 0.0736, Train error: 0.0154
Epoch: 6, Loss: 0.0417, Train error: 0.0231
Epoch: 7, Loss: 0.1275, Train error: 0.0385
Epoch: 8, Loss: 0.0563, Train error: 0.0231
Epoch: 9, Loss: 0.0674, Train error: 0.0231
Epoch: 10, Loss: 0.0765, Train error: 0.0308
Epoch: 11, Loss: 0.0749, Train error: 0.0154
Epoch: 12, Loss: 0.0664, Train error: 0.0154
Epoch: 13, Loss: 0.0511, Train error: 0.0077
Epoch: 14, Loss: 0.0610, Train error: 0.0308
Epoch: 15, Loss: 0.0507, Train error: 0.0308
Epoch: 16, Loss: 0.0329, Train error: 0.0154
Epoch: 17, Loss: 0.0691, Train error: 0.0231
Epoch: 18, Loss: 0.0567, Train error: 0.0154
Epoch: 19, Loss: 0.0515, Train error: 0.0231


In [17]:
test(model)


Test Set, Loss: 0.5719, Test error: 0.2000
Accuracy : 0.8
Precision : 0.7647058823529411
Recall : 0.8666666666666667
F1 Score : 0.8125


## VARMIL

In [26]:
model = VarMIL(embed_size= 1024, hidden_size=500,separate_attn=False, dropout=0.5)

In [27]:
for epoch in range(1, 20):
    train(epoch, model, 0.001)

Epoch: 1, Loss: 0.2162, Train error: 0.1154
Epoch: 2, Loss: 0.0812, Train error: 0.0385
Epoch: 3, Loss: 0.1095, Train error: 0.0308
Epoch: 4, Loss: 0.0740, Train error: 0.0231
Epoch: 5, Loss: 0.0865, Train error: 0.0231
Epoch: 6, Loss: 0.1092, Train error: 0.0385
Epoch: 7, Loss: 0.0683, Train error: 0.0231
Epoch: 8, Loss: 0.1585, Train error: 0.0308
Epoch: 9, Loss: 0.0927, Train error: 0.0308
Epoch: 10, Loss: 0.1347, Train error: 0.0231
Epoch: 11, Loss: 0.0467, Train error: 0.0231
Epoch: 12, Loss: 0.0799, Train error: 0.0308
Epoch: 13, Loss: 0.1070, Train error: 0.0231
Epoch: 14, Loss: 0.1063, Train error: 0.0385
Epoch: 15, Loss: 0.0234, Train error: 0.0000
Epoch: 16, Loss: 0.0382, Train error: 0.0231
Epoch: 17, Loss: 0.0795, Train error: 0.0385
Epoch: 18, Loss: 0.0994, Train error: 0.0308
Epoch: 19, Loss: 0.0505, Train error: 0.0308


In [28]:
test(model)


Test Set, Loss: 4.3909, Test error: 0.5000
Accuracy : 0.5
Precision : 0.5
Recall : 1.0
F1 Score : 0.6666666666666666


## CLAM

In [29]:
model = CLAM_SB()
#TRAIN THE MODEL
for epoch in range(1, 20):
    train(epoch, model, 0.001)


Epoch: 1, Loss: 0.2497, Train error: 0.0769
Epoch: 2, Loss: 0.2194, Train error: 0.0385
Epoch: 3, Loss: 0.4246, Train error: 0.0462
Epoch: 4, Loss: 0.4142, Train error: 0.0538
Epoch: 5, Loss: 0.1606, Train error: 0.0308
Epoch: 6, Loss: 0.0800, Train error: 0.0231
Epoch: 7, Loss: 0.3839, Train error: 0.0385
Epoch: 8, Loss: 0.1210, Train error: 0.0154
Epoch: 9, Loss: 0.1895, Train error: 0.0231
Epoch: 10, Loss: 0.7841, Train error: 0.0462
Epoch: 11, Loss: 0.2468, Train error: 0.0308
Epoch: 12, Loss: 0.2805, Train error: 0.0308
Epoch: 13, Loss: 0.2389, Train error: 0.0308
Epoch: 14, Loss: 0.6565, Train error: 0.0154
Epoch: 15, Loss: 0.2073, Train error: 0.0308
Epoch: 16, Loss: 0.3936, Train error: 0.0231
Epoch: 17, Loss: 0.5886, Train error: 0.0462
Epoch: 18, Loss: 0.3669, Train error: 0.0462
Epoch: 19, Loss: 0.4109, Train error: 0.0385


In [30]:
test(model)


Test Set, Loss: 0.4810, Test error: 0.2667
Accuracy : 0.7333333333333333
Precision : 0.6842105263157895
Recall : 0.8666666666666667
F1 Score : 0.7647058823529411
