In [1]:
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
import pandas as pd
import lmdb
import io
import pickle
import os
import json
import urllib
from transformers import CLIPTokenizer
from transformers import CLIPFeatureExtractor
from transformers import CLIPProcessor
from transformers import CLIPModel, CLIPConfig
from transformers import get_scheduler
from PIL import Image
from tqdm import tqdm
import math

In [2]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,dataset_path, image_path,only_vsr = False, vsr = None, max_length = 90):
        self.max_length = max_length
        self.image_path = image_path
        if(only_vsr):#assess only for the vsr dataset
            self.dataset = self.read_vsr_dataset(vsr)
        else:#assess for the SNLI-VE dataset
            self.dataset = self.read_dataset(dataset_path)
            if(vsr):#add extra data to make model more reliable
                self.dataset = pd.concat([self.dataset, self.read_vsr_dataset('train.json'),
                                self.read_vsr_dataset('test.json'), self.read_vsr_dataset('dev.json')], ignore_index=True)
        #self.features_extract = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
        #self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    def read_vsr_dataset(self,dataset_name, dataset_path = '../visual-spatial-reasoning/',splits_path='splits/', 
                         image_path = 'images/',sort = False):
        dataset = pd.read_json(dataset_path+splits_path+dataset_name, lines =True)
        labels_encoding = {0:0,1:2}#leave the label 0 the same and convert 1 to 2 to mean entailment
        dataset = dataset[['caption','image','label']]
        dataset.rename(columns = {'caption':'hypothesis', 'image':'Flickr30kID', 'label' : 'gold_label'}, inplace = True)
        dataset['gold_label']=dataset['gold_label'].apply(lambda label: labels_encoding[label])
        dataset['Flickr30kID']=dataset['Flickr30kID'].apply(lambda path: dataset_path+image_path+path)
        if(dataset_name=='train.json'):
            dataset.drop(labels=[1786,3569,4553,4912], axis=0, inplace = True)
        elif(dataset_name=='test.json'):
            dataset.drop(labels=[135,614,1071,1621,1850], axis=0, inplace = True)
        elif(dataset_name=='dev.json'):
            dataset.drop(labels=[807], axis=0, inplace = True)
        dataset.reset_index(drop=True, inplace=True)
        if sort:
            dataset.sort_values(by="hypothesis", key=lambda x: x.str.len(), inplace = True)
        return dataset
    
    def read_dataset(self, url,sort = False):
        dataset = pd.read_csv(url)
        labels_encoding = {'contradiction':0,'neutral': 1,
                           'entailment':2}
        dataset = dataset[['hypothesis','Flickr30kID','gold_label']]
        dataset['gold_label']=dataset['gold_label'].apply(lambda label: labels_encoding[label])
        dataset['Flickr30kID']=dataset['Flickr30kID'].apply(lambda path: self.image_path+path)
        if sort:
            dataset.sort_values(by="hypothesis", key=lambda x: x.str.len(), inplace = True)
        return dataset
    
    def get_visual_features(self,img):
        return self.features_extract(img)
    
    def get_text_features(self,text): 
        return self.tokenizer(text)
    
    def __getitem__(self, idx):
        img_name = self.dataset.loc[idx,'Flickr30kID']
        text = self.dataset.loc[idx,'hypothesis']
        label = self.dataset.loc[idx,'gold_label']
        item = self.processor(text=text, images=Image.open(img_name), return_tensors="pt",padding="max_length", 
                              max_length= self.max_length,truncation=True)
        item['input_ids'] = item['input_ids'][0]
        item['attention_mask'] = item['attention_mask'][0]
        item['pixel_values'] = item['pixel_values'][0]
        item['label'] = torch.tensor(label)
        return item

    def __len__(self):
        return len(self.dataset.index)
    
    def __exit__(self):
        self.img_env.close()
        self.env.close()

