In [91]:
# Load the trained model

import torch
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from datasets import load_from_disk
from transformers import CLIPTokenizer, CLIPProcessor, AutoTokenizer
from transformers import AutoProcessor, FlavaModel
from PIL import Image
import base64
import io
import numpy as np

In [92]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset):
        self.image_size = 224
        self.dataset = dataset
        self.processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
       
        pixel_values = self.processor(images=base64str_to_PILobj(item["image"]).convert("RGB").resize((self.image_size, self.image_size)),
                                            return_tensors="pt")['pixel_values']
    
        text_output = self.processor(text = item['text'],
                                          padding='max_length', 
                                          return_tensors="pt", 
                                          truncation=True)

        label = torch.LongTensor([item['label']])

        return {
            'pixel_values': pixel_values,
            'input_ids': text_output['input_ids'],
            'attention_mask': text_output['attention_mask'],
            'labels': label,
            'idx_memes': item['id'],
            'image': item['image']
        }

In [93]:
class CustomBLIP(nn.Module):
    def __init__(self,
                 text_encoder=None,
                 tokenizer=None,
                 config=None,
                 ):
        super().__init__()

        self.tokenizer = tokenizer
        self.num_labels = 1

        self.model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
        for param in self.model.parameters():
            param.requires_grad = False

        #self.text_encoder = self.model.text_model
        self.text_encoder = self.model.text_encoder
        
        self.vision_encoder = self.model.vision_model
        self.cross_entropy_loss = torch.nn.BCELoss(reduction='mean')

        self.cls_head = nn.Sequential(

            nn.Linear(self.text_encoder.config.hidden_size, 1024),
            nn.ReLU(),

            nn.Linear(1024, 512),
            nn.ReLU(),

            nn.Linear(512, 256),
            nn.ReLU(),

            nn.Linear(256, 128),
            nn.ReLU(),

            nn.Linear(128, 64),
            nn.ReLU(),

            nn.Linear(64, self.num_labels)

        )
        

    def forward(self, batch, train=True):

        pixel_values = batch['pixel_values'].to(device)
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        #image_embeds = vision_outputs[0]

        #image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)

        image_embeds = self.vision_encoder(pixel_values.squeeze(1))[0]
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)

        if train:
            #import pdb; pdb.set_trace()
            output = self.text_encoder(input_ids = input_ids.squeeze(1), 
                                        attention_mask=attention_mask.squeeze(1), 
                                        encoder_hidden_states=image_embeds, 
                                        encoder_attention_mask=image_atts, 
                                        return_dict=True)
            #import pdb; pdb.set_trace()

            prediction = self.cls_head(output.last_hidden_state[:, 0, :]) #[:, 0, :]
            prediction = torch.sigmoid(prediction)
            #print(prediction)
            
            return prediction

        else:
            output = self.text_encoder(input_ids,
                                       attention_mask=attention_mask,
                                       encoder_hidden_states=image_embeds,
                                       encoder_attention_mask=image_atts,
                                       return_dict=True
                                       )
            prediction = self.cls_head(output.last_hidden_state[:, 0, :]) #[:, 0, :]
            prediction = torch.sigmoid(prediction)
            return prediction

    @torch.no_grad()
    def copy_params(self):
        for model_pair in self.model_pairs:
            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
                param_m.data.copy_(param.data)  # initialize
                param_m.requires_grad = False  # not update by gradient

    @torch.no_grad()
    def _momentum_update(self):
        for model_pair in self.model_pairs:
            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
                param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)


In [94]:
model = torch.load('blip_entire_model.pt')

In [95]:
model.eval()

CustomBLIP(
  (model): BlipForImageTextRetrieval(
    (vision_model): BlipVisionModel(
      (embeddings): BlipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (encoder): BlipEncoder(
        (layers): ModuleList(
          (0): BlipEncoderLayer(
            (self_attn): BlipAttention(
              (dropout): Dropout(p=0.0, inplace=False)
              (qkv): Linear(in_features=768, out_features=2304, bias=True)
              (projection): Linear(in_features=768, out_features=768, bias=True)
            )
            (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): BlipMLP(
              (activation_fn): GELUActivation()
              (fc1): Linear(in_features=768, out_features=3072, bias=True)
              (fc2): Linear(in_features=3072, out_features=768, bias=True)
            )
            (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          

In [96]:
batch_size = 8

combined = load_from_disk('./processed_data/combined_hateful_memes_dataset')
# train_data = combined['train']
# print('processing image...')
# train_dataset = ImageCaptioningDataset(train_data)
# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
dev_seen_data = combined['dev_unseen']
dev_seen_dataset = ImageCaptioningDataset(dev_seen_data)
dev_seen_loader = DataLoader(dev_seen_dataset, shuffle=True, batch_size=batch_size)

In [97]:
def base64str_to_PILobj(base64_string):
    '''
    Args
    - base64_string (str): based64 encoded representing an image

    Output
    - PIL object (use .show() to display)
    '''
    image_data = base64.b64decode(base64_string)
    img = Image.open(io.BytesIO(image_data))
    #img.show()
    return img

In [98]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.eval()
correct_normal = 0
total = 0
model = model.to(device)
with torch.no_grad():
    for batch in dev_seen_loader:
        # print(batch)
        # input_ids = batch['input_ids'].to(device)
        # #token_type_ids  = batch['token_type_ids'].to(device)
        # attention_mask = batch['attention_mask'].to(device)
        # pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)

        labels = labels.view(-1)
        labels = labels.to(device)

        output = model(batch)
        predicted = torch.as_tensor((output - 0.5) > 0, dtype=torch.int32)
        # _, predicted = nn.sigmoid(output)#torch.max(output.data, 1)

        total += labels.size(0)
        this_batch_corrected = (predicted==labels.reshape(-1,1)).sum().item()
        correct_normal += this_batch_corrected
        print(f'{this_batch_corrected}/{labels.size(0)} correct for this batch. total corrected by far={correct_normal}/{total}')
        # break

        
accuracy = correct_normal/total

7/8 correct for this batch. total corrected by far=7/8
5/8 correct for this batch. total corrected by far=12/16
5/8 correct for this batch. total corrected by far=17/24
3/8 correct for this batch. total corrected by far=20/32
4/8 correct for this batch. total corrected by far=24/40
5/8 correct for this batch. total corrected by far=29/48
4/8 correct for this batch. total corrected by far=33/56
7/8 correct for this batch. total corrected by far=40/64
3/8 correct for this batch. total corrected by far=43/72
5/8 correct for this batch. total corrected by far=48/80
5/8 correct for this batch. total corrected by far=53/88
5/8 correct for this batch. total corrected by far=58/96
4/8 correct for this batch. total corrected by far=62/104
5/8 correct for this batch. total corrected by far=67/112
4/8 correct for this batch. total corrected by far=71/120
5/8 correct for this batch. total corrected by far=76/128
6/8 correct for this batch. total corrected by far=82/136
5/8 correct for this batch. 

In [99]:
accuracy

0.5962962962962963

In [102]:
predicted

tensor([[0],
        [1],
        [0],
        [0]], device='cuda:0', dtype=torch.int32)

In [103]:
output.data

tensor([[0.1624],
        [0.7426],
        [0.3069],
        [0.1408]], device='cuda:0')