In [1]:
import torch
from PIL import Image
import requests
from lavis.models import load_model_and_preprocess

from torch import nn
import pandas as pd
import os

from tqdm.auto import tqdm, trange

from sklearn.metrics import classification_report
import json

from torch.utils.data import Dataset
import torch.utils.data as data
from torch import optim
from torch.autograd import Variable
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


# Load the model

In [2]:
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [4]:
model, vis_processors, txt_processors = load_model_and_preprocess(
    name = "blip_feature_extractor", model_type="base", is_eval=True, device=device
)

# Define the Dataset class

In [98]:
class TwitterCOMMsDataset(Dataset):
    def __init__(self, csv_path, img_dir):
        """
        Args:
            csv_path (string): Path to the {train_completed|val_completed}.csv file.
            image_folder_dir (string): Directory containing the images
        """
        self.df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        
        self.df['exists'] = self.df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))
        delete_row = self.df[self.df["exists"]==False].index
        self.df = self.df.drop(delete_row)
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        caption = item['full_text']
        img_filename = item['filename']
        topic = item['topic']
        falsified = int(item['falsified'])
        not_falsified = float(not item['falsified'])
#         label = np.array((falsified, not_falsified))
        label = np.array(falsified)
        domain = topic.split('_')[0]
        diff = topic.split('_')[1]
        
        raw_image = Image.open(os.path.join(self.img_dir, img_filename)).convert('RGB')
        image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
        text_input = txt_processors["eval"](caption)
        sample = {"image": image, "text_input": [text_input]}   # image shape: [1, 3, 224, 224]

#         features_multimodal = model.extract_features(sample, mode="multimodal")   # ??? image and text might mismatch
        features_image = model.extract_features(sample, mode="image")
        features_text = model.extract_features(sample, mode="text")
        features_image_proj = features_image.image_embeds_proj[:,0,:]   # [1, 256]
        features_text_proj = features_text.text_embeds_proj[:,0,:]   # [1, 256]
#         multimodal_emb = torch.cat((features_image_proj, features_text_proj), 1)
        multimodal_emb = features_image_proj * features_text_proj   # [1, 256]
#         print(multimodal_emb.shape)

        similarity = features_image_proj @ features_text_proj.t()

        return {"multimodal_emb": multimodal_emb,
                "topic": topic, 
                "label": label, 
                "domain": domain, 
                "difficulty": diff,
               "similarity": similarity}
        
        
        
    

In [93]:
# class ClimateAndCovidDataset(Dataset):
#     def __init__(self, csv_path, img_dir):
#         """
#         Args:
#             csv_path (string): Path to the {train_completed|val_completed}.csv file.
#             image_folder_dir (string): Directory containing the images
#         """
#         self.df = pd.read_csv(csv_path, index_col=0)
#         self.img_dir = img_dir
        
#         self.df['exists'] = self.df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))
#         delete_row = self.df[self.df["exists"]==False].index
#         self.df = self.df.drop(delete_row)
        
#         self.df['is_military'] = self.df['topic'].apply(lambda topic: 'military' in topic)
#         delete_row = self.df[self.df["is_military"]==True].index
#         self.df = self.df.drop(delete_row)
    
#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):
#         item = self.df.iloc[idx]
#         caption = item['full_text']
#         img_filename = item['filename']
#         topic = item['topic']
#         falsified = float(item['falsified'])
#         not_falsified = float(not item['falsified'])
#         label = np.array((falsified, not_falsified))
#         domain = topic.split('_')[0]
#         diff = topic.split('_')[1]
        
#         try:
#             raw_image = Image.open(os.path.join(self.img_dir, img_filename)).convert('RGB')
#             image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
#             text_input = txt_processors["eval"](caption)
#             sample = {"image": image, "text_input": [text_input]}   # image shape: [1, 3, 224, 224]
        
#             features_multimodal = model.extract_features(sample, mode="multimodal")
# #             features_image = model.extract_features(sample, mode="image")
# #             features_text = model.extract_features(sample, mode="text")
# #             print(features_multimodal.multimodal_embeds[:, 0, :].shape)
        
