In [1]:
import os
import json
import time
import random
import numpy as np
import pandas as pd
import pydicom
from PIL import Image
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings("ignore")

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
import matplotlib.pyplot as plt
import sklearn.metrics as metrics

In [3]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"]="3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
config = dict(
    saved_path="saved_models/efficientb4_cifar10.pt",
    best_saved_path = "saved/random_best.pt",
    lr=0.001, 
    EPOCHS = 3,
    BATCH_SIZE = 32,
    IMAGE_SIZE = 132,
    TRAIN_VALID_SPLIT = 0.2,
    device=device,
    SEED = 42,
    pin_memory=True,
    num_workers=2,
    USE_AMP = True,
    channels_last=False)

In [5]:
random.seed(config['SEED'])
# If you or any of the libraries you are using rely on NumPy, you can seed the global NumPy RNG 
np.random.seed(config['SEED'])
# Prevent RNG for CPU and GPU using torch
torch.manual_seed(config['SEED'])
torch.cuda.manual_seed(config['SEED'])
torch.backends.cudnn.benchmarks = True
torch.backends.cudnn.deterministic = True

torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

In [6]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop((config['IMAGE_SIZE'],config['IMAGE_SIZE'])),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=15),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((config['IMAGE_SIZE'],config['IMAGE_SIZE'])),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((config['IMAGE_SIZE'],config['IMAGE_SIZE'])),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [7]:
path1 = '../../dataset/CrisisMMD_v2.0/CrisisMMD_v2.0/'
classes = os.listdir('../../dataset/CrisisMMD_v2.0/CrisisMMD_v2.0/data_image/')
classes

['iraq_iran_earthquake',
 'srilanka_floods',
 'mexico_earthquake',
 'hurricane_harvey',
 'california_wildfires']

In [8]:
path = '../../dataset/CrisisMMD_v2.0/CrisisMMD_v2.0/crisismmd_datasplit_all/'
os.listdir(path)

['Readme.txt',
 'task_damage_text_img_dev.tsv',
 'task_humanitarian_text_img_train.tsv',
 'task_damage_text_img_test.tsv',
 'task_informative_text_img_test.tsv',
 'task_informative_text_img_train.tsv',
 'task_humanitarian_text_img_test.tsv',
 'task_damage_text_img_train.tsv',
 'task_informative_text_img_dev.tsv',
 '.ipynb_checkpoints',
 'task_humanitarian_text_img_dev.tsv']

In [9]:
df = pd.read_csv(path+'task_humanitarian_text_img_train.tsv',sep = '\t')
df_test = pd.read_csv(path+'task_humanitarian_text_img_test.tsv',sep = '\t')
print(df_test.shape)
print(df.shape)
df.head()

(2237, 9)
(13608, 9)


Unnamed: 0,event_name,tweet_id,image_id,tweet_text,image,label,label_text,label_image,label_text_image
0,california_wildfires,917791291823591425,917791291823591425_1,RT @Cal_OES: PLS SHARE: Weâ€™re capturing wild...,data_image/california_wildfires/10_10_2017/917...,not_humanitarian,other_relevant_information,not_humanitarian,Negative
1,california_wildfires,917791291823591425,917791291823591425_0,RT @Cal_OES: PLS SHARE: Weâ€™re capturing wild...,data_image/california_wildfires/10_10_2017/917...,other_relevant_information,other_relevant_information,infrastructure_and_utility_damage,Negative
2,california_wildfires,917793137925459968,917793137925459968_0,RT @KAKEnews: California wildfires destroy mor...,data_image/california_wildfires/10_10_2017/917...,infrastructure_and_utility_damage,infrastructure_and_utility_damage,infrastructure_and_utility_damage,Positive
3,california_wildfires,917793137925459968,917793137925459968_1,RT @KAKEnews: California wildfires destroy mor...,data_image/california_wildfires/10_10_2017/917...,infrastructure_and_utility_damage,infrastructure_and_utility_damage,infrastructure_and_utility_damage,Positive
4,california_wildfires,917793137925459968,917793137925459968_2,RT @KAKEnews: California wildfires destroy mor...,data_image/california_wildfires/10_10_2017/917...,infrastructure_and_utility_damage,infrastructure_and_utility_damage,infrastructure_and_utility_damage,Positive


In [10]:
# Filtering out required 5 classes
for i in range(len(df)-1,-1,-1):
    d = 1
    for j in classes:
        if j in df['image'][i]:
            d = 0
            continue
    if d: df.drop(i, inplace = True)
    
    if i<len(df_test):
        d = 1
        for j in classes:
            if j in df_test['image'][i]:
                d = 0
                continue
        if d: df_test.drop(i, inplace = True)
print(df.shape)
print(df_test.shape)

(6765, 9)
(1160, 9)


