In [1]:
# Load the trained model

import torch
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from datasets import load_from_disk
from transformers import CLIPTokenizer, CLIPProcessor, AutoTokenizer
from transformers import AutoProcessor, FlavaModel
from PIL import Image
import base64
import io
import numpy as np
import torch.nn.functional as F

In [2]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset):

        self.dataset = dataset
        self.image_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
        self.text_processor = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # import pdb; pdb.set_trace()
        # encoding = self.processor(images=base64str_to_PILobj(item["image"]), text=item["text"], padding="max_length", return_tensors="pt")
        # # remove batch dimension
        # encoding = {k:v.squeeze() for k,v in encoding.items()}
        # # import pdb; pdb.set_trace()
        # encoding['label'] = item['label']
        # encoding['image'] = item["image"]
        # return encoding
        pixel_values = self.image_processor(images=base64str_to_PILobj(item["image"]).convert("RGB"),
                                            return_tensors="pt")['pixel_values']
        # caption_output = self.text_processor(item["caption"], 
        #                                      padding=True,
        #                                      return_tensors="pt",
        #                                      truncation=True)
        text_output = self.text_processor(item['text'],
                                          padding='max_length', 
                                          return_tensors="pt", 
                                          truncation=True)
        # pdb.set_trace()
        # print(idx, pixel_values.shape)
        # import pdb; pdb.set_trace()
        label = torch.LongTensor([item['label']])
        # import pdb; pdb.set_trace()
        return {
            'pixel_values': pixel_values,
            'input_ids': text_output['input_ids'],
            'attention_mask': text_output['attention_mask'],
            'labels': label,
            # 'input_ids_caption': caption_output['input_ids'],
            # 'attention_mask_caption': caption_output['attention_mask_caption'],
            'idx_memes': item['id'],
            'image': item['image']
        }

In [3]:
class CLIPClassifier(nn.Module):

    def __init__(self, 
                 map_dim=32,
                 dropout_lst=[0.1, 0.4, 0.2],
                 pretrained_model='openai/clip-vit-large-patch14',
                 freeze_image_encoder=True,
                 freeze_text_encoder=True
                 ):
        super().__init__()

        self.map_dim = map_dim
        self.dropout_lst = dropout_lst
        self.num_mapping_layers = 1
        self.head = 'concat'

        self.clip = CLIPModel.from_pretrained(pretrained_model)
        self.image_encoder = copy.deepcopy(self.clip.vision_model)
        self.text_encoder = copy.deepcopy(self.clip.text_model)

        # Not using pretrained map
        image_map_layers = [nn.Linear(self.image_encoder.config.hidden_size, self.map_dim),
                            nn.Dropout(p=self.dropout_lst[0])]
        text_map_layers = [nn.Linear(self.text_encoder.config.hidden_size, self.map_dim),
                           nn.Dropout(p=self.dropout_lst[0])]
        for _ in range(1, self.num_mapping_layers):
            image_map_layers.extend([nn.ReLU(), 
                                     nn.Linear(self.map_dim, self.map_dim), 
                                     nn.Dropout(p=self.dropout_lst[0])])
            text_map_layers.extend([nn.ReLU(), 
                                    nn.Linear(self.map_dim, self.map_dim), 
                                    nn.Dropout(p=self.dropout_lst[0])])

        self.image_map = nn.Sequential(*image_map_layers)
        self.text_map = nn.Sequential(*text_map_layers)

        pre_output_input_dim = self.map_dim*2

        pre_output_layers = [nn.Dropout(p=self.dropout_lst[1])]
        pre_output_layers.extend([nn.Linear(pre_output_input_dim, self.map_dim),
                                  nn.ReLU(),
                                  nn.Dropout(p=self.dropout_lst[2])])

        self.pre_output = nn.Sequential(*pre_output_layers)
        self.output = nn.Linear(self.map_dim, 1)

        self.cross_entropy_loss = torch.nn.BCEWithLogitsLoss(reduction='mean')

        if freeze_image_encoder:
            for _, p in self.image_encoder.named_parameters():
                p.requires_grad_(False)

        if freeze_text_encoder:
            for _, p in self.text_encoder.named_parameters():
                p.requires_grad_(False)
        del self.clip
        

    def forward(self, batch):
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        # import pdb; pdb.set_trace()
        image_features = self.image_encoder(pixel_values=pixel_values.squeeze(1)).pooler_output
        image_features = self.image_map(image_features)
        # import pdb; pdb.set_trace()
        text_features = self.text_encoder(input_ids=input_ids.squeeze(1),
                                          attention_mask=attention_mask.squeeze(1)).pooler_output
        
        text_features = self.text_map(text_features)

        image_features = F.normalize(image_features, p=2, dim=1)
        text_features = F.normalize(text_features, p=2, dim=1)

        if self.head == 'concat':
            features = torch.cat([image_features, text_features], dim=1)
        elif self.head == 'cross':
            features = torch.bmm(image_features.unsqueeze(2), text_features.unsqueeze(1)) # [16, d, d]
            features = features.reshape(features.shape[0], -1)  # [16, d*d]

        features = self.pre_output(features)
        logits = self.output(features)
        preds = torch.sigmoid(logits)
        # preds = (torch.sigmoid(logits) >= 0.5).long()

        return preds

