In [1]:
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
import pandas as pd
import lmdb
import io
import pickle
import os
import json
import urllib
from transformers import CLIPTokenizer
from transformers import CLIPFeatureExtractor, AutoFeatureExtractor
from transformers import CLIPProcessor
from transformers import CLIPModel, CLIPConfig
from transformers import get_scheduler
from PIL import Image
from tqdm import tqdm
import math
import numpy as np

In [11]:
import torch.nn.functional as F
from pytorch_metric_learning import losses

class SupervisedContrastiveLoss(torch.nn.Module):
    def __init__(self, temperature=0.1):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, feature_vectors, labels):
        # Normalize feature vectors
        feature_vectors_normalized = F.normalize(feature_vectors, p=2, dim=1)
        # Compute logits
        logits = torch.div(
            torch.matmul(
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
            ),
            self.temperature,
        )
        return losses.NTXentLoss(temperature=self.temperature)(logits, torch.squeeze(labels))

In [3]:
class ContrastiveDataset(torch.utils.data.Dataset):
    def __init__(self,dataset_path,image_path, vsr = None, vsr_image_path = './data/vsr-images', max_length = 77):
        self.max_length = max_length
        self.image_path = image_path
        if(vsr):#assess only for the vsr dataset
            self.dataset = self.read_vsr_dataset(vsr)
        else:#assess for the SNLI-VE dataset
            self.dataset = self.read_dataset(dataset_path)
        #self.features_extract = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
        self.features_extract = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        #self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    def read_vsr_dataset(self,dataset_name, dataset_path = '../visual-spatial-reasoning/',splits_path='splits/', 
                         image_path = 'images/',sort = True, encode_labels = False):
        dataset = pd.read_json(dataset_path+splits_path+dataset_name, lines =True)
        dataset = dataset[['caption','image','label','relation']]
        dataset.rename(columns = {'caption':'hypothesis', 'image':'Flickr30kID', 'label' : 'gold_label'}, inplace = True)
        dataset['Flickr30kID']=dataset['Flickr30kID'].apply(lambda img_name: dataset_path + image_path + img_name )
        if encode_labels:
            labels_encoding = {0:0,1:2}#leave the label 0 the same and convert 1 to 2 to mean entailment
            dataset['gold_label']=dataset['gold_label'].apply(lambda label: labels_encoding[label])
        if(dataset_name=='train.json'):
            dataset.drop(labels=[1786,3569,4553,4912], axis=0, inplace = True)
        elif(dataset_name=='test.json'):
            dataset.drop(labels=[135,614,1071,1621,1850], axis=0, inplace = True)
        elif(dataset_name=='dev.json'):
            dataset.drop(labels=[807], axis=0, inplace = True)
        dataset.reset_index(drop=True, inplace=True)
        if sort:
            #dataset.sort_values(by="hypothesis", key=lambda x: x.str.len(), inplace = True)
            dataset.sort_values(['relation'],ascending=False, inplace=True)
        return dataset
    
    def read_dataset(self, url,sort = False):
        dataset = pd.read_csv(url)
        labels_encoding = {'contradiction':0,'neutral': 1,
                           'entailment':2}
        dataset = dataset[['hypothesis','Flickr30kID','gold_label']]
        dataset['gold_label']=dataset['gold_label'].apply(lambda label: labels_encoding[label])
        dataset['Flickr30kID']=dataset['Flickr30kID'].apply(lambda img_name: self.image_path + img_name )
        if sort:
            dataset.sort_values(by="hypothesis", key=lambda x: x.str.len(), inplace = True)
        return dataset
    
    def get_visual_features(self,img):
        return self.features_extract(img, return_tensors="pt")
    
    def get_text_features(self,text): 
        return self.tokenizer(text, return_tensors="pt", padding = True, truncation = True)
    
    def __getitem__(self, idx):
        sample = self.dataset.loc[idx]
        img_name = sample['Flickr30kID']
        text = sample['hypothesis']
        label = sample['gold_label']
        
        item = {'text': text,
                'image': Image.open(img_name),
                'label': torch.tensor(label,dtype = torch.long)}
        return item
    
    def collate_fn(self,batch):
        #print(batch)
        text = []
        images = []
        labels = []
        for item in batch:            
            text.append(item['text'])
            images.append(item['image'])
            labels.append(item['label'])
        item_img = self.get_visual_features(images)
        item_text = self.get_text_features(text)
        #item = self.processor(text=text, images= images, return_tensors="pt",truncation=True,padding =True)
        item = {**item_img,**item_text}
        item['label'] = torch.tensor(labels,dtype = torch.long)
        return item
        
    def __len__(self):
        return len(self.dataset.index)