In [11]:
from transformers import BertTokenizer, VisualBertModel
# Initialize the VilBERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize_text(text):
    # Tokenize the text and add special tokens
    tokenized_text = tokenizer.encode_plus(text, add_special_tokens=True, max_length=128, truncation=True)
    return tokenized_text
tokenized = df['tweet_text'].apply(tokenize_text)
tokenized_test = df_test['tweet_text'].apply(tokenize_text)

def f(x):
    x = x['input_ids']
    l = len(x)
    if l==45:
        return x
    if l>45:
        return x[-45:]
    if l<45:
        return [0]*(45-l)+x

df['tweet_text'] = tokenized.apply(f)
df_test['tweet_text'] = tokenized_test.apply(f)
df.head()

Unnamed: 0,event_name,tweet_id,image_id,tweet_text,image,label,label_text,label_image,label_text_image
0,california_wildfires,917791291823591425,917791291823591425_1,"[2050, 30102, 30108, 2890, 11847, 3748, 10273,...",data_image/california_wildfires/10_10_2017/917...,not_humanitarian,other_relevant_information,not_humanitarian,Negative
1,california_wildfires,917791291823591425,917791291823591425_0,"[2050, 30102, 30108, 2890, 11847, 3748, 10273,...",data_image/california_wildfires/10_10_2017/917...,other_relevant_information,other_relevant_information,infrastructure_and_utility_damage,Negative
2,california_wildfires,917793137925459968,917793137925459968_0,"[1024, 2662, 3748, 26332, 6033, 2062, 2084, 27...",data_image/california_wildfires/10_10_2017/917...,infrastructure_and_utility_damage,infrastructure_and_utility_damage,infrastructure_and_utility_damage,Positive
3,california_wildfires,917793137925459968,917793137925459968_1,"[1024, 2662, 3748, 26332, 6033, 2062, 2084, 27...",data_image/california_wildfires/10_10_2017/917...,infrastructure_and_utility_damage,infrastructure_and_utility_damage,infrastructure_and_utility_damage,Positive
4,california_wildfires,917793137925459968,917793137925459968_2,"[1024, 2662, 3748, 26332, 6033, 2062, 2084, 27...",data_image/california_wildfires/10_10_2017/917...,infrastructure_and_utility_damage,infrastructure_and_utility_damage,infrastructure_and_utility_damage,Positive


In [12]:
from sklearn.preprocessing import LabelEncoder
enc = LabelEncoder()
enc.fit(df['label_text'])
df['label_text'] = enc.transform(df['label_text'])
df_test['label_text'] = enc.transform(df_test['label_text'])


enc = LabelEncoder()
enc.fit(df['label_image'])
df['label_image'] = enc.transform(df['label_image'])
df_test['label_image'] = enc.transform(df_test['label_image'])

df.head()

Unnamed: 0,event_name,tweet_id,image_id,tweet_text,image,label,label_text,label_image,label_text_image
0,california_wildfires,917791291823591425,917791291823591425_1,"[2050, 30102, 30108, 2890, 11847, 3748, 10273,...",data_image/california_wildfires/10_10_2017/917...,not_humanitarian,5,4,Negative
1,california_wildfires,917791291823591425,917791291823591425_0,"[2050, 30102, 30108, 2890, 11847, 3748, 10273,...",data_image/california_wildfires/10_10_2017/917...,other_relevant_information,5,1,Negative
2,california_wildfires,917793137925459968,917793137925459968_0,"[1024, 2662, 3748, 26332, 6033, 2062, 2084, 27...",data_image/california_wildfires/10_10_2017/917...,infrastructure_and_utility_damage,1,1,Positive
3,california_wildfires,917793137925459968,917793137925459968_1,"[1024, 2662, 3748, 26332, 6033, 2062, 2084, 27...",data_image/california_wildfires/10_10_2017/917...,infrastructure_and_utility_damage,1,1,Positive
4,california_wildfires,917793137925459968,917793137925459968_2,"[1024, 2662, 3748, 26332, 6033, 2062, 2084, 27...",data_image/california_wildfires/10_10_2017/917...,infrastructure_and_utility_damage,1,1,Positive


In [13]:
class CustomDataset(Dataset):
    def __init__(self, df):
        self.df = df
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        #print(idx)
        # Load image from path and convert to tensor
        img_path = path1+self.df.iloc[idx]['image']
        img = Image.open(img_path).convert('RGB')
        img_tensor = data_transforms['test'](img)
        
        # Load text and convert to tensor
        text = self.df.iloc[idx]['tweet_text']
        text_tensor = torch.tensor(text)
        
        # Load labels and convert to tensor
        img_label = self.df.iloc[idx]['label_image']
        img_label_tensor = torch.tensor(img_label)
        
        text_label = self.df.iloc[idx]['label_text']
        text_label_tensor = torch.tensor(text_label)
        
        return (img_tensor, text_tensor), (img_label_tensor, text_label_tensor)
    
