In [1]:
import torch
from PIL import Image
import requests
from lavis.models import load_model_and_preprocess

from torch import nn
import pandas as pd
import os

from tqdm.auto import tqdm, trange

from sklearn.metrics import classification_report
import json

from torch.utils.data import Dataset
import torch.utils.data as data
from torch import optim
from torch.autograd import Variable
import numpy as np
from sentence_transformers import util

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from lavis.models import model_zoo

In [None]:
print(model_zoo)

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1,2,3"

# Load the model

In [2]:
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [3]:
model, vis_processors, txt_processors = load_model_and_preprocess(
    name = "clip_feature_extractor", model_type="ViT-B-32", is_eval=True, device=device
)

# Define the Dataset class

In [4]:
class TwitterCOMMsDataset(Dataset):
    def __init__(self, csv_path, img_dir):
        """
        Args:
            csv_path (string): Path to the {train_completed|val_completed}.csv file.
            image_folder_dir (string): Directory containing the images
        """
        self.df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        
        self.df['exists'] = self.df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))
        delete_row = self.df[self.df["exists"]==False].index
        self.df = self.df.drop(delete_row)
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        caption = item['full_text']
        img_filename = item['filename']
        topic = item['topic']
        falsified = int(item['falsified'])
        not_falsified = float(not item['falsified'])
#         label = np.array((falsified, not_falsified))
        label = np.array(falsified)
        domain = topic.split('_')[0]
        diff = topic.split('_')[1]
        
        raw_image = Image.open(os.path.join(self.img_dir, img_filename)).convert('RGB')
        image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
        text_input = txt_processors["eval"](caption)
        sample = {"image": image, "text_input": [text_input]}   # image shape: [1, 3, 224, 224]

        features = model.extract_features(sample)
        features_image = features.image_embeds   # [1, 512]
        features_text = features.text_embeds   # [1, 512]
        multimodal_emb = torch.cat((features_image, features_text), 1)
        cos_sim = util.cos_sim(features_text, features_image)
#         features_image_proj = features_image.image_embeds_proj[:,0,:]   # [1, 256]
#         features_text_proj = features_text.text_embeds_proj[:,0,:]   # [1, 256]
        
#         multimodal_emb = torch.cat((features_image_proj, features_text_proj), 1)
#         multimodal_emb = features_image_proj * features_text_proj   # [1, 256]
#         print(multimodal_emb.shape)

#         similarity = features_image_proj @ features_text_proj.t()

        return {"multimodal_emb": multimodal_emb,
                "topic": topic, 
                "label": label, 
                "domain": domain, 
                "difficulty": diff,
               "similarity": cos_sim}
        
        
        
    

In [None]:
# class ClimateAndCovidDataset(Dataset):
#     def __init__(self, csv_path, img_dir):
#         """
#         Args:
#             csv_path (string): Path to the {train_completed|val_completed}.csv file.
#             image_folder_dir (string): Directory containing the images
#         """
#         self.df = pd.read_csv(csv_path, index_col=0)
#         self.img_dir = img_dir
        
#         self.df['exists'] = self.df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))
#         delete_row = self.df[self.df["exists"]==False].index
#         self.df = self.df.drop(delete_row)
        
#         self.df['is_military'] = self.df['topic'].apply(lambda topic: 'military' in topic)
#         delete_row = self.df[self.df["is_military"]==True].index
#         self.df = self.df.drop(delete_row)
    
#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):
#         item = self.df.iloc[idx]
#         caption = item['full_text']
#         img_filename = item['filename']
#         topic = item['topic']
#         falsified = float(item['falsified'])
#         not_falsified = float(not item['falsified'])
#         label = np.array((falsified, not_falsified))
#         domain = topic.split('_')[0]
#         diff = topic.split('_')[1]
        
#         try:
#             raw_image = Image.open(os.path.join(self.img_dir, img_filename)).convert('RGB')
#             image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
#             text_input = txt_processors["eval"](caption)
#             sample = {"image": image, "text_input": [text_input]}   # image shape: [1, 3, 224, 224]
        
#             features_multimodal = model.extract_features(sample, mode="multimodal")
# #             features_image = model.extract_features(sample, mode="image")
# #             features_text = model.extract_features(sample, mode="text")
# #             print(features_multimodal.multimodal_embeds[:, 0, :].shape)
        
#             return {"multimodal_emb": features_multimodal.multimodal_embeds[:, 0, :],
#                     "topic": topic, 
#                     "label": label, 
#                     "domain": domain, 
#                     "difficulty": diff}
        
#         except IOError as e:
#             print(e)
        
        
        
    

In [5]:
# train_data = TwitterCOMMsDataset(csv_path='../data/train_completed.csv',
#                                     img_dir='/import/network-temp/yimengg/data/twitter-comms/train/images/train_image_ids')   # took ~one hour to construct the dataset
val_data = TwitterCOMMsDataset(csv_path='../data/val_completed.csv', 
                               img_dir='/import/network-temp/yimengg/data/twitter-comms/images/val_images/val_tweet_image_ids')

In [None]:
train_data.__len__()

In [6]:
BATCH_SIZE = 32

# train_iterator = data.DataLoader(train_data, 
#                                  shuffle = True, 
#                                  batch_size=BATCH_SIZE)
val_iterator = data.DataLoader(val_data, 
                               shuffle = False, 
                               batch_size=BATCH_SIZE)