In [4]:
class ContrastiveTrainer():
    def __init__(self,model,train,eval_test, device = None, num_labels = 3):
        self.device = device
        self.model = model
        self.train = train
        self.eval_test = eval_test
        self.test_acc_list = []#init
        self.model_path = "./models/new_my_model_epoch_"
        self.num_labels = num_labels
        self.config_problem_type = "single_label_classification"
        if self.config_problem_type == "single_label_classification":
          self.loss_fct = SupervisedContrastiveLoss()
          self.output_loss = lambda outputs,labels : self.loss_fct(outputs,labels) 
        self.train_loader = DataLoader(self.train, batch_size=batch_size, shuffle=False, num_workers = 4,
                                 collate_fn = self.train.collate_fn)
        
    def train_model(self,batch_size = None, lr= None, epochs=None):
        optimizer = AdamW(self.model.parameters(), lr=lr)
        lr_scheduler = get_scheduler(
            name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps= epochs * len(self.train_loader)
        )
        for epoch in range(epochs):
            progress_bar = tqdm(range(math.ceil(len(self.train)/batch_size)))
            train_losses = []
            for item in self.train_loader:
                """
                print(item.keys())
                for key, value in item.items() :
                    print(value.shape)
                    print(key,'\n',value)
                """
                item['input_ids'] = item['input_ids'].to(self.device)
                item['attention_mask'] = item['attention_mask'].to(self.device)
                item['pixel_values'] = item['pixel_values'].to(self.device)
                item['label'] = item['label'].to(self.device)
                optimizer.zero_grad()
                outputs = self.model.forward(**item)
                label = item['label']
                loss = self.output_loss(outputs, label)
                train_losses.append(loss)
                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                progress_bar.update(1)
            #print("Saving model ....")
            #model.save_model(self.model_path+str(epoch))
            #print("Model Saved!")
            test_acc = self.eval_test.evaluate(batch_size = batch_size)
            self.test_acc_list.append(test_acc)
            print('--- Epoch ',epoch,' Acc: ',test_acc)
            mean_loss = torch.tensor(train_losses).mean().item()
            print('Training loss: %.4f' % (mean_loss))
        return

In [5]:
class MyEvaluator():
  def __init__(self,model,test, device = None):
    self.test_dataset = test
    self.model = model
    self.device = device
  
  def evaluate(self, batch_size = 8):
      self.model.eval()
      loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle = False, num_workers = 4,
                         collate_fn = self.test_dataset.collate_fn)
      n_correct = 0
      n_possible = 0
      for item in loader:
        item['input_ids'] = item['input_ids'].to(self.device)
        item['attention_mask'] = item['attention_mask'].to(self.device)
        item['pixel_values'] = item['pixel_values'].to(self.device)
        item['label'] = item['label'].to(self.device)
        y_hat = self.model.predict(item)
        y = item['label']
        n_correct += (y == y_hat).sum().item()
        n_possible += float(y.shape[0])
      self.model.train()
      return n_correct / n_possible

