In [19]:
import torch
from torch import nn
from abc import abstractmethod
# from models.metaclassifier.base import MetaModelBase

class MetaModelBase(nn.Module):
    def __init__(self, imgtxt_encoder, class_prototype_aggregator):
        super().__init__()
        self.encoder = imgtxt_encoder
        self.class_prototype_aggregator = class_prototype_aggregator
        self.loss_fn = nn.BCEWithLogitsLoss()
    
    def set_class_prototype_details(self, class_labels, support_images, support_label_inds):
        image_embeddings = self.encoder.embed_image(support_images, pool=True) # (B, D)
        image_embeddings = image_embeddings.unsqueeze(1).expand(-1, support_label_inds.shape[1], -1)

        self.class_prototypes = self.class_prototype_aggregator(image_embeddings, support_label_inds)
    
    def update_support_and_classify(self, class_labels, support_images, support_label_inds, query_images):
        self.set_class_prototype_details(class_labels, support_images, support_label_inds)
        return self.forward(query_images)
    
    @abstractmethod
    def forward(self, query_images):
        return query_images
    
    def loss(self, predictions, label_inds):
        return self.loss_fn(predictions, label_inds.float())


class ProtoNet(MetaModelBase):
    def __init__(self, imgtxt_encoder, class_prototype_aggregator, distance_func, scale=1.0, trainable_base=True):
        super(ProtoNet, self).__init__(imgtxt_encoder, class_prototype_aggregator)
        self.distance_func = distance_func
        self.scale = nn.Parameter(torch.tensor(scale))
        self.set_trainable(trainable_base, trainable_base, include_text_bb=False, include_logit_scale=False)
    
    def set_trainable(self, trainable):
        self.encoder.set_trainable(trainable, trainable, include_text_bb=False, include_logit_scale=False)

    def forward(self, query_images):
        query_image_embeddings = self.encoder.embed_image(query_images, pool=True)
        query_image_embeddings = query_image_embeddings.unsqueeze(1).expand(-1, self.class_prototypes.shape[0], -1)
        return -self.distance_func(self.class_prototypes, query_image_embeddings) * self.scale
    
class RelationNet(MetaModelBase):
    def __init__(self, imgtxt_encoder, embed_dim, class_prototype_aggregator, fc_hidden_size=16, activation=nn.ReLU, dropout=0.3):
        super(RelationNet, self).__init__(imgtxt_encoder, class_prototype_aggregator)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.cls = nn.Sequential(
            nn.Linear(embed_dim*2, fc_hidden_size),
            activation(),
            nn.Dropout(dropout),
            nn.Linear(fc_hidden_size, 1)
        )
        self.encoder.set_trainable(False, False)

    def forward(self, query_images):
        query_image_embeddings = self.encoder.embed_image(query_images, pool=True)
        query_image_embeddings = query_image_embeddings.unsqueeze(1).expand(-1, self.class_prototypes.shape[0], -1)
        class_prototypes = self.class_prototypes.repeat(query_image_embeddings.shape[0], 1, 1)
       
        out = self.cls(torch.cat((class_prototypes, query_image_embeddings), dim=2))
        return out.squeeze(2) # NxLx1 -> NxL

## Set up datasets and dataloaders

In [1]:
from torch.utils.data import DataLoader
import pandas as pd

from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.sampling import FewShotBatchSampler
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')

IMG_PATH = 'datasets/vindr-cxr-png'

def meta_training_loader(dataset, shots, n_ways=None, include_query=False):
    return DataLoader(dataset, batch_sampler=FewShotBatchSampler(dataset.get_class_indicators(), shots, n_ways=n_ways, include_query=include_query))

def meta_training_dataset(img_info, split):
    img_ids = img_info[img_info['meta_split'] == split]['image_id'].to_list()
    return Dataset(IMG_PATH, img_info, img_ids, VINDR_CXR_LABELS, VINDR_SPLIT[split], mean_std=MEAN_STDS['chestmnist'])

num_shots = 5
train_query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
train_query_loader = meta_training_loader(train_query_dataset, num_shots)
train_support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
train_support_loader = meta_training_loader(train_support_dataset, num_shots)

val_dataset = meta_training_dataset(img_info, 'val')
val_loader = meta_training_loader(val_dataset, num_shots, include_query=True)

test_dataset = meta_training_dataset(img_info, 'test')
test_loader = meta_training_loader(test_dataset, num_shots, include_query=True)

