# Preamble: Install and Import Packages

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Resize
from torchvision.io import read_image, ImageReadMode
from multilingual_clip import Config_MCLIP
import open_clip
import json
import pandas as pd
import random
from pathlib import Path
import numpy as np
import transformers as hf
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from PIL import Image
import os
import time
import math
import matplotlib.pyplot as plt
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print(hf.__version__)
torch.autograd.set_detect_anomaly(True)

4.39.2


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x21a599329f0>

# Initialise the Configuration and Random Seeds

In [4]:
_text_model_config = {}

_image_model_config = {
    "attention_probs_dropout_prob": 0.0,
    "encoder_stride": 16,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.0,
    "hidden_size": 768,
    "image_size": 224,
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "layer_norm_eps": 1e-12,
    "num_attention_heads": 12,
    "num_channels": 3,
    "num_hidden_layers": 0,
    "patch_size": 16,
    "qkv_bias": True,
}

# Dual encoder/Concat
tokeniser_model_id = 'xlm-roberta-base'
text_model_id = 'xlm-roberta-base'
image_model_id = 'google/vit-base-patch16-224-in21k'

# CLIP
# multimodal_model_id = 'openai/clip-vit-base-patch32'

# M-CLIP
# tokeniser_model_id = 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus'
# text_model_id = 'M-CLIP/XLM-Roberta-Large-Vit-B-16Plus'
# image_model_id = 'ViT-B-16-plus-240'
image_training_id = 'laion400m_e32'

# ViLT
multimodal_model_id = 'dandelin/vilt-b32-mlm'


class CFG:
    use_multimodal = True
    use_dualencoder = False
    split_lang = False
    save_models = False
    use_lstm = False
    use_attn = False
    use_mask_split = False
    use_modal_attn = True
    is_mclip = False
    init_weights = False
    tokeniser_model_id = tokeniser_model_id
    text_model_id = text_model_id
    image_model_id = image_model_id
    multimodal_model_id = multimodal_model_id
    image_training_id = image_training_id
    text_model_config = hf.AutoConfig.from_pretrained(text_model_id) if not 'M-CLIP' in text_model_id else None
    image_model_config = hf.AutoConfig.from_pretrained(image_model_id) if not 'M-CLIP' in text_model_id else None
    multimodal_model_config = hf.AutoConfig.from_pretrained(multimodal_model_id, text_config=_text_model_config, vision_config=_image_model_config)
    images_base_path = Path(f'EXIST 2024 Lab/EXIST 2024 Memes Dataset/training/memes')
    images_base_path_test = Path('EXIST 2024 Lab/EXIST 2024 Memes Dataset/test/memes')
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    debug = True
    print_freq = 300
    apex = True # for faster training
    epochs = 10
    learning_rate = 2e-4  # for adam optimizer
    eps = 1e-6
    betas = (0.9, 0.999)  # for adam optimizer
    batch_size = 64
    max_len = 512
    weight_decay = 0.01  # for adam optimizer regulaization parameter
    gradient_accumulation_steps = 1
    max_grad_norm = 1000
    seed = 42
    train = True
    num_class = 2  # Number of class in your dataset
    mlp_hidden_size = 256
    mlp_hidden_layers = 0
    mlp_dropout = 0.1
    mlp_grad_clip = 1.0
    mlp_init_range = 0.2
    mlp_attn_dim = 256

In [5]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(CFG.seed)

In [6]:
class MultilingualCLIP(hf.PreTrainedModel):
    config_class = Config_MCLIP.MCLIPConfig

    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.transformer = hf.AutoModel.from_pretrained(config.modelBase, cache_dir=kwargs.get("cache_dir"))
        self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions,
                                                    out_features=config.numDims)

    def forward(self, tokens, mask):
        embs = self.transformer(tokens, attention_mask=mask)[0]
        embs = (embs * mask.unsqueeze(2)).sum(dim=1) / mask.sum(dim=1)[:, None]
        return self.LinearTransformation(embs)

    @classmethod
    def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
        model.load_state_dict(state_dict)
        return model, [], [], []

In [7]:
task4_ea = ['211702', '211424', '110981', '110664']

# Preprocess the Dataset