In [3]:
class MyTrainer():
    def __init__(self,model,train,eval_test, device = None, num_labels = 3):
        self.model = model
        self.device = device
        self.train = train
        self.eval_test = eval_test
        self.test_acc_list = []#init
        self.model_path = "./models/my_model_epoch_"
        self.num_labels = num_labels
        self.config_problem_type = "single_label_classification"
        if self.config_problem_type == "single_label_classification":
          self.loss_fct = torch.nn.CrossEntropyLoss()
          self.output_loss = lambda output,labels : self.loss_fct(output.logits.view(-1, self.num_labels), labels.view(-1)) 
        elif self.config_problem_type == "regression":
          self.loss_fct = torch.nn.MSELoss()
          if self.num_labels == 1: self.output_loss = lambda output,labels : self.loss_fct(output.logits.squeeze(), labels.squeeze())
          else: self.output_loss =  lambda output,labels : self.loss_fct(output.logits, labels)
        elif self.config_problem_type == "multi_label_classification":
          self.loss_fct = torch.nn.BCEWithLogitsLoss()
          self.output_loss = lambda output,labels : self.loss_fct(output.logits, labels)

    def train_model(self,batch_size = None, lr= None, epochs=None):
        optimizer = AdamW(self.model.parameters(), lr=lr)
        train_loader = DataLoader(self.train, batch_size=batch_size, shuffle=True, num_workers = 4)
        lr_scheduler = get_scheduler(
            name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps= epochs * len(train_loader)
        )
        for epoch in range(epochs):
            progress_bar = tqdm(range(math.ceil(len(self.train)/batch_size)))
            train_losses = []
            for item in train_loader:
                item['input_ids'] = item['input_ids'].to(self.device)
                item['attention_mask'] = item['attention_mask'].to(self.device)
                item['pixel_values'] = item['pixel_values'].to(self.device)
                item['label'] = item['label'].to(self.device)
                optimizer.zero_grad()
                outputs = self.model.forward(**item)
                label = item['label']
                loss = self.output_loss(outputs, label)
                train_losses.append(loss)
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                progress_bar.update(1)
            #print("Saving model ....")
            #model.save_model(self.model_path+str(epoch))
            #print("Model Saved!")
            test_acc = self.eval_test.evaluate(batch_size = batch_size)
            self.test_acc_list.append(test_acc)
            print('--- Epoch ',epoch,' Acc: ',test_acc)
            mean_loss = torch.tensor(train_losses).mean().item()
            print('Training loss: %.4f' % (mean_loss))
        return

In [4]:
class MyEvaluator():
  def __init__(self,model,test, device = None):
    self.test_dataset = test
    self.model = model
    self.device = device
  
  def evaluate(self, batch_size = 64):
      self.model.eval()
      loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle = False)
      n_correct = 0
      n_possible = 0
      for item in loader:
        item['input_ids'] = item['input_ids'].to(self.device)
        item['attention_mask'] = item['attention_mask'].to(self.device)
        item['pixel_values'] = item['pixel_values'].to(self.device)
        item['label'] = item['label'].to(self.device)
        y_hat = self.model.predict(item)
        y = item['label']
        n_correct += (y == y_hat).sum().item()
        n_possible += float(y.shape[0])
      self.model.train()
      return n_correct / n_possible

