## Installation

In [None]:
#For cpu only
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

In [None]:
!pip install pyyaml==5.1
!python -m pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.5'
!pip install pandas
!pip install transformers

## Model
Reference: 
- https://colab.research.google.com/drive/1bLGxKdldwqnMVA5x4neY7-l_8fKGWQYI?usp=sharing#scrollTo=7-5rqN-vtlkq
- https://github.com/Ikea-179/Hateful-Meme-Detection/blob/main/VisualBERT.ipynb

In [7]:
%load_ext autoreload
%autoreload 2

In [8]:
import torch, torchvision
import matplotlib.pyplot as plt
import json
import cv2
import numpy as np
from copy import deepcopy
from visual_embedding.visual_embeding_detectron2 import VisualEmbedder
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.structures.image_list import ImageList
from detectron2.data import transforms as T
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputs
from detectron2.structures.boxes import Boxes
from detectron2.layers import nms
from detectron2 import model_zoo
from detectron2.config import get_cfg

### Dataset

In [63]:
data_path='../data/hateful_memes/train_df_wQuery_.jsonl'
import pandas as pd
img_data = pd.read_json(path_or_buf=data_path, lines=True).to_dict(orient='records')

In [64]:
print(len(img_data))
print(img_data[0].keys())

8500
dict_keys(['id', 'img', 'label', 'text', 'query_1'])


In [65]:
cfg_path="COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
MIN_BOXES=10 
MAX_BOXES=100
visualembedder=VisualEmbedder(cfg_path=cfg_path, min_boxes=MIN_BOXES, max_boxes=MAX_BOXES)

In [66]:
import os
img_inpainted_dir='../data/hateful_memes/img_inpainted'
visualembedder.visual_embeds_detectron2([cv2.imread(os.path.join(img_inpainted_dir, img_data[3]['img'].split('/')[-1]))])

The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


[tensor([[-3.6105, -2.2630,  2.2598,  ..., -2.5526,  0.3378,  1.5410],
         [ 2.1424,  1.1499,  0.4670,  ..., -0.5727,  0.6574,  0.3634],
         [ 1.9217, -0.5865, -0.4217,  ...,  0.8373,  1.3967,  1.7566],
         ...,
         [ 1.5050,  1.8730,  1.2205,  ..., -1.2324,  1.6751, -0.4564],
         [-1.5518,  0.5050, -0.0539,  ..., -1.0867,  1.0135,  0.4631],
         [ 0.2506, -0.3850, -0.0968,  ..., -1.3164,  1.0680,  0.9516]],
        device='cuda:0', grad_fn=<IndexBackward0>)]

In [67]:
# class HateMemeDataset(Dataset):
#     def __init__(self, data_path, img_dir, target_transform=None):
#         self.img_data = pd.read_json(path_or_buf=data_path, lines=True).to_dict(orient='records')
#         self.img_dir = img_dir
#         self.target_transform = target_transform
        

#     def __len__(self):
#         return len(self.img_data)

#     def __getitem__(self, idx):
#         image_embed=self.img_data[idx]['visual_embedding']
#         img_text=self.img_data[idx]['text']
#         img_query=self.img_data[idx]['query_1']
#         label=self.img_data[idx]['label']
#         data_id=self.img_data[idx]['id']
        
#         if self.target_transform:
#             label = self.target_transform(label)
#         return image_embed, label, img_text, img_query, data_id

In [3]:
from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image
import os
import pandas as pd
from torchvision.io import read_image
import torch
from torch.utils.data import Dataset

