In [1]:
import torch
from torch import nn

from utils.prototype import class_variance

class ClsModel(nn.Module):
    def __init__(self, imgtxt_encoder, attn_model, embed_dim, class_prototype_aggregator, fc_hidden_size=16, use_variance=False, activation=nn.ReLU, dropout=0.3) -> None:
        super().__init__()
        self.encoder = imgtxt_encoder
        self.attn_model = attn_model
        self.class_prototype_aggregator = class_prototype_aggregator
        self.use_variance = use_variance
        if use_variance:
            self.cls = nn.Sequential(
                nn.Linear(embed_dim*3, fc_hidden_size),
                activation(),
                nn.Dropout(dropout),
                nn.Linear(fc_hidden_size, 1)
            )
        else:
            self.cls = nn.Sequential(
                nn.Linear(embed_dim*2, fc_hidden_size),
                activation(),
                nn.Dropout(dropout),
                nn.Linear(fc_hidden_size, 1)
            )
        self.cls_loss_criterion = nn.BCEWithLogitsLoss()
        self.freeze_child_models()

    def freeze_child_models(self):
        self.encoder.set_trainable(False, False, False)
        self.attn_model.set_trainable(False)

    def set_class_prototype_details(self, class_labels, support_images, support_label_inds):
        text_embeddings, image_embeddings = self.encoder(class_labels, support_images, pool=False)
        self.class_label_embeddings = text_embeddings
        support_prototypes = self.attn_model(text_embeddings, image_embeddings, support_label_inds)
        self.class_prototypes = self.class_prototype_aggregator(support_prototypes, support_label_inds)
        if self.use_variance:
            self.class_prototypes_var = class_variance(support_prototypes, support_label_inds)
    
    def update_support_and_classify(self, class_labels, support_images, support_label_inds, query_images):
        self.set_class_prototype_details(class_labels, support_images, support_label_inds)
        return self.forward(query_images)

    def forward(self, query_images):
        query_image_embeddings = self.encoder.embed_image(query_images, pool=False)
        query_prototypes = self.attn_model(self.class_label_embeddings, query_image_embeddings)

        # Prototypes: LxD (to repeat N times), variance: LxD (to repeat N times), query class prototype: NxLxD
        class_prototypes = self.class_prototypes.repeat(query_prototypes.shape[0], 1, 1)
        if self.use_variance:
            class_prototypes_var = self.class_prototypes_var.repeat(query_prototypes.shape[0], 1, 1)

        if self.use_variance:
            out = self.cls(torch.cat((class_prototypes, class_prototypes_var, query_prototypes), dim=2))
        else:
            out = self.cls(torch.cat((class_prototypes, query_prototypes), dim=2))
        return out.squeeze(2) # NxLx1 -> NxL
    
    def loss(self, predictions, label_inds):
        return self.cls_loss_criterion(predictions, label_inds.float())

In [None]:
import torch
from torch import nn

def euclidean_distance(prototype, query):
    # prototype: (L, D) | query: (N, L, D)
    prototype = prototype.unsqueeze(0).expand(query.shape[0], -1, -1)
    return ((prototype-query)**2).sum(2)
    
def cosine_distance(prototype, query):
    prototype = prototype / prototype.norm(dim=-1, keepdim=True)
    prototype = prototype.unsqueeze(0).expand(query.shape[0], -1, -1)
    query = query / query.norm(dim=-1, keepdim=True)
    cos = prototype * query
    return -cos
    

class ProtoNet(nn.Module):
    def __init__(self, img_encoder, class_prototype_aggregator, distance_func):
        super().__init__()
        self.encoder = img_encoder
        self.class_prototype_aggregator = class_prototype_aggregator
        self.distance_func = distance_func
        self.loss_fn = nn.BCEWithLogitsLoss()
    
    

In [2]:
import copy
from utils.metrics import AverageMeter, calculate_auc, multilabel_accuracy

class DataloaderIterator:
    def __init__(self, dataloader):
        self.dataloader = dataloader
        self.iterator = iter(self.dataloader)

    def next_batch(self):
        batch = next(self.iterator, None)
        if batch is None:
            self.iterator = iter(self.dataloader)
            batch = next(self.iterator)
        return batch

class MetaTrainer:
    def __init__(self, model, train_class_labels, val_class_labels, device='cpu'):
        self.model = model
        self.device = device
        self.train_class_labels = train_class_labels
        self.val_class_labels = val_class_labels

    def run_train(self, epochs, query_dataloader, support_dataloader, val_dataloader, lr=1e-5):
        model = self.model.to(self.device)
        best_epoch = None
        best_acc = None
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        support_iterator = DataloaderIterator(support_dataloader)
        for epoch in range(epochs):
            model.train()
            loss_meter = AverageMeter()
            acc_meter = AverageMeter()
            for i, (qimages, qclass_inds) in enumerate(query_dataloader):
                qimages, qclass_inds = qimages.to(self.device), qclass_inds.to(self.device)
                simages, sclass_inds = support_iterator.next_batch()
                simages, sclass_inds = simages.to(self.device), sclass_inds.to(self.device)

                optimizer.zero_grad()
                predictions = model.update_support_and_classify(self.train_class_labels, simages, sclass_inds, qimages)
                loss = model.loss(predictions, qclass_inds)

                acc = multilabel_accuracy(predictions, qclass_inds)
                acc_meter.update(acc, qclass_inds.shape[0])

                loss.backward()
                optimizer.step()

                loss_meter.update(loss.item(), len(qimages))
                print(f"Batch {i+1}: loss {loss_meter.average()} | Acc {acc}")
            
            print(f"Epoch {epoch+1}: Training loss {loss_meter.average()} | Acc: {acc_meter.average()}")

            val_acc, val_auc, val_loss = self.run_eval(model, val_dataloader)
            print(f"Epoch {epoch+1}: Validation loss {val_loss} | Accuracy {val_acc} | AUC {val_auc}")

            if best_acc is None or val_acc > best_acc:
                best_acc = val_acc
                self.best_model = copy.deepcopy(model)
                best_epoch = epoch
        self.model = model
        print('Best epoch: ', best_epoch+1)
    
    def run_eval(self, model, val_dataloader):
        model.eval()
        model = model.to(self.device)
        
        loss_meter = AverageMeter()
        auc_meter = AverageMeter()
        acc_meter = AverageMeter()

        with torch.no_grad():
            for images, class_inds in val_dataloader:
                images, class_inds = images.to(self.device), class_inds.to(self.device)
                shots = images.shape[0]//2
                qimages, simages = images[:shots,:,:], images[shots:,:,:]
                qclass_inds, sclass_inds = class_inds[:shots,:], class_inds[shots:,:]

                predictions = model.update_support_and_classify(self.val_class_labels, simages, sclass_inds, qimages)
                loss = model.loss(predictions, qclass_inds)

                loss_meter.update(loss.item(), shots)

                auc = calculate_auc(predictions, qclass_inds)
                auc_meter.update(auc, shots)
            
                acc = multilabel_accuracy(predictions, qclass_inds)
                acc_meter.update(acc, shots)
        return acc_meter.average(), auc_meter.average(), loss_meter.average()

In [3]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.sampling import FewShotBatchSampler
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')

IMG_PATH = 'datasets/vindr-cxr-png'

def meta_training_loader(dataset, shots, n_ways=None, include_query=False):
    return DataLoader(dataset, batch_sampler=FewShotBatchSampler(dataset.get_class_indicators(), shots, n_ways=n_ways, include_query=include_query))

def meta_training_dataset(img_info, split):
    img_ids = img_info[img_info['meta_split'] == split]['image_id'].to_list()
    return Dataset(IMG_PATH, img_info, img_ids, VINDR_CXR_LABELS, VINDR_SPLIT[split], mean_std=MEAN_STDS['chestmnist'])

num_shots = 5
train_query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
train_query_loader = meta_training_loader(train_query_dataset, num_shots)
train_support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
train_support_loader = meta_training_loader(train_support_dataset, num_shots)

val_dataset = meta_training_dataset(img_info, 'val')
val_loader = meta_training_loader(val_dataset, num_shots, include_query=True)

test_dataset = meta_training_dataset(img_info, 'test')
test_loader = meta_training_loader(test_dataset, num_shots, include_query=True)

In [4]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder
from utils.prototype import class_prototype_inf

device =  get_device()

encoder = torch.load('imgtext_model_trained.pth')
encoder.text_model.device = device
attention = torch.load('attention-model-8h4l.pth')
model = ClsModel(encoder, attention, 512, class_prototype_inf)
mtrainer = MetaTrainer(model, train_query_dataset.class_labels(), val_dataset.class_labels(), device=device)

In [9]:
mtrainer.run_eval(mtrainer.best_model, test_loader)

(79.96017919362869, 0.5280027641417706, 0.6473812199220425)

In [5]:
mtrainer.run_eval(model, test_loader)
# (80.06968641114983, 0.5413753165712019, 0.6457181776442179)

  classes_count = torch.nonzero(label_inds)[:,1].bincount()


(80.15928322548532, 0.5015153589662482, 0.6472647931517624)

In [8]:
mtrainer.run_eval(mtrainer.best_model, val_loader)

(79.39358600583091, 0.48764045693712416, 0.6482774785586766)

In [6]:
mtrainer.run_eval(model, val_loader)
# (79.19533527696791, 0.49627472243658033, 0.646583594594683)

(78.98542274052481, 0.48061929521196595, 0.6488205228533064)

In [9]:
mtrainer.run_train(4, train_query_loader, train_support_loader, val_loader, lr=5e-5)

Batch 1: loss 0.6797685623168945 | Acc 59.693877551020414
Batch 2: loss 0.6813422739505768 | Acc 58.06122448979592
Batch 3: loss 0.6820537249247233 | Acc 57.244897959183675
Batch 4: loss 0.6819744408130646 | Acc 58.16326530612245
Batch 5: loss 0.6813647031784058 | Acc 60.204081632653065
Batch 6: loss 0.6800132294495901 | Acc 63.26530612244898
Batch 7: loss 0.6796564289501735 | Acc 60.71428571428571
Batch 8: loss 0.6793731153011322 | Acc 60.816326530612244
Batch 9: loss 0.6795809202724032 | Acc 58.57142857142858
Batch 10: loss 0.6796349585056305 | Acc 59.183673469387756
Batch 11: loss 0.6792561411857605 | Acc 61.73469387755102
Batch 12: loss 0.6788489570220312 | Acc 62.142857142857146
Batch 13: loss 0.6786480362598712 | Acc 60.816326530612244
Batch 14: loss 0.6787015966006688 | Acc 59.38775510204082
Batch 15: loss 0.6788699269294739 | Acc 58.36734693877551
Batch 16: loss 0.6788304671645164 | Acc 59.897959183673464
Batch 17: loss 0.6790687967749203 | Acc 56.83673469387755
Batch 18: loss 