In [8]:
with open('EXIST 2024 Lab/EXIST 2024 Memes Dataset/training/EXIST2024_training.json', 'r', encoding='utf-8') as fp:
    annotations = json.load(fp)
df = pd.DataFrame.from_dict(annotations).T
df = df[df['id_EXIST'].isin(task4_ea)]
print(df.shape)
df.head()

(4, 16)


Unnamed: 0,id_EXIST,lang,text,meme,path_memes,number_annotators,annotators,gender_annotators,age_annotators,ethnicities_annotators,study_levels_annotators,countries_annotators,labels_task4,labels_task5,labels_task6,split
110664,110664,es,EL FISICO ATRAE PERO LA ENFERMERA ENAMORA,110664.jpeg,memes/110664.jpeg,6,"[Annotator_145, Annotator_146, Annotator_147, ...","[F, F, F, M, M, M]","[18-22, 23-45, 46+, 46+, 18-22, 23-45]","[White or Caucasian, White or Caucasian, White...","[High school degree or equivalent, Bachelor’s ...","[Spain, Italy, Portugal, Spain, Portugal, Mexico]","[YES, YES, YES, YES, YES, YES]","[DIRECT, DIRECT, DIRECT, DIRECT, DIRECT, UNKNOWN]","[[SEXUAL-VIOLENCE], [STEREOTYPING-DOMINANCE, O...",TRAIN-MEME_ES
110981,110981,es,QUIERE IGUALDAD DE OPORTUNIDADES EN LA EDUCACI...,110981.jpeg,memes/110981.jpeg,6,"[Annotator_217, Annotator_218, Annotator_219, ...","[F, F, F, M, M, M]","[18-22, 23-45, 46+, 46+, 18-22, 23-45]","[Hispano or Latino, White or Caucasian, White ...","[Bachelor’s degree, Master’s degree, Bachelor’...","[Peru, Portugal, Portugal, Chile, Germany, Spain]","[NO, NO, NO, NO, NO, NO]","[-, -, -, -, -, -]","[[-], [-], [-], [-], [-], [-]]",TRAIN-MEME_ES
211424,211424,en,SEXUAL ORIENTATION? ORIENTED TO WHEREVER THE H...,211424.jpeg,memes/211424.jpeg,6,"[Annotator_752, Annotator_753, Annotator_754, ...","[F, F, F, M, M, M]","[18-22, 23-45, 46+, 18-22, 23-45, 46+]","[White or Caucasian, White or Caucasian, White...","[Master’s degree, High school degree or equiva...","[Portugal, Slovenia, Australia, Germany, Austr...","[YES, YES, YES, YES, YES, YES]","[DIRECT, DIRECT, DIRECT, DIRECT, DIRECT, DIRECT]","[[OBJECTIFICATION, SEXUAL-VIOLENCE], [OBJECTIF...",TRAIN-MEME_EN
211702,211702,en,me laying in bed at night thinking about how m...,211702.jpeg,memes/211702.jpeg,6,"[Annotator_817, Annotator_818, Annotator_819, ...","[F, F, F, M, M, M]","[18-22, 23-45, 46+, 18-22, 23-45, 46+]","[Asian, White or Caucasian, White or Caucasian...","[High school degree or equivalent, Master’s de...","[China, Hungary, United States, United Kingdom...","[NO, NO, NO, NO, NO, NO]","[-, -, -, -, -, -]","[[-], [-], [-], [-], [-], [-]]",TRAIN-MEME_EN


In [9]:
mini_df = df[['id_EXIST', 'meme', 'text', 'lang']].reset_index(drop=True)
mini_df['id_EXIST'] = pd.to_numeric(mini_df['id_EXIST'])
mini_df.head()

Unnamed: 0,id_EXIST,meme,text,lang
0,110664,110664.jpeg,EL FISICO ATRAE PERO LA ENFERMERA ENAMORA,es
1,110981,110981.jpeg,QUIERE IGUALDAD DE OPORTUNIDADES EN LA EDUCACI...,es
2,211424,211424.jpeg,SEXUAL ORIENTATION? ORIENTED TO WHEREVER THE H...,en
3,211702,211702.jpeg,me laying in bed at night thinking about how m...,en