#             return {"multimodal_emb": features_multimodal.multimodal_embeds[:, 0, :],
#                     "topic": topic, 
#                     "label": label, 
#                     "domain": domain, 
#                     "difficulty": diff}
        
#         except IOError as e:
#             print(e)
        
        
        
    

In [99]:
# train_data = TwitterCOMMsDataset(csv_path='../data/train_completed.csv',
#                                     img_dir='/import/network-temp/yimengg/data/twitter-comms/train/images/train_image_ids')
val_data = TwitterCOMMsDataset(csv_path='../data/val_completed.csv', 
                               img_dir='/import/network-temp/yimengg/data/twitter-comms/images/val_images/val_tweet_image_ids')

In [89]:
train_data.__len__()

2292375

In [100]:
BATCH_SIZE = 64

# train_iterator = data.DataLoader(train_data, 
#                                  shuffle = True, 
#                                  batch_size=BATCH_SIZE)
val_iterator = data.DataLoader(val_data, 
                               shuffle = True, 
                               batch_size=BATCH_SIZE)

In [42]:
class Net(nn.Module):
    def __init__(self, in_dim, out_dim=2):
        super(Net, self).__init__()
        
        self.fc = nn.Linear(in_dim, out_dim)
        self.in_dim = in_dim
    
    def forward(self, x):
        x = x.view(-1, self.in_dim)
        out = self.fc(x)
        return out
    
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

In [43]:
def normal_init(m, mean, std):
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [79]:
def main():
    net = Net(256)
    net.cuda()
    net.train()
    net.weight_init(mean=0, std=0.02)
    
    lr = 0.001
    optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
    
    criterion = nn.CrossEntropyLoss()
    criterion.to(device)
 
    EPOCHS = 2
    for epoch in range(EPOCHS):
        total_loss = 0
        num_correct = 0
        total = 0
        for i, batch in tqdm(enumerate(val_iterator, 0), desc='iterations'):
            inputs = batch["multimodal_emb"].to(device)
            labels = batch["label"].to(device)
            inputs, labels = Variable(inputs), Variable(labels)
            
            net.zero_grad()
            y_preds = net(inputs)
            loss = criterion(y_preds, labels)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            _, top_pred = y_preds.topk(1, 1)
            y = labels.cpu()
            batch_size = y.shape[0]
            top_pred = top_pred.cpu().view(batch_size)
            
#             num_correct += sum(top_pred == y[:, 0]).item()
            num_correct += sum(top_pred == y).item()
            total += batch_size
            
            if i % 50 == 0:
                print("Epoch [%d/%d]: Training accuracy %.2f" % (epoch+1, EPOCHS, num_correct/total))
                print(y_preds)
                print(top_pred)
                print(labels)
                print(num_correct)
                print(total)

    return net

In [107]:
for i, batch in tqdm(enumerate(val_iterator, 0), desc='iterations'):
    num_correct = 0
    total = 0
    similarity = batch["similarity"].squeeze()
    labels = batch["label"]
    y_preds = torch.zeros_like(labels)
    y_preds[similarity < 0.5] = 1
    
    num_correct += sum(y_preds == labels).item()
    total += BATCH_SIZE
    
    if i % 50 == 0:
        print("Accuracy %.2f" % (num_correct/total))
        print(y_preds)

iterations: 1it [00:03,  3.63s/it]

Accuracy 0.36
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


iterations: 51it [02:51,  3.56s/it]

Accuracy 0.47
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


iterations: 101it [05:34,  3.67s/it]

Accuracy 0.55
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


iterations: 151it [08:25,  3.47s/it]

Accuracy 0.42
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


iterations: 201it [11:09,  3.20s/it]

Accuracy 0.59
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


iterations: 251it [13:47,  3.27s/it]

Accuracy 0.55
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


iterations: 301it [16:26,  3.07s/it]

Accuracy 0.53
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


iterations: 346it [18:48,  3.26s/it]


In [80]:
net = main()