In [6]:
class CLIPClassifier(torch.nn.Module):
    def __init__(self, clip, num_labels=3):
      super(CLIPClassifier, self).__init__()
      self.clip = clip
      self.new_encoder_layer = torch.nn.TransformerEncoderLayer(d_model=512, nhead=4)
      self.new_transformer_encoder = torch.nn.TransformerEncoder(self.new_encoder_layer, num_layers=3)
      self.classification = torch.nn.Linear(512, num_labels, bias=True)
      """"
      self.classification = torch.nn.Sequential( torch.nn.Linear(in_features=512, out_features=1536, bias=True)
        , torch.nn.GELU()
        , torch.nn.LayerNorm((1536,), eps=1e-12, elementwise_affine=True)
        , torch.nn.Linear(in_features=1536, out_features=num_labels, bias=True)
      )
      """
      """"
      self.classification = torch.nn.Sequential(torch.nn.Linear(in_features=512, out_features=512, bias=True)
          ,torch.nn.GELU()
          ,torch.nn.LayerNorm((512,), eps=1e-12, elementwise_affine=True)
           ,torch.nn.Linear(in_features=512, out_features=30522, bias=False)
            , torch.nn.Linear(in_features=30522, out_features=num_labels, bias=True))
      """
      self.num_labels = num_labels
    
    def forward(self, input_ids=None, pixel_values=None, attention_mask=None, position_ids=None, return_loss=None, output_attentions=None, output_hidden_states=None, label=None):
        output = self.clip.forward(input_ids,  pixel_values, attention_mask, position_ids, return_loss, output_attentions, output_hidden_states, return_dict=True)
        #print('ov',output.vision_model_output[0].size())
        #print('ot',output.text_model_output[0].size())
        #print('am',attention_mask.size())
        aux_vision = output.vision_model_output[0]#.pooler_output#
        aux_vision = self.clip.visual_projection(aux_vision)
        aux_text = output.text_model_output[0]#.pooler_output#[0]
        aux_text = self.clip.text_projection(aux_text)
        aux = torch.cat((aux_vision,aux_text),dim=1)

        ones = torch.ones(aux_vision.shape[0],aux_vision.shape[1],dtype=torch.float).to(device)
        aux_mask = torch.cat((ones,attention_mask), dim=1)
        padding_mask = torch.swapaxes(aux_mask, 0, 1)

        #print('aux',aux.size())
        #print('aux_mask',aux_mask.size())

        aux = self.new_transformer_encoder( aux, src_key_padding_mask= padding_mask)
        #aux = self.new_transformer_encoder( aux, mask= padding_mask)
        #print(aux.shape)#change back shape to (batch size, sequence length, features)
        
        input_mask_expanded = aux_mask.unsqueeze(-1).expand(aux.size()).float()
        #print('input_mask_expanded ',input_mask_expanded.size())
        #print('aux',aux.size())
        #print('aux_mask ',aux_mask.size())
        #aux = torch.sum(aux * aux_mask, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        aux = torch.sum(aux * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

        output.logits = self.classification(aux)
        return output
    
    def predict(self,item):
      """
      item (n_examples x n_features)
      """
      scores = model(**item)  # (n_examples x n_classes)
      predicted_labels = scores.logits.argmax(dim=-1)  # (n_examples)
      return predicted_labels
  
    def save_model(self,path):
        torch.save(self.state_dict(), path)
        
    def load_model(self,path):
        self.load_state_dict(torch.load(path))
        self.eval()

In [7]:
class CLIP(CLIPModel):
    def __init__(self, num_labels=3):
      super().__init__(CLIPConfig.from_pretrained("openai/clip-vit-base-patch32"))
      self.new_encoder_layer = torch.nn.TransformerEncoderLayer(d_model=512, nhead=4)
      self.new_transformer_encoder = torch.nn.TransformerEncoder(self.new_encoder_layer, num_layers=3)
      self.classification = torch.nn.Linear(512, num_labels, bias=True)
      self.num_labels = num_labels
    
    def forward(self, input_ids=None, pixel_values=None, attention_mask=None, position_ids=None, return_loss=None, output_attentions=None, output_hidden_states=None, label=None):
        output = super().forward(input_ids,  pixel_values, attention_mask, position_ids, return_loss, output_attentions, output_hidden_states, return_dict=True)
        #print('ov',output.vision_model_output[0].size())
        #print('ot',output.text_model_output[0].size())
        #print('am',attention_mask.size())
        aux_vision = output.vision_model_output[0]#.pooler_output#
        aux_vision = self.visual_projection(aux_vision)
        aux_text = output.text_model_output[0]#.pooler_output#[0]
        aux_text = self.text_projection(aux_text)
        aux = torch.cat((aux_vision,aux_text),dim=1)

        ones = torch.ones(aux_vision.shape[0],aux_vision.shape[1],dtype=torch.float).to(device)
        aux_mask = torch.cat((ones,attention_mask), dim=1)
        padding_mask = torch.swapaxes(aux_mask, 0, 1)

        #print('aux',aux.size())
        #print('aux_mask',aux_mask.size())

        aux = self.new_transformer_encoder( aux, src_key_padding_mask= padding_mask)
        #aux = self.new_transformer_encoder( aux, mask= padding_mask)
        #print(aux.shape)#change back shape to (batch size, sequence length, features)
        
        input_mask_expanded = aux_mask.unsqueeze(-1).expand(aux.size()).float()
        #print('input_mask_expanded ',input_mask_expanded.size())
        #print('aux',aux.size())
        #print('aux_mask ',aux_mask.size())
        #aux = torch.sum(aux * aux_mask, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        aux = torch.sum(aux * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

        output.logits = self.classification(aux)
        return output
    
    def predict(self,item):
      """
      item (n_examples x n_features)
      """
      scores = model(**item)  # (n_examples x n_classes)
      predicted_labels = scores.logits.argmax(dim=-1)  # (n_examples)
      return predicted_labels
  
    def save_model(self,path):
        torch.save(self.state_dict(), path)
        
    def load_model(self,path):
        self.load_state_dict(torch.load(path))
        self.eval()

