In [None]:
import mil_configs

In [None]:
train_config = {
    "lr": 1e-4,
    "enable_clam_training": True, # when False, instance classifier is never trained. k_sample and subtyping options are not used.  
    "model_config": mil_configs.clam_config,
    #"model_config": mil_configs.abmil_config,
    #"model_config": mil_configs.sdp_config,
    #"model_config": mil_configs.avg_config,
    
}

In [None]:
# MNIST 6, 8, 9 classification
import tqdm
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from utils import MILMNISTDataset, calc_acc, resnet50_baseline
from models import ClamWrapper

# Training dataset
train_loader = torch.utils.data.DataLoader(
    MILMNISTDataset(root='.', target_digits=[6, 8, 9], bag_size=64, train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])), batch_size=1, shuffle=False, num_workers=0)
# Test dataset
test_loader = torch.utils.data.DataLoader(
    MILMNISTDataset(root='.', target_digits=[6, 8, 9], bag_size=64, train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), batch_size=1, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ClamWrapper(train_config['model_config'], base_encoder=resnet50_baseline(pretrained=True)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=train_config['lr'])
scaler = GradScaler()
criterion = nn.CrossEntropyLoss(reduction='mean')
inst_criterion = nn.BCEWithLogitsLoss(reduction='none')

#calc_acc(model, test_loader, device)

itr = 0
for images, labels in tqdm.tqdm(train_loader):
    itr+=1
    model.train()
    images = images.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()
    with autocast(device_type='cuda', dtype=torch.float16):
        if train_config['enable_clam_training']:
            bag_logit, inst_logit, top_p_ids, top_n_ids  = model(images)
            bag_loss = criterion(bag_logit, labels)
            instance_target = torch.zeros_like(inst_logit).to(inst_logit.device)
            instance_mask   = torch.zeros_like(inst_logit).to(inst_logit.device)
            for p_index, n_index in zip(top_p_ids, top_n_ids):
                if p_index.dim() > 1: # CLAM-MB
                    instance_target[p_index[labels.item()], labels] = 1.
                    if train_config['model_config']['subtyping']:
                        instance_mask[p_index[labels.item()], :] = 1. 
                    else:
                        instance_mask[p_index[labels.item(), :], labels] = 1. 
                    instance_mask[n_index[labels.item()], labels] = 1.
                else: # CLAM-SB
                    instance_target[p_index, labels] = 1.
                    if train_config['model_config']['subtyping']:
                        instance_mask[p_index] = 1.         
                    else:
                        instance_mask[p_index, labels] = 1. 
                    
                    instance_mask[n_index, labels] = 1.
            inst_loss = inst_criterion(inst_logit.view(-1), instance_target.view(-1)) * instance_mask.view(-1)
            inst_loss = inst_loss.mean()
        
            loss = 0.7*bag_loss + 0.3*inst_loss
        else:
            bag_logit = model.eval_forward(images)
            loss = criterion(bag_logit, labels)
    
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    scaler.step(optimizer)
    scaler.update()

    if itr%500==0:
        print(f'{itr} iterations')
        calc_acc(model, test_loader, device)