In [10]:
task4_gold_path = Path('EXIST 2024 Lab/evaluation/golds/EXIST2024_training_task4_gold_hard.json')
task5_gold_path = Path('EXIST 2024 Lab/evaluation/golds/EXIST2024_training_task5_gold_hard.json')
task6_gold_path = Path('EXIST 2024 Lab/evaluation/golds/EXIST2024_training_task6_gold_hard.json')
task4_gold = pd.read_json(task4_gold_path)

choices = ['YES', 'NO']
mini_df = pd.merge(mini_df, task4_gold, left_on='id_EXIST', right_on='id', how='left').drop(columns=['id', 'test_case']).rename(columns={'value': 'label_task4'})
mini_df['label_task4'] = mini_df['label_task4'].apply(lambda x: np.random.choice(choices) if pd.isna(x) else x)
mini_df['label_task4'] = pd.to_numeric(mini_df['label_task4'].map({'YES': 1, 'NO': 0}))
print(len(mini_df))
mini_df.head()

4


Unnamed: 0,id_EXIST,meme,text,lang,label_task4
0,110664,110664.jpeg,EL FISICO ATRAE PERO LA ENFERMERA ENAMORA,es,1
1,110981,110981.jpeg,QUIERE IGUALDAD DE OPORTUNIDADES EN LA EDUCACI...,es,0
2,211424,211424.jpeg,SEXUAL ORIENTATION? ORIENTED TO WHEREVER THE H...,en,1
3,211702,211702.jpeg,me laying in bed at night thinking about how m...,en,0


# Initialise the Processors/Tokenisers/Models

In [11]:
if CFG.is_mclip:
    tokenizer = hf.AutoTokenizer.from_pretrained(CFG.tokeniser_model_id)
    text_model = MultilingualCLIP.from_pretrained(CFG.text_model_id).to(CFG.device)
    image_model, _, image_processor = open_clip.create_model_and_transforms(CFG.image_model_id, pretrained=CFG.image_training_id)
    image_model = image_model.to(CFG.device)
elif CFG.use_multimodal:
    mm_processor = hf.AutoProcessor.from_pretrained(CFG.multimodal_model_id)
    mm_model = hf.AutoModel.from_pretrained(CFG.multimodal_model_id).to(CFG.device)
elif CFG.use_dualencoder:
    tokenizer = hf.AutoTokenizer.from_pretrained(CFG.tokeniser_model_id, padding=True, truncation=True)
    processor = hf.AutoImageProcessor.from_pretrained(CFG.image_model_id)
    de_processor = hf.VisionTextDualEncoderProcessor(image_processor=processor, tokenizer=tokenizer)
    text_model = hf.AutoModel.from_pretrained(CFG.text_model_id).to(CFG.device)
    image_model = hf.AutoModel.from_pretrained(CFG.image_model_id).to(CFG.device)
    de_model = hf.VisionTextDualEncoderModel(vision_model=image_model, text_model=text_model)
else:
    tokenizer = hf.AutoTokenizer.from_pretrained(CFG.tokeniser_model_id)
    text_model = hf.AutoModel.from_pretrained(CFG.text_model_id).to(CFG.device)
    # Adding a config to the image_model gets rid of lots of pretrained weights
    image_model = hf.AutoModel.from_pretrained(CFG.image_model_id).to(CFG.device)

# Custom Dataset Definition

