## Set up datasets and dataloaders

In [None]:
from utils.device import get_device
from utils.data import DatasetConfig
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS

IMG_PATH = 'datasets/vindr-cxr-png'

NUM_SHOTS = 5
NUM_WAYS = 7
TRAIN_NUM_WAYS= 7
N_QUERY = 20
dataset_config = DatasetConfig(IMG_PATH, 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist'])
device =  get_device()

import torch
import random
import numpy as np

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

## Models with attention
### Run experiments on proposed model

In [None]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.model import ClsModel


encoder = torch.load('models/embedding/model/vindr1/imgtext_model_trained.pth')
encoder.text_model.device = device
attention = torch.load('models/attention/model/vindr1/attention-model-8h4l.pth')
model = ClsModel(encoder, attention, 512, class_prototype_inf, fc_hidden_size=16)
mtrainer = ControlledMetaTrainer(model, NUM_SHOTS, NUM_WAYS, dataset_config, train_n_ways=TRAIN_NUM_WAYS, device=device)

In [None]:
mtrainer.run_eval(mtrainer.model, mtrainer.test_loader)


In [None]:
mtrainer.run_eval(mtrainer.model, mtrainer.val_loader)

In [None]:
# 5 times?
mtrainer.run_train(2, lr=1e-5)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.test_loader)

In [None]:
mtrainer.model.attn_model.set_trainable(True)

In [None]:
mtrainer.run_train(3, lr=1e-5)

In [None]:
# Model trained 3 epochs with lr=5e-6, 3 epochs with lr=1e-6, 3 epochs with lr=1e-5
mtrainer.run_eval(mtrainer.model, mtrainer.test_loader)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.test_loader)

In [None]:
torch.save(mtrainer.best_model.cls.state_dict(), 'models/metaclassifier/model/comb3/cls_weights-16.pkl')

In [None]:
torch.save(mtrainer.best_model.attn_model, 'models/metaclassifier/model/comb3/attention-model-8h4l.pth')

### Prototypical Network with attention

In [None]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.base import euclidean_distance
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.model import ProtoNetAttention

device =  get_device()

encoder = torch.load('models/embedding/model/vindr1/imgtext_model_trained.pth')
encoder.text_model.device = device
attention = torch.load('models/attention/model/vindr1/attention-model-8h4l.pth')
# imgtxt_encoder, attn_model, class_prototype_aggregator, distance_func
model = ProtoNetAttention(encoder, attention, class_prototype_inf, euclidean_distance)
mtrainer = ControlledMetaTrainer(model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device)

In [None]:
mtrainer.model.encoder.set_trainable(True, True, include_logit_scale=False)

In [None]:
mtrainer.model.attn_model.set_trainable(True)

In [None]:
mtrainer.run_eval(model, mtrainer.test_loader, True)

In [None]:
mtrainer.run_eval(model, mtrainer.val_loader)

In [None]:
mtrainer.run_train(3, lr=1e-5)

In [None]:
print(mtrainer.best_model.scale)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.create_query_eval_dataloader('train'), True)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.val_loader, True)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.test_loader, True)

#### Result of training backbone for 3 iterations
Training Query set: 0.5583506872136611
Test set: 0.8007964161274267
Validation set: 0.7913702623906704


#### Result of attention for 3 iterations
Training Query set: 0.558725531028738
Test set: 0.7993031358885015
Validation set: 0.7881049562682213


## Run experiments on baseline models without attention
### RelationNet

In [None]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.baselines import RelationNet

encoder = torch.load('models/embedding/model/vindr1/imgtext_model_trained.pth')
encoder.text_model.device = device
base_model = RelationNet(encoder, 512, class_prototype_inf, fc_hidden_size=16)
btrainer = ControlledMetaTrainer(base_model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.val_loader)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.test_loader)

In [None]:
btrainer.run_train(2, lr=1e-6)

In [None]:
btrainer.run_eval(btrainer.best_model, btrainer.val_loader)

In [None]:
btrainer.run_eval(btrainer.best_model, btrainer.test_loader)

In [None]:
torch.save(btrainer.best_model.cls.state_dict(), 'relnet_weights-16.pkl')

### Prototypical Network

In [None]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.base import euclidean_distance
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.baselines import ProtoNet

encoder = torch.load('models/embedding/model/vindr1/imgtext_model_trained.pth')
encoder.text_model.device = device
base_model = ProtoNet(encoder, class_prototype_inf, euclidean_distance, trainable_base=False)
btrainer = ControlledMetaTrainer(base_model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.val_loader, True)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.test_loader, True)

In [None]:
import torchvision
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder, ImageOnlyEmbedding, DummyEncoder, resnet_backbone, load_pretrained_resnet, adapt_resnet_input_channels

from utils.prototype import class_prototype_inf, class_prototype_mean
from models.metaclassifier.base import euclidean_distance, cosine_distance
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.baselines import ProtoNet

# from torchmetrics.classification import MultilabelAccuracy, MultilabelAUROC

# encoder = torch.load('models/embedding/model/imgtext_model_trained.pth')
# encoder = torch.load('models/embedding/model/imgtext_model_from_medclip.pth')
# torchvision.models.resnet50(num_classes=len(chest_ds.info['label']), pretrained=False)
# img_backbone = resnet_backbone(adapt_resnet_input_channels(torchvision.models.resnet50(weights=None), 1))
img_backbone = resnet_backbone(adapt_resnet_input_channels(torchvision.models.resnet50(weights='IMAGENET1K_V1'), 1))


encoder = ImageOnlyEmbedding(img_backbone, 512)
# encoder = DummyEncoder(512)
base_model = ProtoNet(encoder, class_prototype_inf, cosine_distance, trainable_base=False)

btrainer = ControlledMetaTrainer(base_model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device, n_query=N_QUERY)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.val_loader, verbose=True)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.test_loader, verbose=True)

In [None]:
btrainer.run_train(4, lr=5e-5)

In [None]:
btrainer.run_eval(btrainer.best_model, btrainer.test_loader)