class HatefulMemesData(Dataset):
    def __init__(self, df,img_dir, tokenizer, sequence_length, visual_embed_model='vit', print_text=False):         

        self.sequence_length = sequence_length
        self.tokenizer = tokenizer
        self.print_text = print_text
        self.dataset = df
        self.img_dir = img_dir
        self.visual_embed_model = visual_embed_model
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
        self.feature_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k').to('cuda')

    def __len__(self):
        return len(self.dataset)


    def tokenize_data(self, example):
   
        idx = example['id']
        idx = [idx] if isinstance(idx, str) else idx
        
        encoded_dict = self.tokenizer(example['text'], padding='max_length', max_length=self.sequence_length, truncation=True, return_tensors='pt')
        tokens = encoded_dict['input_ids']
        token_type_ids = encoded_dict['token_type_ids']
        attn_mask = encoded_dict['attention_mask']
        
        captioning_encode_dict=self.tokenizer(example['query_1'], padding='max_length', max_length=self.sequence_length, truncation=True, return_tensors='pt')
        caption_token=captioning_encode_dict['input_ids']
        caption_token_type_ids=captioning_encode_dict['token_type_ids']
        caption_attn_mask=captioning_encode_dict['attention_mask']

        targets = torch.tensor(example['label']).type(torch.int64)

        ## Get Visual Embeddings
        try:
            if self.visual_embed_model=='vit':
                #TODO: make it work
                img = example['img'].split('/')[-1]
                img = Image.open(os.path.join(self.img_dir , img))
                img = np.array(img)
                img = img[...,:3]
                inputs = self.feature_extractor(images=img, return_tensors="pt")
                outputs = self.feature_model(**inputs.to('cuda'))
                visual_embeds = outputs.last_hidden_state
                visual_embeds = visual_embeds.cpu()
            elif self.visual_embed_model=='detectron2':
                visual_embeds=example['visual_embedding']
        except:
            # print("Error with Id: ", idx)
            if self.visual_embed_model=='vit':
                visual_embeds = np.zeros(shape=(197, 768), dtype=float)
            elif self.visual_embed_model=='detectron2':
                visual_embeds = np.zeros(shape=(100, 1024), dtype=float)

        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.int64)
        visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.int64)

        inputs={"input_ids": tokens.squeeze(),
            "attention_mask": attn_mask.squeeze(),
            "token_type_ids": token_type_ids.squeeze(),
            "visual_embeds": visual_embeds.squeeze(),
            "visual_token_type_ids": visual_token_type_ids.squeeze(),
            "visual_attention_mask": visual_attention_mask.squeeze(),
            "label": targets.squeeze(),
            "caption_input_ids": caption_token.squeeze(),
            "caption_attention_mask": caption_attn_mask.squeeze(),
            "caption_token_type_ids": caption_token_type_ids.squeeze()
        }
        
        return inputs
  
    def __getitem__(self, index):
        inputs = self.tokenize_data(self.dataset[index])
        
        if self.print_text:
            for k in inputs.keys():
                print(k, inputs[k].shape, inputs[k].dtype)

        return inputs

  from .autonotebook import tqdm as notebook_tqdm
2024-05-07 11:54:46.104574: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-07 11:54:46.261403: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-07 11:54:46.261459: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-07 11:54:46.279408: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-07 11:54:46.327554: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-07 11:54:46.328403:

In [4]:
from torch.utils.data import DataLoader
from transformers import BertTokenizer, VisualBertForPreTraining, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
train_data_path='../data/hateful_memes/train_df_wQuery_.jsonl'
validation_data_path='../data/hateful_memes/dev_seen_df_wQuery_.jsonl'
img_inpainted_dir='../data/hateful_memes/img_inpainted'
cfg_path="COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
MIN_BOXES=10 
MAX_BOXES=100

training_data = HatefulMemesData(train_data_path, img_inpainted_dir, tokenizer, sequence_length=50, visual_embed_model='vit')
validation_data = HatefulMemesData(validation_data_path, img_inpainted_dir,tokenizer, sequence_length=50, visual_embed_model='vit')


train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(validation_data, batch_size=32, shuffle=True)



KeyboardInterrupt: 

#### Check dataloader

In [69]:
# Display image and label.
train_features, train_labels, train_text, train_query, train_id = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


dict_keys(['p2', 'p3', 'p4', 'p5', 'p6'])


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m
  x = F.conv2d(


OutOfMemoryError: CUDA out of memory. Tried to allocate 60.00 MiB. GPU 