In [5]:
class CLIP(CLIPModel):
    def __init__(self, num_labels=3):
      super().__init__(CLIPConfig.from_pretrained("openai/clip-vit-base-patch32"))
      #super().__init__(CLIPConfig.from_pretrained("flax-community/clip-rsicd-v2"))
      self.new_encoder_layer = torch.nn.TransformerEncoderLayer(d_model=512, nhead=4)
      self.new_transformer_encoder = torch.nn.TransformerEncoder(self.new_encoder_layer, num_layers=3)
      self.classification = torch.nn.Linear(512, num_labels, bias=True)
      self.num_labels = num_labels
    
    def forward(self, input_ids=None, pixel_values=None, attention_mask=None, position_ids=None, return_loss=None, output_attentions=None, output_hidden_states=None, label=None):
        output = super().forward(input_ids,  pixel_values, attention_mask, position_ids, return_loss, output_attentions, output_hidden_states, return_dict=True)
        #print('ov',output.vision_model_output[0].size())
        #print('ot',output.text_model_output[0].size())
        #print('am',attention_mask.size())
        aux_vision = output.vision_model_output[0]#.pooler_output#
        aux_vision = self.visual_projection(aux_vision)
        aux_text = output.text_model_output[0]#.pooler_output#[0]
        aux_text = self.text_projection(aux_text)
        aux = torch.cat((aux_vision,aux_text),dim=1)

        ones = torch.ones(aux_vision.shape[0],aux_vision.shape[1],dtype=torch.float).to(device)
        aux_mask = torch.cat((ones,attention_mask), dim=1)
        padding_mask = torch.swapaxes(aux_mask, 0, 1)

        #print('aux',aux.size())
        #print('aux_mask',aux_mask.size())

        aux = self.new_transformer_encoder( aux, src_key_padding_mask= padding_mask)
        #aux = self.new_transformer_encoder( aux, mask= padding_mask)
        #print(aux.shape)#change back shape to (batch size, sequence length, features)
        
        input_mask_expanded = aux_mask.unsqueeze(-1).expand(aux.size()).float()
        #print('input_mask_expanded ',input_mask_expanded.size())
        #print('aux',aux.size())
        #print('aux_mask ',aux_mask.size())
        #aux = torch.sum(aux * aux_mask, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        aux = torch.sum(aux * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

        output.logits = self.classification(aux)
        return output
    
    def predict(self,item):
      """
      item (n_examples x n_features)
      """
      scores = model(**item)  # (n_examples x n_classes)
      predicted_labels = scores.logits.argmax(dim=-1)  # (n_examples)
      return predicted_labels
  
    def save_model(self,path):
        torch.save(self.state_dict(), path)
        
    def load_model(self,path):
        self.load_state_dict(torch.load(path))
        self.eval()

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
torch.cuda.get_device_name(device)

In [None]:
max_length = 32
train = MyDataset('../e-ViL/data/esnlive_train.csv','../e-ViL/data/flickr30k_images/flickr30k_images/',
                  max_length= max_length, vsr = 'train.json', only_vsr = True)
test = MyDataset('../e-ViL/data/esnlive_test.csv','../e-ViL/data/flickr30k_images/flickr30k_images/',max_length=max_length
                            , vsr = 'test.json', only_vsr = True)
dev = MyDataset('../e-ViL/data/esnlive_dev.csv','../e-ViL/data/flickr30k_images/flickr30k_images/',max_length=max_length
                           , vsr = 'dev.json', only_vsr = True)

In [None]:
print(len(train))
print(len(test))
print(len(dev))

In [6]:
model = CLIP()

In [7]:
print(model)

CLIP(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0): CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_

In [None]:
"""
#modules = [model.text_model.embeddings, model.vision_model.embeddings]
modules = [model.text_model, model.vision_model]
for module in modules:
    for param in module.parameters():
        param.requires_grad = False
"""

In [None]:
task = 'train'
model_path = 'vsr_len32_batch32_lr1e-5'
batch_size = 32
epochs = 100
lr = 1e-5
if task =='train':
    test_evaluator = MyEvaluator(model,test, device = device)
    dev_evaluator = MyEvaluator(model,dev, device = device)
    trainer = MyTrainer(model,train,test_evaluator, device = device)
    model = model.to(device)
    print("-----Training Model-----")
    trainer.train_model(epochs=epochs ,batch_size = batch_size, lr = lr)
    print('----Training finished-----')
    dev_acc = dev_evaluator.evaluate(batch_size = batch_size)
    print("---- Dev Acc: ",dev_acc)
    model.save_model(model_path)
elif task =='test':
    model.load_model("./models/my_model_epoch_9")
    model = model.to(device)
    evaluator = MyEvaluator(model,dev, device = device)
    acc = evaluator.evaluate(batch_size = batch_size)
    print(acc)