In [12]:
class ExistDataset(Dataset):
    def __init__(self, features, img_dir, labels=None, test=False, img_transform=None, caption_transform=None, target_transform=None):
        self.features = features
        self.labels = labels
        self.img_dir = img_dir
        self.test = test
        self.img_transform = img_transform
        self.caption_transform = caption_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        img_path = str(self.img_dir.joinpath(self.features['meme'].iloc[idx]))
        if CFG.is_mclip:
            image = Image.open(img_path)
        else:
            image = read_image(img_path, mode=ImageReadMode.RGB).to(device=CFG.device)
        caption = self.features['text'].iloc[idx]
        
        if not self.test:
            label = self.labels.iloc[idx]
        else:
            identity = self.features['id_EXIST'].iloc[idx]
        
        if self.img_transform:
            image = self.img_transform(image)
        if self.caption_transform:
            caption = self.caption_transform(caption)
        if not self.test and self.target_transform:
            label = self.target_transform(label)
            
        if CFG.split_lang:
            caption = f'Language: {self.features["lang"].iloc[idx]} - {caption}'
            
        if CFG.is_mclip:
            processed = tokenizer(caption, padding=True, return_tensors='pt')
            seq = processed['input_ids']
            mask = processed['attention_mask']
            image = image_processor(image)
        elif CFG.use_multimodal:
            processed = mm_processor(text=caption, images=image, return_tensors="pt", padding=True, truncation=True)
            seq = processed['input_ids']
            mask = processed['attention_mask']
            image = processed['pixel_values']
        elif CFG.use_dualencoder:
            processed = de_processor(text=caption, images=image, return_tensors="pt")
            seq = processed['input_ids']
            mask = processed['attention_mask']
            image = processed['pixel_values']
        else:
            processed = tokenizer.encode_plus(
                caption,
                padding='longest',
                truncation=True,
                return_tensors='pt'
            )
            seq = processed['input_ids']
            mask = processed['attention_mask']
        
        if not self.test:
            label = torch.tensor([label]).long()
            return image, seq, mask, label
        
        return identity, image, seq, mask

In [13]:
class Collator(object):
    def __init__(self, test=False):
        self.test = test
    def __call__(self, batch):
        if not self.test:
            images, seqs, masks, labels = zip(*batch)
            labels = torch.stack(labels)
        else:
            ids, images, seqs, masks = zip(*batch)

        seqs = [seq.squeeze(dim=0) for seq in seqs]
        masks = [mask.squeeze(dim=0) for mask in masks]
        images = [image.squeeze(dim=0) for image in images]

        seqs = nn.utils.rnn.pad_sequence(seqs, batch_first=True)
        masks = nn.utils.rnn.pad_sequence(masks, batch_first=True)

        images = torch.stack(images)
        
        if not self.test:
            return images, seqs, masks, labels
        
        return ids, images, seqs, masks

In [14]:
resizer = Resize((224, 224), antialias=True)

def resize_images(img_tensor):
    return resizer(img_tensor)

# Dataset Initialisation

In [15]:
val_dataset = ExistDataset(mini_df, CFG.images_base_path, labels=mini_df['label_task4'], img_transform=resize_images, test=True)
len(val_dataset)

4

# Model Architecture

