In [None]:
import torch
from torch import nn

from utils.prototype import class_variance

class ClsModel(nn.Module):
    def __init__(self, imgtxt_encoder, attn_model, embed_dim, class_prototype_aggregator, fc_hidden_size=16) -> None:
        super().__init__()
        self.encoder = imgtxt_encoder
        self.attn_model = attn_model
        self.class_prototype_aggregator = class_prototype_aggregator
        self.cls = nn.Sequential(
            nn.Linear(embed_dim*3, fc_hidden_size),
            nn.Linear(fc_hidden_size, 1)
        )
        self.cls_loss_criterion = nn.BCEWithLogitsLoss()

    def set_class_prototype_details(self, class_labels, support_images, support_label_inds):
        text_embeddings, image_embeddings = self.encoder(class_labels, support_images, pool=False)
        self.class_label_embeddings = text_embeddings
        support_prototypes = self.attn_model(text_embeddings, image_embeddings, support_label_inds)
        self.class_prototypes = self.class_prototype_aggregator(support_prototypes.permute(1,0,2), support_label_inds)
        self.class_prototypes_var = class_variance(support_prototypes.permute(1,0,2), support_label_inds)
    
    def update_support_and_classify(self, class_labels, support_images, support_label_inds, query_images):
        self.set_class_prototype_details(class_labels, support_images, support_label_inds)
        return self.forward(query_images)

    def forward(self, query_images):
        query_image_embeddings = self.encoder.embed_image(query_images, pool=False)
        query_prototypes = self.attn_model(self.class_label_embeddings, query_image_embeddings)

        # Prototypes: LxD (to repeat N times), variance: LxD (to repeat N times), query class prototype: NxLxD
        class_prototypes = self.class_prototypes.expand(query_prototypes.shape[0], -1, -1)
        class_prototypes_var = self.class_prototypes_var.expand(query_prototypes.shape[0], -1, -1)
        out = self.cls(torch.cat((class_prototypes, class_prototypes_var, query_image_embeddings), dim=2))
        return out.squeeze(2) # NxLx1 -> NxL
    
    def loss(self, predictions, query_label_inds):
        return self.cls_loss_criterion(predictions, query_label_inds)

In [None]:
class MetaTrainer:
    def __init__(self):
        pass