In [29]:
import torch
from PIL import Image
import requests
from lavis.models import load_model_and_preprocess

from torch import nn
import pandas as pd
import os

from tqdm import tqdm

from sklearn.metrics import classification_report
import json

from torch.utils.data import Dataset
import torch.utils.data as data

# Load the model

In [2]:
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [3]:
model, vis_processors, txt_processors = load_model_and_preprocess(
    name = "blip_feature_extractor", model_type="base", is_eval=True, device=device
)

# Define the Dataset class

In [46]:
class TwitterCOMMsDataset(Dataset):
    def __init__(self, csv_path, img_dir):
        """
        Args:
            csv_path (string): Path to the {train_completed|val_completed}.csv file.
            image_folder_dir (string): Directory containing the images
        """
        self.df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        
        self.df['exists'] = self.df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))
        delete_row = self.df[self.df["exists"]==False].index
        self.df = self.df.drop(delete_row)
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        caption = item['full_text']
        img_filename = item['filename']
        topic = item['topic']
        label = item['falsified']
        domain = topic.split('_')[0]
        diff = topic.split('_')[1]
        
        try:
            raw_image = Image.open(os.path.join(self.img_dir, img_filename)).convert('RGB')
            image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
            text_input = txt_processors["eval"](caption)
            sample = {"image": image, "text_input": [text_input]}   # image shape: [1, 3, 224, 224]
        
            features_multimodal = model.extract_features(sample, mode="multimodal")
#             features_image = model.extract_features(sample, mode="image")
#             features_text = model.extract_features(sample, mode="text")
#             print(features_multimodal.multimodal_embeds[:, 0, :].shape)
        
            return {"multimodal_emb": features_multimodal.multimodal_embeds[:, 0, :],
                    "topic": topic, 
                    "label": label, 
                    "domain": domain, 
                    "difficulty": diff}
        
        except IOError as e:
            print(e)
        
        
        
    

In [47]:
val_data = TwitterCOMMsDataset(csv_path='../data/val_completed.csv', 
                               img_dir='/import/network-temp/yimengg/data/twitter-comms/images/val_images/val_tweet_image_ids')

In [48]:
BATCH_SIZE = 64

In [49]:
val_iterator = data.DataLoader(val_data, 
                                 batch_size = BATCH_SIZE)

In [55]:
class Net(nn.Module):
    def __init__(self, in_dim, out_dim=2):
        super(Net, self).__init__()
        
        self.fc = nn.Linear(in_dim, out_dim)
    
    def forward(self, x):
        out = self.fc(x)
        return out
    
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

In [54]:
def normal_init(m, mean, std):
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_(mean=0, std=0.02)

In [None]:
def main():
    net = Net(768)
    net.cuda()
    net.train()
    net.weight_init()
    
    lr = 0.0001
    optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
    
    criterion = nn.CrossEntropyLoss()
    criterion.to(device)
 
    EPOCHS = 2
    for epoch in range(EPOCHS):
        for i, batch in tqdm(enumerate(train_iterator, 0), desc='iterations'):
            inputs = batch["multimodal_emb"].to(device)
            labels = batch["label"].to(device)
            
            inputs, labels = Variable(inputs), Variable(labels)
            y_preds = net(inputs)
            loss = criterion(y_preds, labels)
            
            loss.backward()
            optimizer.step()

# Playground

In [36]:
val_df = pd.read_csv('../data/val_completed.csv', index_col=0)
val_df

Unnamed: 0,id,full_text,image_id,filename,falsified,topic
0,1422180237341777922,"As #COP26 approaches, how does the UK plan to ...",1422180237341777922,40/1422180237341777922-1422180234313584640.jpg,False,climate_hard
1,1422180237341777922,"As #COP26 approaches, how does the UK plan to ...",1422180237341777922,40/1422180237341777922-1422180234313584640.jpg,False,climate_random
2,1399648514079141894,"Devil Facial corruption, climate change, fight...",1422180237341777922,40/1422180237341777922-1422180234313584640.jpg,True,climate_hard
3,1407719137917419529,Almost 20% of Canada’s greenhouse gas emission...,1422180237341777922,40/1422180237341777922-1422180234313584640.jpg,True,climate_random
4,1422180237341777922,"As #COP26 approaches, how does the UK plan to ...",1424685512753885201,49/1424685512753885201-1424685510971244549.jpg,True,climate_hard
...,...,...,...,...,...,...
22115,1363526076429836292,@TostevinM They already changed the color of M...,1365950394992529412,66/1365950394992529412-1365950220350091266.jpg,True,military_hard
22116,1365950394992529412,@manilabulletin LOOK: President Duterte inspec...,1365950394992529412,66/1365950394992529412-1365950220350091266.jpg,False,military_hard
22117,1365950394992529412,@manilabulletin LOOK: President Duterte inspec...,1365950394992529412,66/1365950394992529412-1365950220350091266.jpg,False,military_random
22118,1097480948772454403,"A Tornado crew, consisting of a Pilot, Squadro...",1097480948772454403,02/1097480948772454403-1097480946121625602.jpg,False,military_hard


In [38]:
img_dir = '/import/network-temp/yimengg/data/twitter-comms/images/val_images/val_tweet_image_ids'
val_df['exists'] = val_df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))

In [42]:
delete_row = val_df[val_df["exists"]==False].index

In [44]:
val_df = val_df.drop(delete_row)

In [45]:
val_df

Unnamed: 0,id,full_text,image_id,filename,falsified,topic,exists
0,1422180237341777922,"As #COP26 approaches, how does the UK plan to ...",1422180237341777922,40/1422180237341777922-1422180234313584640.jpg,False,climate_hard,True
1,1422180237341777922,"As #COP26 approaches, how does the UK plan to ...",1422180237341777922,40/1422180237341777922-1422180234313584640.jpg,False,climate_random,True
2,1399648514079141894,"Devil Facial corruption, climate change, fight...",1422180237341777922,40/1422180237341777922-1422180234313584640.jpg,True,climate_hard,True
3,1407719137917419529,Almost 20% of Canada’s greenhouse gas emission...,1422180237341777922,40/1422180237341777922-1422180234313584640.jpg,True,climate_random,True
4,1422180237341777922,"As #COP26 approaches, how does the UK plan to ...",1424685512753885201,49/1424685512753885201-1424685510971244549.jpg,True,climate_hard,True
...,...,...,...,...,...,...,...
22115,1363526076429836292,@TostevinM They already changed the color of M...,1365950394992529412,66/1365950394992529412-1365950220350091266.jpg,True,military_hard,True
22116,1365950394992529412,@manilabulletin LOOK: President Duterte inspec...,1365950394992529412,66/1365950394992529412-1365950220350091266.jpg,False,military_hard,True
22117,1365950394992529412,@manilabulletin LOOK: President Duterte inspec...,1365950394992529412,66/1365950394992529412-1365950220350091266.jpg,False,military_random,True
22118,1097480948772454403,"A Tornado crew, consisting of a Pilot, Squadro...",1097480948772454403,02/1097480948772454403-1097480946121625602.jpg,False,military_hard,True