In [16]:
class ConcatArch(nn.Module):
    def __init__(self, hidden_size, hidden_layers, dropout, num_classes, use_multimodal=False, use_dualencoder=False, is_mclip=False):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.hidden_layers = hidden_layers
        self.use_multimodal = use_multimodal
        self.use_dualencoder = use_dualencoder
        self.is_mclip = is_mclip
        self.is_vilt = 'ViltForMaskedLM' in CFG.multimodal_model_config.architectures
        
        if self.is_mclip:
            self.text_model = text_model
            self.image_model = image_model
        elif self.use_multimodal:
            self.mm_model = mm_model
        elif self.use_dualencoder:
            self.de_model = de_model
        else:
            self.text_model = text_model
            self.image_model = image_model
        
        if self.is_mclip:
            self.fc1 = nn.Linear(1280, self.hidden_size)
        elif self.use_multimodal:
            if self.is_vilt and CFG.use_lstm:
                out_channels = CFG.mlp_hidden_size + CFG.multimodal_model_config.hidden_size
                self.lstm = nn.LSTM(CFG.multimodal_model_config.hidden_size, CFG.mlp_hidden_size, batch_first=True)
            elif self.is_vilt and CFG.use_mask_split:
                out_channels = CFG.multimodal_model_config.hidden_size * 3
            elif self.is_vilt and CFG.use_attn:
                self.attn = nn.Sequential(
                    nn.Linear(CFG.multimodal_model_config.hidden_size, CFG.mlp_attn_dim),
                    nn.Tanh(),
                    nn.Linear(CFG.mlp_attn_dim, 1),
                    nn.Softmax(dim=1)
                )
            elif self.is_vilt and CFG.use_modal_attn:
                self.attn1 = nn.Sequential(
                    nn.Linear(CFG.multimodal_model_config.hidden_size, CFG.mlp_attn_dim),
                    nn.Tanh(),
                    nn.Linear(CFG.mlp_attn_dim, 1),
                    nn.Softmax(dim=1)
                )
                self.attn2 = nn.Sequential(
                    nn.Linear(CFG.multimodal_model_config.hidden_size, CFG.mlp_attn_dim),
                    nn.Tanh(),
                    nn.Linear(CFG.mlp_attn_dim, 1),
                    nn.Softmax(dim=1)
                )
                out_channels = CFG.multimodal_model_config.hidden_size * 2
            elif self.is_vilt:
                out_channels = CFG.multimodal_model_config.hidden_size
            else:
                out_channels = 2 * CFG.multimodal_model_config.projection_dim
            self.fc1 = nn.Linear(out_channels, self.hidden_size)
        elif self.use_dualencoder:
            self.fc1 = nn.Linear(2 * 512, self.hidden_size)
        else:
            self.fc1 = nn.Linear(CFG.text_model_config.hidden_size + CFG.image_model_config.hidden_size, self.hidden_size)
        self.hiddens = nn.ModuleList([nn.Linear(self.hidden_size, self.hidden_size) for _ in range(self.hidden_layers)])
        self.fc2 = nn.Linear(self.hidden_size, num_classes)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
        if CFG.init_weights:
            self._init_weights(self.fc1)
            for hidden in self.hiddens:
                self._init_weights(hidden)
            self._init_weights(self.fc2)

    def forward(self, tokens, mask, image):
        text_attentions = None
        img_attentions = None
        
        if self.is_mclip:
            emb_text = self.text_model.forward(tokens, mask)
            emb_img = self.image_model.encode_image(image)
            x = torch.cat([emb_text, emb_img], dim=1)
        elif self.use_multimodal:
            mm_output = self.mm_model(input_ids=tokens, attention_mask=mask, pixel_values=image, output_hidden_states=True)
            cats = [mm_output.pooler_output] if self.is_vilt else [mm_output.text_embeds, mm_output.image_embeds]
            
            if self.is_vilt and CFG.use_lstm:
                # First hidden state is apparently the embedding output
                # https://discuss.huggingface.co/t/hidden-states-embedding-tensors/3549/
                layerwise_cls = torch.stack([h[:, 0, :] for h in mm_output.hidden_states[1:]], dim=1)
                _, (h, _) = self.lstm(layerwise_cls)
                h = h.squeeze(dim=0)
                cats.append(h)

            if self.is_vilt and CFG.use_mask_split:
                last_h = mm_output.last_hidden_state
                mask_len = mask.shape[1]
                mean_pooled_text = torch.mean(last_h[:, :mask_len, :], dim=1)
                mean_pooled_img = torch.mean(last_h[:, mask_len:, :], dim=1)
                cats += [mean_pooled_text, mean_pooled_img]

            if self.is_vilt and CFG.use_attn:
                last_h = mm_output.last_hidden_state
                attentions = self.attn(last_h)
                x = torch.sum(attentions * last_h, dim=1)

                cls = last_h[:, 0, :]
                x += cls
            elif self.is_vilt and CFG.use_modal_attn:
                last_h = mm_output.last_hidden_state
                mask_len = mask.shape[1]
                text_split = last_h[:, :mask_len, :]
                img_split = last_h[:, mask_len:, :]
                text_attentions = self.attn1(text_split)
                img_attentions = self.attn2(img_split)
                x1 = torch.sum(text_attentions * text_split, dim=1)
                x2 = torch.sum(img_attentions * img_split, dim=1)

                x = torch.cat([x1, x2], dim=1)

                cls = last_h[:, 0, :]
                cls = torch.cat([cls, cls], dim=1)
                x += cls
            else:
                x = torch.cat(cats, dim=1)
        elif self.use_dualencoder:
            de_output = self.de_model(input_ids=tokens, attention_mask=mask, pixel_values=image)
            x = torch.cat([de_output.text_embeds, de_output.image_embeds], dim=1)
        else:
            cls_text = self.text_model(tokens, attention_mask=mask).last_hidden_state[:, 0, :]
            cls_img = self.image_model(image).last_hidden_state[:, 0, :]
            x = torch.cat([cls_text, cls_img], dim=1)

        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        for hidden in self.hiddens:
            x = hidden(x)
            x = self.activation(x)
            x = self.dropout(x)
        x = self.fc2(x)
        
        output = x

        if CFG.use_modal_attn:
            return output.float(), text_attentions, img_attentions
        
        return output.float()
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=CFG.mlp_init_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=CFG.mlp_init_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