In [7]:
class Net(nn.Module):
    def __init__(self, in_dim, out_dim=2):
        super(Net, self).__init__()
        
        self.fc = nn.Linear(in_dim, out_dim)
        self.in_dim = in_dim
    
    def forward(self, x):
        x = x.view(-1, self.in_dim)
        out = self.fc(x)
        return out
    
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

In [8]:
def normal_init(m, mean, std):
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [9]:
def main():
#     net = Net(256)
    net = Net(1024)
    net.cuda()
    net.train()
    net.weight_init(mean=0, std=0.02)
    
    lr = 0.0001
    optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
    
    criterion = nn.CrossEntropyLoss()
    criterion.to(device)
    
    softmax = nn.Softmax(dim=1)
 
    EPOCHS = 2
    for epoch in range(EPOCHS):
        total_loss = 0
        num_correct = 0
        total = 0
        for i, batch in tqdm(enumerate(val_iterator, 0), desc='iterations'):
            inputs = batch["multimodal_emb"].to(device)
            labels = batch["label"].to(device)
            inputs, labels = Variable(inputs), Variable(labels)
            
            net.zero_grad()
            y_preds = net(inputs)
            loss = criterion(y_preds, labels)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
#             _, top_pred = y_preds.topk(1, 1)
            
            top_pred = torch.zeros_like(labels)
            y_preds = softmax(y_preds)
#             print(y_preds[:, 0])
            top_pred[y_preds[:, 1] < 0.25] = 1
            y = labels.cpu()
            batch_size = y.shape[0]
            top_pred = top_pred.cpu().view(batch_size)
            
#             num_correct += sum(top_pred == y[:, 0]).item()
            num_correct += sum(top_pred == y).item()
            total += batch_size
            
            if i % 50 == 0:
                print("Epoch [%d/%d]: Training accuracy %.2f" % (epoch+1, EPOCHS, num_correct/total))
#                 print(y_preds)
#                 print(top_pred)
#                 print(labels)
#                 print(num_correct)
#                 print(total)

    return net

In [None]:
net = main()

iterations: 1it [00:04,  4.77s/it]

Epoch [1/2]: Training accuracy 0.50


  "Palette images with Transparency expressed in bytes should be "
iterations: 51it [01:41,  2.33s/it]

Epoch [1/2]: Training accuracy 0.51


iterations: 101it [03:01,  1.59s/it]

Epoch [1/2]: Training accuracy 0.51


iterations: 151it [04:30,  1.84s/it]

Epoch [1/2]: Training accuracy 0.52


iterations: 201it [05:54,  1.98s/it]

Epoch [1/2]: Training accuracy 0.52


iterations: 251it [07:24,  2.07s/it]

Epoch [1/2]: Training accuracy 0.52


iterations: 301it [08:37,  1.33s/it]

Epoch [1/2]: Training accuracy 0.52


iterations: 351it [09:56,  1.32s/it]

Epoch [1/2]: Training accuracy 0.53


iterations: 401it [11:22,  2.33s/it]

Epoch [1/2]: Training accuracy 0.53


iterations: 451it [12:41,  1.41s/it]

Epoch [1/2]: Training accuracy 0.53


iterations: 501it [14:01,  1.19s/it]

Epoch [1/2]: Training accuracy 0.53


iterations: 551it [15:17,  1.45s/it]

Epoch [1/2]: Training accuracy 0.53


iterations: 601it [16:34,  1.91s/it]

Epoch [1/2]: Training accuracy 0.53


iterations: 651it [18:00,  1.69s/it]

Epoch [1/2]: Training accuracy 0.53


iterations: 691it [19:02,  1.65s/it]
iterations: 1it [00:01,  1.33s/it]

Epoch [2/2]: Training accuracy 0.50


iterations: 36it [01:04,  1.80s/it]


In [None]:
!nvidia-smi

In [None]:
num_correct = 0
total = 0
for i, batch in tqdm(enumerate(val_iterator, 0), desc='iterations'):
    similarity = batch["similarity"].squeeze()
    labels = batch["label"]
    y_preds = torch.zeros_like(labels)
    y_preds[similarity < 0.25] = 1
    
    num_correct += sum(y_preds == labels).item()
    total += BATCH_SIZE
    
    if i % 50 == 0:
        print("Accuracy %.2f" % (num_correct/total))
        print(y_preds)
        print(similarity)

In [None]:
idx = 6
print(val_data.df.iloc[idx]['full_text'])
print("Falsified? " + str(val_data.df.iloc[idx]['falsified']))
print(val_data.df.iloc[idx]['topic'])
print(val_data[idx]['similarity'])

In [None]:
raw_image = Image.open('/import/network-temp/yimengg/data/twitter-comms/images/val_images/val_tweet_image_ids/'+val_data.df.iloc[idx]['filename']).convert('RGB')   
display(raw_image.resize((596, 437)))

# Playground

In [None]:
val_df = pd.read_csv('../data/val_completed.csv', index_col=0)
val_df

In [None]:
img_dir = '/import/network-temp/yimengg/data/twitter-comms/images/val_images/val_tweet_image_ids'
val_df['exists'] = val_df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))

In [None]:
delete_row = val_df[val_df["exists"]==False].index

In [None]:
val_df = val_df.drop(delete_row)

In [None]:
val_df