In [14]:
import torch
from PIL import Image
import requests

from torch import nn
import pandas as pd
import os

from tqdm.auto import tqdm, trange

from sklearn.metrics import classification_report
import json

from torch.utils.data import Dataset
import torch.utils.data as data
from torch import optim
from torch.autograd import Variable
import numpy as np

# from transformers import CLIPProcessor, CLIPModel
from sentence_transformers import SentenceTransformer, util

In [40]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1,2,3"

# Load the model

In [2]:
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [24]:
clip_img = SentenceTransformer('clip-ViT-B-32')
clip_txt = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1')

Downloading (…)24087/.gitattributes: 100%|██████| 690/690 [00:00<00:00, 171kB/s]
Downloading (…)_Pooling/config.json: 100%|█████| 190/190 [00:00<00:00, 53.7kB/s]
Downloading (…)/2_Dense/config.json: 100%|█████| 115/115 [00:00<00:00, 32.5kB/s]
Downloading pytorch_model.bin: 100%|███████| 1.57M/1.57M [00:00<00:00, 26.2MB/s]
Downloading (…)9a46024087/README.md: 100%|█| 5.63k/5.63k [00:00<00:00, 1.42MB/s]
Downloading (…)46024087/config.json: 100%|██████| 572/572 [00:00<00:00, 176kB/s]
Downloading (…)ce_transformers.json: 100%|█████| 122/122 [00:00<00:00, 33.3kB/s]
Downloading pytorch_model.bin: 100%|█████████| 539M/539M [00:09<00:00, 56.7MB/s]
Downloading (…)nce_bert_config.json: 100%|███| 53.0/53.0 [00:00<00:00, 15.0kB/s]
Downloading (…)cial_tokens_map.json: 100%|█████| 112/112 [00:00<00:00, 37.1kB/s]
Downloading (…)24087/tokenizer.json: 100%|█| 1.96M/1.96M [00:00<00:00, 4.97MB/s]
Downloading (…)okenizer_config.json: 100%|██████| 371/371 [00:00<00:00, 109kB/s]
Downloading (…)9a46024087/vo

# Define the Dataset class

In [54]:
class TwitterCOMMsDataset(Dataset):
    def __init__(self, csv_path, img_dir):
        """
        Args:
            csv_path (string): Path to the {train_completed|val_completed}.csv file.
            image_folder_dir (string): Directory containing the images
        """
        self.df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        
        self.df['exists'] = self.df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))
        delete_row = self.df[self.df["exists"]==False].index
        self.df = self.df.drop(delete_row)
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        caption = item['full_text']
        img_filename = item['filename']
        topic = item['topic']
        falsified = int(item['falsified'])
        not_falsified = float(not item['falsified'])
        label = np.array(falsified)
        domain = topic.split('_')[0]
        diff = topic.split('_')[1]
        
        raw_image = Image.open(os.path.join(self.img_dir, img_filename))#.convert('RGB')
        img_emb = torch.Tensor(clip_img.encode(raw_image))
        txt_emb = torch.Tensor(clip_txt.encode(caption))
        
        multimodal_emb = torch.cat((img_emb, txt_emb))
        
        cos_sim = util.cos_sim(txt_emb, img_emb)

        return {"multimodal_emb": multimodal_emb,
                "topic": topic, 
                "label": label, 
                "domain": domain, 
                "difficulty": diff,
               "similarity":cos_sim}
        
        
        
    

In [16]:
# class ClimateAndCovidDataset(Dataset):
#     def __init__(self, csv_path, img_dir):
#         """
#         Args:
#             csv_path (string): Path to the {train_completed|val_completed}.csv file.
#             image_folder_dir (string): Directory containing the images
#         """
#         self.df = pd.read_csv(csv_path, index_col=0)
#         self.img_dir = img_dir
        
#         self.df['exists'] = self.df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))
#         delete_row = self.df[self.df["exists"]==False].index
#         self.df = self.df.drop(delete_row)
        
#         self.df['is_military'] = self.df['topic'].apply(lambda topic: 'military' in topic)
#         delete_row = self.df[self.df["is_military"]==True].index
#         self.df = self.df.drop(delete_row)
    
#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):
#         item = self.df.iloc[idx]
#         caption = item['full_text']
#         img_filename = item['filename']
#         topic = item['topic']
#         falsified = float(item['falsified'])
#         not_falsified = float(not item['falsified'])
#         label = np.array((falsified, not_falsified))
#         domain = topic.split('_')[0]
#         diff = topic.split('_')[1]
        
#         try:
#             raw_image = Image.open(os.path.join(self.img_dir, img_filename)).convert('RGB')
#             image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
#             text_input = txt_processors["eval"](caption)
#             sample = {"image": image, "text_input": [text_input]}   # image shape: [1, 3, 224, 224]
        