# Utility Functions

In [17]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [18]:
def get_score(y_trues, y_preds):
    macro_f1 = f1_score(y_trues, y_preds, average='macro')
    return macro_f1

In [30]:
def test_loop(model, test_dataloader):
    all_soft = []
    all_hard = []
    all_ids = []
    all_tokens = []
    all_attentions = []
    
    model.eval()
    
    for identity, image, seq, mask in tqdm(test_dataloader):
        test_image = image.to(device=CFG.device)
        test_seq = seq.to(device=CFG.device)
        test_mask = mask.to(device=CFG.device)

        truncated_seq = [s[m.bool()] for s, m in zip(test_seq, test_mask)]

        with torch.no_grad():
            if CFG.use_modal_attn:
                output, text_attentions, img_attentions = model(test_seq, test_mask, test_image)
            else:
                output = model(test_seq, test_mask, test_image)
        
        soft = nn.functional.softmax(output, dim=1)
        hard = output.argmax(dim=1)
        
        all_ids += list(identity)
        all_soft.append(soft)
        all_hard.append(hard)

        if CFG.use_modal_attn:
            for ts in truncated_seq:
                tokens = mm_processor.tokenizer.convert_ids_to_tokens(ts.cpu().numpy())
                all_tokens.append(tokens)

            all_attentions.extend(text_attentions.squeeze().cpu().numpy())
        
    all_soft = torch.cat(all_soft, dim=0)
    all_hard = torch.cat(all_hard, dim=0)

    if CFG.use_modal_attn:
        return all_ids, all_hard, all_soft, all_tokens, all_attentions
    
    return all_ids, all_hard, all_soft

# Inference From Checkpoint

In [31]:
collate = Collator(test=True)
valid_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size, collate_fn=collate)
loss_fn = nn.CrossEntropyLoss()

inf_model_name = 'dandelin-vilt-b32-mlm-mattn_score_0.9115'
inf_model = ConcatArch(
    hidden_size=CFG.mlp_hidden_size,
    hidden_layers=CFG.mlp_hidden_layers,
    dropout=CFG.mlp_dropout,
    num_classes=CFG.num_class,
    use_multimodal=CFG.use_multimodal,
    use_dualencoder=CFG.use_dualencoder,
    is_mclip=CFG.is_mclip
).to(CFG.device)
inf_model.load_state_dict(torch.load('model_backup/T4/' + inf_model_name + '.pth', map_location=torch.device(CFG.device))['model'])
inf_model

ConcatArch(
  (mm_model): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_fea

In [33]:
ids, hards, softs, tokens, attentions = test_loop(inf_model, valid_dataloader)

100%|██████████| 1/1 [00:01<00:00,  1.18s/it]


In [62]:
def plot_attention_heatmap(tokens, attentions, idx):
    a = list(attentions[1:len(tokens) - 1])
    t = tokens[1:-1]

    print(a)
    print(t)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(a, annot=True, cmap='viridis', xticklabels=t, yticklabels=tokens)
    plt.title("Attention Heatmap")
    plt.ylabel('Input Tokens')
    plt.xlabel('Output Tokens')
    plt.show()

In [63]:
plot_attention_heatmap(tokens[0], attentions[0], 0)

[0.20530574, 0.03372386, 0.017652327, 0.029201869, 0.006606217, 0.015603379, 0.062646054, 0.27612308, 0.0023764307, 0.0015391929, 0.018184192, 0.20514084, 0.0031150633, 0.020074347, 0.07724695]
['el', 'fis', '##ico', 'at', '##rae', 'per', '##o', 'la', 'en', '##fer', '##mer', '##a', 'en', '##amo', '##ra']


IndexError: Inconsistent shape between the condition and the input (got (15, 1) and (15,))

<Figure size 1000x800 with 0 Axes>