In [4]:
model = torch.load('clip_entire_model_added_sigmoid_gradient_clip.pt')
model.eval()

CLIPClassifier(
  (image_encoder): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (position_embedding): Embedding(257, 1024)
    )
    (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0): CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
  

In [5]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset):

        self.dataset = dataset
        self.image_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
        self.text_processor = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # import pdb; pdb.set_trace()
        # encoding = self.processor(images=base64str_to_PILobj(item["image"]), text=item["text"], padding="max_length", return_tensors="pt")
        # # remove batch dimension
        # encoding = {k:v.squeeze() for k,v in encoding.items()}
        # # import pdb; pdb.set_trace()
        # encoding['label'] = item['label']
        # encoding['image'] = item["image"]
        # return encoding
        pixel_values = self.image_processor(images=base64str_to_PILobj(item["image"]).convert("RGB"),
                                            return_tensors="pt")['pixel_values']
        # caption_output = self.text_processor(item["caption"], 
        #                                      padding=True,
        #                                      return_tensors="pt",
        #                                      truncation=True)
        text_output = self.text_processor(item['text'],
                                          padding='max_length', 
                                          return_tensors="pt", 
                                          truncation=True)
        # pdb.set_trace()
        # print(idx, pixel_values.shape)
        # import pdb; pdb.set_trace()
        label = torch.LongTensor([item['label']])
        # import pdb; pdb.set_trace()
        return {
            'pixel_values': pixel_values,
            'input_ids': text_output['input_ids'],
            'attention_mask': text_output['attention_mask'],
            'labels': label,
            # 'input_ids_caption': caption_output['input_ids'],
            # 'attention_mask_caption': caption_output['attention_mask_caption'],
            'idx_memes': item['id'],
            'image': item['image']
        }

def base64str_to_PILobj(base64_string):
    '''
    Args
    - base64_string (str): based64 encoded representing an image

    Output
    - PIL object (use .show() to display)
    '''
    image_data = base64.b64decode(base64_string)
    img = Image.open(io.BytesIO(image_data))
    #img.show()
    return img

In [8]:
batch_size = 32

combined = load_from_disk('./processed_data/combined_hateful_memes_dataset')
# train_data = combined['train']
# print('processing image...')
# train_dataset = ImageCaptioningDataset(train_data)
# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
dev_seen_data = combined['dev_unseen']
dev_seen_dataset = ImageCaptioningDataset(dev_seen_data)
dev_seen_loader = DataLoader(dev_seen_dataset, shuffle=True, batch_size=batch_size)