#             features_multimodal = model.extract_features(sample, mode="multimodal")
# #             features_image = model.extract_features(sample, mode="image")
# #             features_text = model.extract_features(sample, mode="text")
# #             print(features_multimodal.multimodal_embeds[:, 0, :].shape)
        
#             return {"multimodal_emb": features_multimodal.multimodal_embeds[:, 0, :],
#                     "topic": topic, 
#                     "label": label, 
#                     "domain": domain, 
#                     "difficulty": diff}
        
#         except IOError as e:
#             print(e)
        
        
        
    

In [55]:
# train_data = TwitterCOMMsDataset(csv_path='../data/train_completed.csv',
#                                     img_dir='/import/network-temp/yimengg/data/twitter-comms/train/images/train_image_ids')   # took ~one hour to construct the dataset
val_data = TwitterCOMMsDataset(csv_path='../data/val_completed.csv', 
                               img_dir='/import/network-temp/yimengg/data/twitter-comms/images/val_images/val_tweet_image_ids')

In [None]:
train_data.__len__()

In [56]:
BATCH_SIZE = 64

# train_iterator = data.DataLoader(train_data, 
#                                  shuffle = True, 
#                                  batch_size=BATCH_SIZE)
val_iterator = data.DataLoader(val_data, 
                               shuffle = False, 
                               batch_size=BATCH_SIZE)

In [50]:
class Net(nn.Module):
    def __init__(self, in_dim, out_dim=2):
        super(Net, self).__init__()
        
        self.fc = nn.Linear(in_dim, out_dim)
        self.in_dim = in_dim
    
    def forward(self, x):
        x = x.view(-1, self.in_dim)
        out = self.fc(x)
        return out
    
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

In [51]:
def normal_init(m, mean, std):
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [52]:
def main():
#     net = Net(256)
    net = Net(1024)
    net.cuda()
    net.train()
    net.weight_init(mean=0, std=0.02)
    
    lr = 0.0001
    optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
    
    criterion = nn.CrossEntropyLoss()
    criterion.to(device)
 
    EPOCHS = 2
    for epoch in range(EPOCHS):
        total_loss = 0
        num_correct = 0
        total = 0
        for i, batch in tqdm(enumerate(val_iterator, 0), desc='iterations'):
            inputs = batch["multimodal_emb"].to(device)
            labels = batch["label"].to(device)
            inputs, labels = Variable(inputs), Variable(labels)
            
            net.zero_grad()
            y_preds = net(inputs)
            loss = criterion(y_preds, labels)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            _, top_pred = y_preds.topk(1, 1)
            y = labels.cpu()
            batch_size = y.shape[0]
            top_pred = top_pred.cpu().view(batch_size)
            
#             num_correct += sum(top_pred == y[:, 0]).item()
            num_correct += sum(top_pred == y).item()
            total += batch_size
            
            if i % 50 == 0:
                print("Epoch [%d/%d]: Training accuracy %.2f" % (epoch+1, EPOCHS, num_correct/total))
#                 print(y_preds)
#                 print(top_pred)
#                 print(labels)
#                 print(num_correct)
#                 print(total)

    return net

In [53]:
net = main()

iterations: 1it [00:04,  4.58s/it]

Epoch [1/2]: Training accuracy 0.48


  "Palette images with Transparency expressed in bytes should be "
iterations: 51it [16:45, 18.91s/it]

Epoch [1/2]: Training accuracy 0.50


iterations: 101it [34:07, 19.01s/it]

Epoch [1/2]: Training accuracy 0.50


iterations: 151it [50:49, 18.99s/it]

Epoch [1/2]: Training accuracy 0.51


iterations: 201it [1:07:40, 25.69s/it]

Epoch [1/2]: Training accuracy 0.51


iterations: 251it [1:23:17, 15.21s/it]

Epoch [1/2]: Training accuracy 0.52


iterations: 301it [1:38:34, 20.41s/it]

Epoch [1/2]: Training accuracy 0.52


iterations: 346it [1:53:19, 19.65s/it]
iterations: 1it [00:16, 16.10s/it]

Epoch [2/2]: Training accuracy 0.53


iterations: 51it [17:31, 22.47s/it]

Epoch [2/2]: Training accuracy 0.50


iterations: 101it [34:51, 23.44s/it]

Epoch [2/2]: Training accuracy 0.50


iterations: 151it [51:26, 17.35s/it]

Epoch [2/2]: Training accuracy 0.51


iterations: 201it [1:08:35, 21.79s/it]

Epoch [2/2]: Training accuracy 0.52


iterations: 251it [1:25:07, 18.80s/it]

Epoch [2/2]: Training accuracy 0.52


iterations: 301it [1:41:24, 23.57s/it]

Epoch [2/2]: Training accuracy 0.52