train_data = CustomDataset(df)
test_data = CustomDataset(df_test)
valid_data = test_data

print(len(train_data), len(valid_data))
train_dl = torch.utils.data.DataLoader(train_data, batch_size=32,shuffle=True, num_workers = config['num_workers'],
                                          pin_memory = config['pin_memory'])
test_dl = torch.utils.data.DataLoader(test_data, batch_size=32,shuffle=True, num_workers = config['num_workers'],
                                          pin_memory = config['pin_memory'])
valid_dl = torch.utils.data.DataLoader(valid_data, batch_size=32,shuffle=True, num_workers = config['num_workers'],
                                          pin_memory = config['pin_memory'])

6765 1160


In [14]:
a = iter(valid_dl)
b = next(a)
print(b[0][0].shape, b[0][1].shape)
print(b[1][0], b[1][1])

torch.Size([32, 3, 132, 132]) torch.Size([32, 45])
tensor([4, 6, 4, 1, 4, 4, 1, 4, 4, 0, 1, 1, 1, 4, 4, 0, 5, 1, 4, 4, 4, 6, 0, 4,
        4, 4, 4, 4, 1, 6, 6, 4]) tensor([4, 5, 4, 5, 4, 4, 6, 4, 4, 5, 6, 2, 5, 4, 4, 5, 0, 7, 4, 5, 1, 6, 6, 6,
        4, 6, 1, 4, 6, 6, 6, 4])


# Finetune Pretrained model

In [15]:
from torchvision.models import efficientnet_b0
from transformers import BertModel

class MultiTaskModel(nn.Module):
    def __init__(self, num_image_classes, num_text_classes):
        super(MultiTaskModel, self).__init__()
        
        # Define the image classification model using EfficientNet-B0
        efficientnet = efficientnet_b0(pretrained=True)
        efficientnet.classifier[1] = nn.Linear(in_features = 1280, out_features = num_image_classes, bias = True)
        self.image_model = efficientnet
        
        # Define the text classification model using BERT
        self.text_model = BertModel.from_pretrained('bert-base-uncased')
        self.text_classifier = nn.Linear(768, num_text_classes)
        
    def forward(self, img_input, text_input):
        # Forward pass for the image classification model
        img_output = self.image_model(img_input)
        
        # Forward pass for the text classification model
        text_output = self.text_model(text_input)[1]
        text_output = self.text_classifier(text_output)
        
        return img_output, text_output

In [16]:
model = MultiTaskModel(num_image_classes=8, num_text_classes=8)
model.to(device)

# Define the loss functions for each task
img_criterion = nn.CrossEntropyLoss()
text_criterion = nn.CrossEntropyLoss()

# Define the optimizer and learning rate scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
# Train the model

def train(model, num_epochs = 10):
    for epoch in range(num_epochs):
        # Training loop
        model.train()
        total_img_loss = 0
        total_text_loss = 0
        total_img_correct = 0
        total_text_correct = 0
        for batch_idx, batch_data in enumerate(train_dl):
            # Get the image and text inputs and labels for this batch
            img_inputs = batch_data[0][0].to(device)
            text_inputs = batch_data[0][1].to(device)
            img_labels = batch_data[1][0].to(device)
            text_labels = batch_data[1][1].to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            img_outputs, text_outputs = model(img_inputs, text_inputs)
            img_loss = img_criterion(img_outputs, img_labels)
            text_loss = text_criterion(text_outputs, text_labels)
            loss = img_loss + text_loss

            # Backward pass
            loss.backward()
            optimizer.step()

            # Accumulate the loss for this batch
            total_img_loss += img_loss.item()
            total_text_loss += text_loss.item()

            img_preds = torch.argmax(img_outputs, dim=1)
            text_preds = torch.argmax(text_outputs, dim=1)
            total_img_correct += (img_preds == img_labels).sum().item()
            total_text_correct += (text_preds == text_labels).sum().item()

        # Compute the accuracy for this epoch
        img_accuracy = 100.0 * total_img_correct / len(train_data)
        text_accuracy = 100.0 * total_text_correct / len(train_data)

        # Compute the average loss for this epoch
        avg_img_loss = total_img_loss / len(train_dl)
        avg_text_loss = total_text_loss / len(train_dl)
        print(f'Train Epoch: {epoch+1} Avg. Image Loss: {avg_img_loss:.4f} Avg. Text Loss: {avg_text_loss:.4f}')
        print(f'Image Accuracy: {img_accuracy}  Text Accuracy: {text_accuracy}')


        # Validation loop
        model.eval()
        total_img_correct = 0
        total_text_correct = 0
        with torch.no_grad():
            for batch_idx, batch_data in enumerate(valid_dl):
                # Get the image and text inputs and labels for this batch
                img_inputs = batch_data[0][0].to(device)
                text_inputs = batch_data[0][1].to(device)
                img_labels = batch_data[1][0].to(device)
                text_labels = batch_data[1][1].to(device)

                # Forward pass
                img_outputs, text_outputs = model(img_inputs, text_inputs)
                img_preds = torch.argmax(img_outputs, dim=1)
                text_preds = torch.argmax(text_outputs, dim=1)
                total_img_correct += (img_preds == img_labels).sum().item()
                total_text_correct += (text_preds == text_labels).sum().item()

        # Compute the accuracy for this epoch
        img_accuracy = 100.0 * total_img_correct / len(valid_data)
        text_accuracy = 100.0 * total_text_correct / len(valid_data)
        print(f'Validation data:\nImage Accuracy: {img_accuracy}  Text Accuracy: {text_accuracy}\n')
        # Step the learning rate scheduler
        scheduler.step()