## Models with attention
### Run experiments on proposed model

In [13]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.trainer import MetaTrainer
from models.metaclassifier.model import ClsModel

device =  get_device()

encoder = torch.load('models/embedding/model/imgtext_model_trained.pth')
encoder.text_model.device = device
attention = torch.load('models/attention/model/attention-model-8h4l.pth')
model = ClsModel(encoder, attention, 512, class_prototype_inf, fc_hidden_size=64)
mtrainer = MetaTrainer(model, device=device)

In [14]:
mtrainer.run_eval(mtrainer.model, test_loader, test_dataset.class_labels())


(80.09955201592834, 0.46334128083966913, 0.6701876840940336)

In [15]:
mtrainer.run_eval(mtrainer.model, val_loader, val_dataset.class_labels())

(79.35860058309035, 0.5253407048569427, 0.671541600567954)

In [16]:
mtrainer.model.attn_model.set_trainable(True)

In [17]:
# model seems to have a tendency to overfit, improving performance on training seems to quickly lead to a degradation for test and val
mtrainer.run_train(4, train_query_loader, train_support_loader, val_loader,  train_query_dataset.class_labels(), val_dataset.class_labels(), lr=1e-5)

Batch 1: loss 0.6870514154434204 | Acc 60.0
Batch 2: loss 0.6872147917747498 | Acc 58.57142857142858
Batch 3: loss 0.6868315935134888 | Acc 60.30612244897959
Batch 4: loss 0.6869679391384125 | Acc 58.57142857142858
Batch 5: loss 0.6870496034622192 | Acc 58.57142857142858
Batch 6: loss 0.687052438656489 | Acc 58.97959183673469
Batch 7: loss 0.686966095651899 | Acc 59.795918367346935
Batch 8: loss 0.6868915781378746 | Acc 59.897959183673464
Batch 9: loss 0.6867133577664694 | Acc 61.3265306122449
Batch 10: loss 0.686601585149765 | Acc 60.91836734693877
Batch 11: loss 0.6865592382170937 | Acc 60.204081632653065
Batch 12: loss 0.6865238746007284 | Acc 60.204081632653065
Batch 13: loss 0.6864938919360821 | Acc 60.204081632653065
Batch 14: loss 0.686628269297736 | Acc 57.244897959183675
Batch 15: loss 0.6864869594573975 | Acc 62.34693877551021
Batch 16: loss 0.6865179091691971 | Acc 59.08163265306122
Batch 17: loss 0.6866225109380835 | Acc 57.3469387755102
Batch 18: loss 0.6867111722628275 | 

In [18]:
mtrainer.run_eval(mtrainer.best_model, val_loader, val_dataset.class_labels())

(79.42857142857143, 0.5, 0.671340092590877)

In [19]:
mtrainer.run_eval(mtrainer.best_model, test_loader, test_dataset.class_labels())

(80.19910403185662, 0.5, 0.6707499230780253)

In [21]:
torch.save(mtrainer.best_model.cls.state_dict(), 'models/metaclassifier/model/comb2/cls_weights-64.pkl')

In [22]:
torch.save(mtrainer.best_model.attn_model, 'models/metaclassifier/model/comb2/attention-model-8h4l.pth')

### Prototypical Network with attention

In [2]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.base import euclidean_distance
from models.metaclassifier.trainer import MetaTrainer
from models.metaclassifier.model import ProtoNetAttention

device =  get_device()

encoder = torch.load('models/embedding/model/imgtext_model_trained.pth')
encoder.text_model.device = device
attention = torch.load('models/attention/model/attention-model-8h4l.pth')
# imgtxt_encoder, attn_model, class_prototype_aggregator, distance_func
model = ProtoNetAttention(encoder, attention, class_prototype_inf, euclidean_distance)
mtrainer = MetaTrainer(model, device=device)

In [3]:
mtrainer.run_eval(model, test_loader, test_dataset.class_labels())

  classes_count = torch.nonzero(label_inds)[:,1].bincount()


(80.39820806371328, 0.5402599822857049, 0.6664394343771586)

In [4]:
mtrainer.run_eval(model, val_loader, val_dataset.class_labels())

(79.37026239067053, 0.5161583489962621, 0.6703500832830157)

In [5]:
mtrainer.run_train(2, train_query_loader, train_support_loader, val_loader,  train_query_dataset.class_labels(), val_dataset.class_labels(), lr=1e-5)