In [9]:
device = 'cuda'
model.eval()
correct_normal = 0
total = 0
model = model.to(device)
with torch.no_grad():
    for batch in dev_seen_loader:
        # print(len(batch))
        # input_ids = batch['input_ids'].to(device)
        # token_type_ids  = batch['token_type_ids'].to(device)
        # attention_mask = batch['attention_mask'].to(device)
        # pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)

        labels = labels.view(-1)
        labels = labels.to(device)

        output = model(batch)
        predicted = torch.as_tensor((output - 0.5) > 0, dtype=torch.int32)
        # _, predicted = nn.sigmoid(output)#torch.max(output.data, 1)

        total += labels.size(0)
        this_batch_corrected = (predicted==labels.reshape(-1,1)).sum().item()
        correct_normal += this_batch_corrected
        print(f'{this_batch_corrected}/{labels.size(0)} correct for this batch. total corrected by far={correct_normal}/{total}')
        # correct_normal += (predicted == labels).sum().item()
        # break

        
accuracy = correct_normal/total
print(accuracy)

21/32 correct for this batch. total corrected by far=21/32
18/32 correct for this batch. total corrected by far=39/64
23/32 correct for this batch. total corrected by far=62/96
23/32 correct for this batch. total corrected by far=85/128
19/32 correct for this batch. total corrected by far=104/160
22/32 correct for this batch. total corrected by far=126/192
17/32 correct for this batch. total corrected by far=143/224
27/32 correct for this batch. total corrected by far=170/256
20/32 correct for this batch. total corrected by far=190/288
27/32 correct for this batch. total corrected by far=217/320
18/32 correct for this batch. total corrected by far=235/352
19/32 correct for this batch. total corrected by far=254/384
18/32 correct for this batch. total corrected by far=272/416
25/32 correct for this batch. total corrected by far=297/448
21/32 correct for this batch. total corrected by far=318/480
24/32 correct for this batch. total corrected by far=342/512
23/28 correct for this batch. t

In [8]:
train_data = combined['train']
print('processing image...')
train_dataset = ImageCaptioningDataset(train_data)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)

processing image...


In [16]:
device = 'cuda'
model.eval()
correct_normal = 0
total = 0
model = model.to(device)
with torch.no_grad():
    for batch in train_loader:
        # print(len(batch))
        # input_ids = batch['input_ids'].to(device)
        # token_type_ids  = batch['token_type_ids'].to(device)
        # attention_mask = batch['attention_mask'].to(device)
        # pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)

        labels = labels.view(-1)
        labels = labels.to(device)

        output = model(batch)
        predicted = torch.as_tensor((output - 0.5) > 0, dtype=torch.int32)
        # _, predicted = nn.sigmoid(output)#torch.max(output.data, 1)

        total += labels.size(0)
        this_batch_corrected = (predicted==labels.reshape(-1,1)).sum().item()
        correct_normal += this_batch_corrected
        print(f'{this_batch_corrected}/{labels.size(0)} correct for this batch. total corrected by far={correct_normal}/{total}')
        # correct_normal += (predicted == labels).sum().item()
        # break

        
accuracy = correct_normal/total
print(accuracy)

29/32 correct for this batch. total corrected by far=29/32
27/32 correct for this batch. total corrected by far=56/64
29/32 correct for this batch. total corrected by far=85/96
27/32 correct for this batch. total corrected by far=112/128
24/32 correct for this batch. total corrected by far=136/160
26/32 correct for this batch. total corrected by far=162/192
29/32 correct for this batch. total corrected by far=191/224
29/32 correct for this batch. total corrected by far=220/256
25/32 correct for this batch. total corrected by far=245/288
29/32 correct for this batch. total corrected by far=274/320
30/32 correct for this batch. total corrected by far=304/352
28/32 correct for this batch. total corrected by far=332/384
25/32 correct for this batch. total corrected by far=357/416
28/32 correct for this batch. total corrected by far=385/448
26/32 correct for this batch. total corrected by far=411/480
29/32 correct for this batch. total corrected by far=440/512
27/32 correct for this batch. 

In [23]:
all_labels = train_data['label']

from collections import Counter
Counter(all_labels)

Counter({0: 5481, 1: 3019})

In [24]:
3019/5481

0.5508118956394819