In [8]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
torch.cuda.get_device_name(device)

'NVIDIA GeForce RTX 3080'

In [9]:
dataset = 'vsr'
task = 'train'
batch_size = 32
epochs = 100
lr = 1e-5
if dataset =='vsr':
    num_labels = 2
elif dataset =='snli-ve':
    num_labels =3

In [None]:
model = CLIP(num_labels = num_labels)

In [10]:
model = CLIPClassifier(CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
                                      , num_labels = num_labels)

In [None]:
model = model.to(device)

In [None]:
print(model)

In [None]:
"""
#modules = [model.text_model.embeddings, model.vision_model.embeddings]
modules = [model.text_model, model.vision_model]
for module in modules:
    for param in module.parameters():
        param.requires_grad = False
"""

In [None]:
max_length = 32
train = ContrastiveDataset("../e-ViL/data/esnlive_train.csv",
                      "./data/my_image_db",
                      max_length = max_length, vsr= 'train.json')
test = ContrastiveDataset("../e-ViL/data/esnlive_test.csv",
                      "./data/my_image_db",
                      max_length = max_length,
                        vsr= 'test.json')
dev = ContrastiveDataset("../e-ViL/data/esnlive_dev.csv",
                      "./data/my_image_db",
                      max_length = max_length, vsr= 'dev.json')

In [None]:
if task =='train':
    test_evaluator = MyEvaluator(model,test, device = device)
    dev_evaluator = MyEvaluator(model,dev, device = device)
    trainer = ContrastiveTrainer(model,train,test_evaluator, device = device, num_labels = num_labels)
    print("-----Training Model-----")
    trainer.train_model(epochs=epochs ,batch_size = batch_size, lr = lr)
    print('----Training finished-----')
    dev_acc = dev_evaluator.evaluate(batch_size = batch_size)
    print("---- Dev Acc: ",dev_acc)
    train_acc = MyEvaluator(model,train,device=device).evaluate(batch_size = batch_size)
    print("--- Train Acc: ", train_acc)
    model.save_model(dataset+'_len'+str(max_length)+'_batch'+str(batch_size)+'_lr'+str(lr))
elif task =='test':
    model.load_model("my_model_epoch_9")
    evaluator = MyEvaluator(model,dev, device = device)
    acc = evaluator.evaluate(batch_size = batch_size)
    print(acc)
    #output = run_example(model,train)

In [None]:
max_length = 32
train = MyDataset("../e-ViL/data/esnlive_train.csv",
                      "./data/my_image_db",
                      max_length = max_length, vsr= 'train.json')
test = MyDataset("../e-ViL/data/esnlive_test.csv",
                      "./data/my_image_db",
                      max_length = max_length,
                        vsr= 'test.json')
dev = MyDataset("../e-ViL/data/esnlive_dev.csv",
                      "./data/my_image_db",
                      max_length = max_length, vsr= 'dev.json')

In [None]:
print(len(train))
print(len(test))
print(len(dev))

In [None]:
if task =='train':
    test_evaluator = MyEvaluator(model,test, device = device)
    dev_evaluator = MyEvaluator(model,dev, device = device)
    trainer = MyTrainer(model,train,test_evaluator, device = device, num_labels = num_labels)
    print("-----Training Model-----")
    trainer.train_model(epochs=epochs ,batch_size = batch_size, lr = lr)
    print('----Training finished-----')
    dev_acc = dev_evaluator.evaluate(batch_size = batch_size)
    print("---- Dev Acc: ",dev_acc)
    train_acc = MyEvaluator(model,train,device=device).evaluate(batch_size = batch_size)
    print("--- Train Acc: ", train_acc)
    model.save_model(dataset+'_len'+str(max_length)+'_batch'+str(batch_size)+'_lr'+str(lr))
elif task =='test':
    model.load_model("my_model_epoch_9")
    evaluator = MyEvaluator(model,dev, device = device)
    acc = evaluator.evaluate(batch_size = batch_size)
    print(acc)
    #output = run_example(model,train)