iterations: 346it [1:56:16, 20.16s/it]


In [1]:
!nvidia-smi

Wed Apr 26 17:12:12 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 455.45.01    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  On   | 00000000:3B:00.0 Off |                    0 |
| N/A   30C    P0    41W / 250W |   1401MiB / 16280MiB |     24%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  On   | 00000000:D8:00.0 Off |                    0 |
| N/A   29C    P0    36W / 250W |   1401MiB / 16280MiB |      0%      Default |
|       

In [58]:
for i, batch in tqdm(enumerate(val_iterator, 0), desc='iterations'):
    num_correct = 0
    total = 0
    similarity = batch["similarity"].squeeze()
    labels = batch["label"]
    y_preds = torch.zeros_like(labels)
    y_preds[similarity < 0.5] = 1
    
    num_correct += sum(y_preds == labels).item()
    total += BATCH_SIZE
    
    if i % 50 == 0:
        print("Accuracy %.2f" % (num_correct/total))
        print(y_preds)
        print(similarity)

iterations: 1it [00:15, 15.71s/it]

Accuracy 0.47
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([0.3061, 0.3061, 0.2678, 0.2001, 0.2551, 0.2281, 0.2758, 0.2758, 0.1983,
        0.3132, 0.3132, 0.2053, 0.2629, 0.2629, 0.2473, 0.2695, 0.2587, 0.2564,
        0.2564, 0.1994, 0.2719, 0.2646, 0.2646, 0.2770, 0.3539, 0.3539, 0.2157,
        0.2358, 0.1939, 0.1811, 0.1811, 0.1838, 0.2478, 0.2478, 0.2072, 0.2465,
        0.2429, 0.2429, 0.2552, 0.2552, 0.2041, 0.2322, 0.2790, 0.2790, 0.2184,
        0.2183, 0.1556, 0.1934, 0.1934, 0.2562, 0.2562, 0.2291, 0.1822, 0.2427,
        0.2275, 0.2549, 0.2549, 0.1770, 0.1963, 0.1963, 0.2049, 0.3401, 0.3401,
        0.2143])


iterations: 51it [17:34, 24.11s/it]

Accuracy 0.50
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([0.2851, 0.2205, 0.2689, 0.2689, 0.1631, 0.1982, 0.3063, 0.3063, 0.2997,
        0.2997, 0.2142, 0.2050, 0.2959, 0.2734, 0.2734, 0.2100, 0.2261, 0.2261,
        0.1919, 0.2417, 0.1648, 0.3318, 0.3318, 0.2841, 0.2549, 0.2549, 0.2302,
        0.1957, 0.2503, 0.2163, 0.2956, 0.2956, 0.2004, 0.3533, 0.3533, 0.2714,
        0.2528, 0.2528, 0.3032, 0.1613, 0.2361, 0.2076, 0.3184, 0.3184, 0.2328,
        0.2299, 0.2299, 0.2514, 0.3112, 0.2517, 0.2517, 0.1720, 0.2967, 0.2967,
        0.3307, 0.2791, 0.3305, 0.3305, 0.2152, 0.2024, 0.2024, 0.2003, 0.1992,
        0.1720])


iterations: 101it [34:36, 23.91s/it]

Accuracy 0.45
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([0.3153, 0.3153, 0.3147, 0.3147, 0.1960, 0.2577, 0.2428, 0.2428, 0.2171,
        0.2043, 0.2385, 0.2085, 0.2085, 0.1882, 0.2924, 0.2497, 0.2497, 0.2368,
        0.2368, 0.1856, 0.1715, 0.2354, 0.2743, 0.2743, 0.1117, 0.2002, 0.2892,
        0.2892, 0.2220, 0.2421, 0.2572, 0.2572, 0.2047, 0.2430, 0.2430, 0.1829,
        0.3110, 0.3110, 0.2112, 0.2710, 0.2924, 0.2924, 0.2922, 0.2922, 0.2381,
        0.1732, 0.2078, 0.2104, 0.2753, 0.2753, 0.3605, 0.3401, 0.3401, 0.2484,
        0.3348, 0.3348, 0.1513, 0.2666, 0.2836, 0.1805, 0.2354, 0.2354, 0.1934,
        0.2605])


iterations: 123it [42:27, 20.72s/it]


KeyboardInterrupt: 

# Playground

In [None]:
val_df = pd.read_csv('../data/val_completed.csv', index_col=0)
val_df

In [None]:
img_dir = '/import/network-temp/yimengg/data/twitter-comms/images/val_images/val_tweet_image_ids'
val_df['exists'] = val_df['filename'].apply(lambda filename: os.path.exists(os.path.join(img_dir, filename)))

In [None]:
delete_row = val_df[val_df["exists"]==False].index

In [None]:
val_df = val_df.drop(delete_row)

In [None]:
val_df