train(model)

Train Epoch: 1 Avg. Image Loss: 1.1108 Avg. Text Loss: 1.6466
Image Accuracy: 60.916481892091646  Text Accuracy: 29.652623798965262
Validation data:
Image Accuracy: 68.44827586206897  Text Accuracy: 32.8448275862069

Train Epoch: 2 Avg. Image Loss: 0.8520 Avg. Text Loss: 1.6230
Image Accuracy: 69.53436807095343  Text Accuracy: 29.785661492978566
Validation data:
Image Accuracy: 69.91379310344827  Text Accuracy: 20.43103448275862

Train Epoch: 3 Avg. Image Loss: 0.6984 Avg. Text Loss: 1.6175
Image Accuracy: 75.28455284552845  Text Accuracy: 29.416112342941613
Validation data:
Image Accuracy: 67.41379310344827  Text Accuracy: 32.8448275862069

Train Epoch: 4 Avg. Image Loss: 0.5734 Avg. Text Loss: 1.6102
Image Accuracy: 80.10347376201035  Text Accuracy: 30.49519586104952
Validation data:
Image Accuracy: 68.1896551724138  Text Accuracy: 32.8448275862069

Train Epoch: 5 Avg. Image Loss: 0.4676 Avg. Text Loss: 1.6164
Image Accuracy: 82.64597191426459  Text Accuracy: 31.175166297117517
Valid

# Training the model from scratch

In [19]:
class ViLBERT(nn.Module):
    def __init__(self, num_image_classes, num_text_classes):
        super(MultiTaskModel, self).__init__()
        
        # Define the image classification model using EfficientNet-B0
        efficientnet = efficientnet_b0(pretrained=False)
        efficientnet.classifier[1] = nn.Linear(in_features = 1280, out_features = num_image_classes, bias = True)
        self.image_model = efficientnet
        
        # Define the text classification model using BERT
        self.text_model = BertModel
        self.text_classifier = nn.Linear(768, num_text_classes)
        
    def forward(self, img_input, text_input):
        # Forward pass for the image classification model
        img_output = self.image_model(img_input)
        
        # Forward pass for the text classification model
        text_output = self.text_model(text_input)[1]
        text_output = self.text_classifier(text_output)
        
        return img_output, text_output
    
    
model = MultiTaskModel(num_image_classes=8, num_text_classes=8)
model.to(device)

# Define the loss functions for each task
img_criterion = nn.CrossEntropyLoss()
text_criterion = nn.CrossEntropyLoss()

# Define the optimizer and learning rate scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [20]:
train(model)

Train Epoch: 1 Avg. Image Loss: 1.0997 Avg. Text Loss: 1.7276
Image Accuracy: 61.286031042128606  Text Accuracy: 29.534368070953438
Validation data:
Image Accuracy: 65.43103448275862  Text Accuracy: 32.8448275862069

Train Epoch: 2 Avg. Image Loss: 0.8684 Avg. Text Loss: 1.6187
Image Accuracy: 68.76570583887657  Text Accuracy: 30.554323725055433
Validation data:
Image Accuracy: 66.72413793103448  Text Accuracy: 32.8448275862069

Train Epoch: 3 Avg. Image Loss: 0.7124 Avg. Text Loss: 1.6073
Image Accuracy: 75.24020694752402  Text Accuracy: 29.59349593495935
Validation data:
Image Accuracy: 70.0  Text Accuracy: 32.8448275862069

Train Epoch: 4 Avg. Image Loss: 0.5802 Avg. Text Loss: 1.6071
Image Accuracy: 79.58610495195862  Text Accuracy: 30.199556541019955
Validation data:
Image Accuracy: 68.44827586206897  Text Accuracy: 20.43103448275862

Train Epoch: 5 Avg. Image Loss: 0.4756 Avg. Text Loss: 1.6150
Image Accuracy: 83.39985218033999  Text Accuracy: 30.066518847006652
Validation data:
