In [1]:
import torch
from torch import nn
from utils.contrastive_loss import SupConLoss

def image_text_logits(text_embeddings, prototypes, scale=1):
    # text_embeddings: (14, 512) x prototypes: (140, 14, 512) -> (140, 14)
    fac = text_embeddings.unsqueeze(0).expand_as(prototypes)
    return (fac * prototypes).sum(axis=2) * scale

class LabelImageAttention(nn.Module):
    def __init__(self, dim_in, n_head, dropout=0.1, num_layers=6, temperature=1):
        super().__init__()
        self.attn = nn.Transformer(dim_in, batch_first=True, nhead=n_head, dropout=dropout, num_decoder_layers=num_layers, num_encoder_layers=num_layers)
        self.con_loss = SupConLoss(temperature=temperature)

    def forward(self, texts, images, label_inds=None):
        # transformer: (N, S, E), (N, T, E) -> (N, T, E)
        # texts: (L,D) , images: (N,D,H,W), label_inds: (N, L)
        texts = texts.repeat(images.shape[0], 1, 1)
        # view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
        images = images.flatten(start_dim=2).permute(0, 2, 1)
        mask = None
        if label_inds is not None:
            mask = (1 - label_inds).bool()
        
        # Texts: NxLxD (decode)
        # Mask irrelevant labels with tgt_key_padding_mask, set masked positions to True
        # Images: Nx(HxW)xD
        # Output: (N, L, D)
        out = self.attn(images, texts, tgt_key_padding_mask=mask)
        return out / out.norm(dim=-1, keepdim=True)
    
    def loss(self, results, label_inds):
        # results: (N, L, D), labels: (N, L)
        classes = torch.nonzero(label_inds)[:,1] # (Np,)
        prototypes = results[label_inds.bool()] # (Np, D)
        return self.con_loss(prototypes.unsqueeze(1), classes)


class LabelImagePrototypeModel(nn.Module):
    def __init__(self, encoder, n_head, dim_in=512, dropout=0.1, num_layers=4, temperature=1):
        super().__init__()
        self.encoder = encoder
        self.attention = LabelImageAttention(dim_in, n_head, dropout, num_layers, temperature)

    def freeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

    def unfreeze_encoder(self):
        self.encoder.img_model.proj.requires_grad = True
        self.encoder.text_model.proj.requires_grad = True
    
    def forward(self, class_labels, images, label_inds):
        text_embedding, image_emedding = self.encoder(class_labels, images, False)
        prototypes = self.attention(text_embedding, image_emedding, label_inds)
        return text_embedding, image_emedding, prototypes
    
    def attention_loss(self, prototypes, label_inds):
        return self.attention.loss(prototypes, label_inds)


In [2]:
import torch
import copy
from utils.metrics import AverageMeter, calculate_auc, multilabel_accuracy

class Trainer:
    def __init__(self, model, class_labels, device='cpu'):
        self.model = model
        self.device = device
        self.class_labels = class_labels
    
    def run_train(self, epochs, dataloader, val_dataloader, lr=1e-4, full_training=False):
        model = self.model.to(self.device)
        best_epoch = None
        best_loss = None
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        if full_training:
            model.unfreeze_encoder()
        else:
            model.freeze_encoder()
        
        for epoch in range(epochs):
            model.train()
            loss_meter = AverageMeter()
            for i, (images, class_inds) in enumerate(dataloader):
                images, class_inds = images.to(self.device), class_inds.to(self.device)
                optimizer.zero_grad()

                text_embeddings, _, prototypes = model(self.class_labels, images, class_inds)
                print('has nan', torch.any(prototypes.isnan()))
                loss = model.attention_loss(prototypes, class_inds)
                if full_training:
                    logits_per_image = image_text_logits(text_embeddings, prototypes, model.encoder.get_logit_scale())
                    loss += 0.5*model.encoder.contrastive_logit_loss(logits_per_image.t(), logits_per_image, class_inds)
               
                loss.backward()
                optimizer.step()

                loss_meter.update(loss.item(), len(class_inds))
                print(f"Batch {i+1}: loss {loss_meter.average()}")
            
            print(f"Epoch {epoch+1}: Training loss {loss_meter.average()}")

            val_acc, val_auc, val_loss = self.run_eval(model, val_dataloader)
            print(f"Epoch {epoch+1}: Validation loss {val_loss} | Accuracy {val_acc} | AUC {val_auc}")

            if best_loss is None or val_loss < best_loss:
                best_loss = val_loss
                self.best_model = copy.deepcopy(model)
                best_epoch = epoch
        self.model = model
        print('Best epoch: ', best_epoch+1)

    def run_eval(self, model, dataloader, full_training=False):
        model.eval()
        model = model.to(self.device)
        
        loss_meter = AverageMeter()
        auc_meter = AverageMeter()
        acc_meter = AverageMeter()
        with torch.no_grad():
            for images, class_inds in dataloader:
                images, class_inds = images.to(self.device), class_inds.to(self.device)
                text_embeddings, _, prototypes = model(self.class_labels, images, class_inds)

                loss = model.attention_loss(prototypes, class_inds).item()
                if full_training:
                    logits_per_image = image_text_logits(text_embeddings, prototypes, model.encoder.get_logit_scale())
                    loss += 0.5*model.encoder.contrastive_logit_loss(logits_per_image.t(), logits_per_image, class_inds).item()
        
                loss_meter.update(loss, len(class_inds))

                auc = calculate_auc(logits_per_image, class_inds)
                auc_meter.update(auc, len(class_inds))
            
                acc = multilabel_accuracy(logits_per_image, class_inds)
                acc_meter.update(acc, len(class_inds))

        return loss_meter.average(), auc_meter.average(), loss_meter.average()

In [3]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from utils.device import get_device
from utils.data import get_query_and_support_ids
from models.embedding.dataset import Dataset

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14
query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, shuffle=True)

PROJ_SIZE = 512
# device = 'cpu'
device =  get_device()

encoder = torch.load('imgtext_model_trained1-newlib.pth')
encoder.text_model.device = device
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [None]:
torch.autograd.set_detect_anomaly(True)

mtrainer.run_train(5, support_loader, query_loader, lr=1e-4, full_training=False)

In [10]:
def run_eval(model, dataloader, class_labels, device, full_training=False):
    model.eval()
    model = model.to(device)
    
    loss_meter = AverageMeter()
    auc_meter = AverageMeter()
    acc_meter = AverageMeter()
    with torch.no_grad():
        for images, class_inds in dataloader:
            images, class_inds = images.to(device), class_inds.to(device)
            text_embeddings, _, prototypes = model(class_labels, images, class_inds)

            loss = model.attention_loss(prototypes, class_inds).item()
            if full_training:
                logits_per_image = image_text_logits(text_embeddings, prototypes, model.encoder.get_logit_scale())
                loss += 0.5*model.encoder.contrastive_logit_loss(logits_per_image.t(), logits_per_image, class_inds).item()
    
            loss_meter.update(loss, len(class_inds))
            
            auc = calculate_auc(logits_per_image, class_inds)
            auc_meter.update(auc, len(class_inds))
            
            acc = multilabel_accuracy(logits_per_image, class_inds)
            acc_meter.update(acc, len(class_inds))
    return acc_meter.average(), auc_meter.average(), loss_meter.average()

In [None]:
run_eval(mtrainer.model, query_loader, mtrainer.class_labels, device)

In [None]:
mtrainer.run_eval(mtrainer.best_model, query_loader)