iterations: 1it [00:03,  3.25s/it]

Epoch [1/2]: Training accuracy 0.48
tensor([[-1.6271e-03, -1.6783e-03],
        [ 4.2249e-04, -2.7659e-03],
        [-1.4922e-04, -1.6164e-03],
        [ 8.0145e-04, -8.1335e-04],
        [-9.9801e-04, -4.1436e-04],
        [-2.3779e-03, -1.7982e-03],
        [-3.2615e-04, -3.0065e-03],
        [-4.9869e-04, -3.1801e-03],
        [-1.5767e-03, -2.2283e-03],
        [-9.8567e-04, -4.9071e-04],
        [-6.6091e-04, -1.2831e-03],
        [-4.1325e-04, -1.4879e-03],
        [-2.9931e-03, -8.6936e-04],
        [-5.6451e-04, -3.2379e-03],
        [-1.7323e-03, -2.8489e-03],
        [ 1.0438e-03, -2.7642e-04],
        [-2.7389e-04, -1.3078e-03],
        [-1.3912e-03, -1.7819e-03],
        [-1.1076e-03, -1.1403e-03],
        [ 1.2752e-04, -8.2475e-04],
        [-5.3813e-04, -2.3014e-03],
        [-5.0159e-04, -3.5589e-03],
        [-5.4779e-04, -8.2641e-04],
        [-1.5276e-03, -2.2425e-03],
        [-1.4742e-03, -3.0347e-03],
        [-1.3799e-03, -2.0323e-03],
        [ 3.0430e-04, -2.135

iterations: 51it [02:52,  3.40s/it]

Epoch [1/2]: Training accuracy 0.52
tensor([[ 0.0249, -0.0284],
        [ 0.0256, -0.0284],
        [ 0.0241, -0.0257],
        [ 0.0252, -0.0286],
        [ 0.0245, -0.0262],
        [ 0.0277, -0.0303],
        [ 0.0229, -0.0271],
        [ 0.0249, -0.0269],
        [ 0.0259, -0.0292],
        [ 0.0263, -0.0301],
        [ 0.0266, -0.0295],
        [ 0.0229, -0.0282],
        [ 0.0247, -0.0266],
        [ 0.0231, -0.0285],
        [ 0.0249, -0.0288],
        [ 0.0268, -0.0289],
        [ 0.0245, -0.0289],
        [ 0.0247, -0.0275],
        [ 0.0244, -0.0279],
        [ 0.0282, -0.0295],
        [ 0.0259, -0.0296],
        [ 0.0252, -0.0288],
        [ 0.0256, -0.0262],
        [ 0.0264, -0.0288],
        [ 0.0250, -0.0295],
        [ 0.0273, -0.0297],
        [ 0.0243, -0.0292],
        [ 0.0259, -0.0284],
        [ 0.0253, -0.0295],
        [ 0.0269, -0.0288],
        [ 0.0248, -0.0288],
        [ 0.0255, -0.0283],
        [ 0.0272, -0.0277],
        [ 0.0242, -0.0287],
        [ 0.

iterations: 101it [05:38,  3.63s/it]

Epoch [1/2]: Training accuracy 0.53
tensor([[ 0.0378, -0.0407],
        [ 0.0429, -0.0460],
        [ 0.0405, -0.0420],
        [ 0.0443, -0.0462],
        [ 0.0438, -0.0457],
        [ 0.0425, -0.0461],
        [ 0.0428, -0.0487],
        [ 0.0418, -0.0454],
        [ 0.0430, -0.0473],
        [ 0.0438, -0.0464],
        [ 0.0432, -0.0455],
        [ 0.0386, -0.0397],
        [ 0.0400, -0.0426],
        [ 0.0442, -0.0447],
        [ 0.0415, -0.0433],
        [ 0.0439, -0.0453],
        [ 0.0411, -0.0465],
        [ 0.0404, -0.0441],
        [ 0.0430, -0.0483],
        [ 0.0416, -0.0436],
        [ 0.0393, -0.0427],
        [ 0.0420, -0.0426],
        [ 0.0422, -0.0425],
        [ 0.0437, -0.0463],
        [ 0.0410, -0.0443],
        [ 0.0391, -0.0426],
        [ 0.0428, -0.0460],
        [ 0.0449, -0.0477],
        [ 0.0405, -0.0413],
        [ 0.0406, -0.0459],
        [ 0.0394, -0.0418],
        [ 0.0406, -0.0417],
        [ 0.0410, -0.0419],
        [ 0.0418, -0.0455],
        [ 0.

iterations: 151it [08:26,  3.26s/it]

Epoch [1/2]: Training accuracy 0.53
tensor([[ 0.0453, -0.0462],
        [ 0.0516, -0.0536],
        [ 0.0482, -0.0502],
        [ 0.0509, -0.0528],
        [ 0.0443, -0.0457],
        [ 0.0484, -0.0508],
        [ 0.0509, -0.0540],
        [ 0.0452, -0.0473],
        [ 0.0465, -0.0477],
        [ 0.0475, -0.0478],
        [ 0.0493, -0.0551],
        [ 0.0438, -0.0478],
        [ 0.0466, -0.0488],
        [ 0.0565, -0.0575],
        [ 0.0451, -0.0490],
        [ 0.0510, -0.0554],
        [ 0.0508, -0.0522],
        [ 0.0540, -0.0567],
        [ 0.0454, -0.0486],
        [ 0.0421, -0.0455],
        [ 0.0501, -0.0530],
        [ 0.0491, -0.0535],
        [ 0.0471, -0.0476],
        [ 0.0492, -0.0508],
        [ 0.0473, -0.0507],
        [ 0.0526, -0.0532],
        [ 0.0491, -0.0508],
        [ 0.0468, -0.0463],
        [ 0.0479, -0.0526],
        [ 0.0487, -0.0509],
        [ 0.0479, -0.0499],
        [ 0.0477, -0.0520],
        [ 0.0494, -0.0513],
        [ 0.0475, -0.0489],
        [ 0.

iterations: 201it [11:14,  3.38s/it]

Epoch [1/2]: Training accuracy 0.53
tensor([[ 0.0593, -0.0617],
        [ 0.0541, -0.0589],
        [ 0.0532, -0.0567],
        [ 0.0519, -0.0576],
        [ 0.0566, -0.0577],
        [ 0.0606, -0.0624],
        [ 0.0535, -0.0567],
        [ 0.0577, -0.0590],
        [ 0.0606, -0.0622],
        [ 0.0603, -0.0618],
        [ 0.0564, -0.0611],
        [ 0.0547, -0.0557],
        [ 0.0608, -0.0615],
        [ 0.0607, -0.0642],
        [ 0.0552, -0.0536],
        [ 0.0601, -0.0630],
        [ 0.0588, -0.0631],
        [ 0.0603, -0.0623],
        [ 0.0513, -0.0578],
        [ 0.0593, -0.0615],
        [ 0.0545, -0.0587],
        [ 0.0542, -0.0561],
        [ 0.0566, -0.0602],
        [ 0.0611, -0.0628],
        [ 0.0603, -0.0633],
        [ 0.0544, -0.0579],
        [ 0.0523, -0.0564],
        [ 0.0626, -0.0656],
        [ 0.0656, -0.0697],
        [ 0.0590, -0.0618],
        [ 0.0600, -0.0623],
        [ 0.0640, -0.0659],
        [ 0.0571, -0.0606],
        [ 0.0599, -0.0654],
        [ 0.

iterations: 226it [12:41,  3.37s/it]


KeyboardInterrupt: 

# Playground

In [None]:
val_df = pd.read_csv('../data/val_completed.csv', index_col=0)
val_df

In [None]:
img_dir = '/import/network-temp/yimengg/data/twitter-comms/images/val_images/val_tweet_image_ids'
val_df['exists'] = val_df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))

In [None]:
delete_row = val_df[val_df["exists"]==False].index

In [None]:
val_df = val_df.drop(delete_row)

In [None]:
val_df