Batch 1: loss 0.645804762840271 | Acc 59.795918367346935
Batch 2: loss 0.64562126994133 | Acc 58.877551020408156
Batch 3: loss 0.6446925004323324 | Acc 60.51020408163266
Batch 4: loss 0.6453973799943924 | Acc 59.285714285714285
Batch 5: loss 0.6450785756111145 | Acc 58.673469387755105
Batch 6: loss 0.6450744767983755 | Acc 59.38775510204082
Batch 7: loss 0.6451364670481 | Acc 58.57142857142858
Batch 8: loss 0.6450803205370903 | Acc 60.71428571428571
Batch 9: loss 0.6445973383055793 | Acc 61.530612244897966
Batch 10: loss 0.6443249225616455 | Acc 60.71428571428571
Batch 11: loss 0.6443740454587069 | Acc 59.591836734693885
Batch 12: loss 0.6442251602808634 | Acc 60.10204081632653
Batch 13: loss 0.6441494272305415 | Acc 60.0
Batch 14: loss 0.6442059363637652 | Acc 57.44897959183673
Batch 15: loss 0.6439297715822856 | Acc 62.755102040816325
Batch 16: loss 0.6437373459339142 | Acc 59.08163265306122
Batch 17: loss 0.6442396991393146 | Acc 56.93877551020409
Batch 18: loss 0.644419295920266 | 

In [7]:
print(mtrainer.best_model.scale)

Parameter containing:
tensor(1.0008, device='mps:0', requires_grad=True)


In [8]:
mtrainer.run_eval(mtrainer.best_model, test_loader, test_dataset.class_labels())

(80.38825286212047, 0.5417728137567698, 0.6569410795118751)

## Run experiments on baseline models without attention
### RelationNet

In [20]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.trainer import MetaTrainer
from models.metaclassifier.baselines import RelationNet

device =  get_device()

encoder = torch.load('imgtext_model_trained.pth')
encoder.text_model.device = device
base_model = RelationNet(encoder, 512, class_prototype_inf, fc_hidden_size=64)
btrainer = MetaTrainer(base_model, device=device)

In [21]:
btrainer.run_eval(btrainer.model, val_loader, val_dataset.class_labels())

(76.40816326530613, 0.4814422659229693, 0.6890060526984079)

In [22]:
btrainer.run_eval(btrainer.model, test_loader, test_dataset.class_labels())

(75.85863613738178, 0.46753415772972007, 0.6887184134343776)

In [23]:
btrainer.run_train(4, train_query_loader, train_support_loader, val_loader,  train_query_dataset.class_labels(), val_dataset.class_labels(), lr=5e-5)

Batch 1: loss 0.6922392845153809 | Acc 55.81632653061225
Batch 2: loss 0.6922772824764252 | Acc 56.42857142857143
Batch 3: loss 0.6921975612640381 | Acc 56.53061224489796
Batch 4: loss 0.6921016573905945 | Acc 56.42857142857143
Batch 5: loss 0.6920598268508911 | Acc 56.93877551020409
Batch 6: loss 0.6919763882954916 | Acc 57.95918367346938
Batch 7: loss 0.6918906910078866 | Acc 57.14285714285714
Batch 8: loss 0.6917958334088326 | Acc 58.57142857142858
Batch 9: loss 0.6916449202431573 | Acc 59.897959183673464
Batch 10: loss 0.6915410220623016 | Acc 58.97959183673469
Batch 11: loss 0.6914029825817455 | Acc 59.897959183673464
Batch 12: loss 0.6912817309300104 | Acc 59.48979591836735
Batch 13: loss 0.6911432926471417 | Acc 59.38775510204082
Batch 14: loss 0.6910162568092346 | Acc 59.38775510204082
Batch 15: loss 0.6908564249674479 | Acc 60.40816326530612
Batch 16: loss 0.6907406598329544 | Acc 58.97959183673469
Batch 17: loss 0.6905901607345132 | Acc 59.285714285714285
Batch 18: loss 0.690

In [24]:
btrainer.run_eval(btrainer.best_model, val_loader, val_dataset.class_labels())

(78.69387755102042, 0.5204702109376685, 0.6312191026551383)

In [25]:
btrainer.run_eval(btrainer.best_model, test_loader, test_dataset.class_labels())

(79.25335988053757, 0.510349916356135, 0.6320822573289638)

In [None]:
torch.save(btrainer.best_model.cls.state_dict(), 'relnet_weights-64.pkl')