In [None]:
import pandas as pd
import os
import numpy as np
import warnings
from tqdm import tqdm
import transformers
from transformers import BertTokenizer

warnings.filterwarnings("ignore") 
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
from torch.nn.utils import weight_norm

import torchvision
from torchvision import transforms, models, datasets
import PIL
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
from transformers import BertModel, AdamW, get_linear_schedule_with_warmup

ImageFile.LOAD_TRUNCATED_IMAGES = True
PIL.Image.MAX_IMAGE_PIXELS = 180960000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Dataset Config
DATASET_NUM_CLASSES = 2
DATASET_SIZE = "small"
DATASET_BASE_DIR = "../data/processed_data/" + str(DATASET_NUM_CLASSES) + "_" + DATASET_SIZE
IMAGE_BASE_DIR = DATASET_BASE_DIR + "/images/"
DATA_TYPE = "text_comments" 

# BERT Config
PRE_TRAINED_MODEL_NAME = "bert-base-cased"
MAX_LEN = 512

BATCH_SIZE = 8
NUM_EPOCHES = 9
FUSION_TYPE = "add"  # "cat", "add", "max", "avg"
SAVE_PATH = "./trained_models/" + str(DATASET_NUM_CLASSES) + "_" + DATASET_SIZE
MODEL_PATH = SAVE_PATH + "/multimodal_models_" + DATA_TYPE + "_" + FUSION_TYPE + ".pth"
CONFUSION_MATRIX_PATH = SAVE_PATH + "/confusion_matrix_" + DATA_TYPE + "_" + FUSION_TYPE + ".png"

# seed = 20
# torch.manual_seed(seed) 

In [None]:
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

In [None]:
train_data_raw = pd.read_csv(DATASET_BASE_DIR + "/train_data.csv")
test_data_raw = pd.read_csv(DATASET_BASE_DIR + "/test_data.csv")

In [None]:
train_data = train_data_raw[["id", DATA_TYPE, str(DATASET_NUM_CLASSES) + "_way_label"]]
test_data = test_data_raw[["id", DATA_TYPE, str(DATASET_NUM_CLASSES) + "_way_label"]]

In [None]:
train_data.rename(columns = {DATA_TYPE : "text", str(DATASET_NUM_CLASSES) + "_way_label" : "label"}, inplace = True)
test_data.rename(columns = {DATA_TYPE : "text", str(DATASET_NUM_CLASSES) + "_way_label" : "label"}, inplace = True)

In [None]:
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

In [None]:
class FakedditMultiModalDataset(Dataset):
    def __init__(self, texts, labels, ids, tokenizer, data_type, max_len, transform):
        super(FakedditMultiModalDataset, self).__init__()
        self.texts = texts
        self.labels = labels
        self.ids = ids
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform
        self.images = [IMAGE_BASE_DIR + data_type + "/" + each_id + ".jpg" for each_id in self.ids]
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer.encode_plus(text, max_length = self.max_len, padding = "max_length",
                                              truncation=True, return_tensors="pt")
        image = Image.open(self.images[idx])
        image = image.convert("RGB")
        image = self.transform(image)
        label = torch.tensor(self.labels[idx], dtype = torch.long)
        
        return {
            "attention_mask": encoding["attention_mask"].flatten(),
            "input_ids": encoding["input_ids"].flatten(),
            "image": image,
            "label": label
        }

In [None]:
def create_dataloader(df, tokenizer, data_type, max_len, transform, batch_size):
    ds = FakedditMultiModalDataset(
        texts = df.text.to_numpy(),
        labels = df.label.to_numpy(),
        ids = df.id.to_numpy(),
        tokenizer = tokenizer,
        data_type = data_type,
        max_len = max_len,
        transform = transform
    )
    
    return DataLoader(ds, batch_size = batch_size, shuffle = True)

In [None]:
data_transforms = {
    "train":
        transforms.Compose([transforms.Resize([224, 224]),
                            transforms.ToTensor()]),
    "test":
        transforms.Compose([transforms.Resize([224, 224]),
                           transforms.ToTensor()])
}

In [None]:
train_dataloader = create_dataloader(train_data, tokenizer, "train", MAX_LEN, data_transforms["train"], BATCH_SIZE)
test_dataloader = create_dataloader(test_data, tokenizer, "test", MAX_LEN, data_transforms["test"], BATCH_SIZE)

In [None]:
sample_data = next(iter(train_dataloader))

In [None]:
sample_data["input_ids"].shape

In [None]:
pic = sample_data["image"][0].squeeze()
pic = pic.permute((1,2,0)).numpy()
plt.imshow(pic)

In [None]:
class MultiModalForClassification(nn.Module):
    def __init__(self, num_classes, fusion_type):
        super(MultiModalForClassification, self).__init__()
        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        self.drop = nn.Dropout(p = 0.3)
        self.image_model = models.resnet50(pretrained = True)
        self.fusion_type = fusion_type
        in_features = self.image_model.fc.in_features
        self.image_model.fc = nn.Linear(in_features, self.bert.config.hidden_size)
        
        if self.fusion_type == "cat":
            self.fc = nn.Linear(self.bert.config.hidden_size * 2, num_classes)
        else:
            self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids = None, attention_mask = None, image = None):
        bert_outputs = self.bert(input_ids, attention_mask = attention_mask)
        image_outputs = self.image_model(image)
        if self.fusion_type == "cat":
            fused_outputs = torch.cat([bert_outputs[1], image_outputs], dim = 1)
        elif self.fusion_type == "add":
            fused_outputs = bert_outputs[1] + image_outputs
        pooled_outputs = self.drop(fused_outputs)
        out = self.fc(pooled_outputs)
        return out

In [None]:
def train_epoch(model, data_loader, loss_fn, optimizer, scheduler, device, n_examples):
    model = model.train()
    total_losses = 0
    correct_count = 0
    progress_bar = tqdm(enumerate(data_loader), total = len(data_loader))
    for idx, data in progress_bar:
        inputs = {
            "input_ids": data["input_ids"].to(device), 
            "attention_mask": data["attention_mask"].to(device),
            "image": data["image"].to(device)
        }
        targets = data["label"].to(device)
        outputs = model(**inputs)
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, targets)
        correct_count += torch.sum(preds == targets)
        loss.backward()
        total_losses += loss.item()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        progress_bar.set_description(f"loss:{loss.item():.4f}")

    return correct_count.double() / n_examples, total_losses / len(data_loader)

In [None]:
from sklearn.metrics import classification_report, f1_score
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt

def eval_model(model, data_loader, device, phase = "train"):
    model = model.eval()
    progress_bar = tqdm(enumerate(data_loader), total=len(data_loader))
    all_targets = []
    all_predictions = []
    with torch.no_grad():
        for idx, data in progress_bar:
            inputs = {
                "input_ids": data["input_ids"].to(device), 
                "attention_mask": data["attention_mask"].to(device),
                "image": data["image"].to(device)
            }
            targets = data["label"].to(device)
            probs = model(**inputs)
            _, preds = torch.max(probs, dim = 1)
            all_targets += targets.cpu().numpy().tolist()
            all_predictions += preds.cpu().numpy().tolist()
        
    f1_metrics = f1_score(np.array(all_targets).reshape(-1), np.array(all_predictions).reshape(-1), average='weighted')
    print(classification_report(all_targets, all_predictions, digits = 5))
    if phase == "eval":
        ConfusionMatrixDisplay.from_predictions(all_targets, all_predictions, cmap = "GnBu")
        plt.savefig(CONFUSION_MATRIX_PATH)
        plt.show()
        
    return f1_metrics

In [None]:
model = MultiModalForClassification(DATASET_NUM_CLASSES, FUSION_TYPE)
model = model.to(device)
optimizer = AdamW(model.parameters(), lr = 2e-5, correct_bias = False)
total_steps = len(train_dataloader) * NUM_EPOCHES
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps = total_steps)
loss_fn = nn.CrossEntropyLoss().to(device)

In [None]:
model.zero_grad()
prev_best = 0

for epoch in range(NUM_EPOCHES):
    print(f'Epoch {epoch+1}/{NUM_EPOCHES}')
    print('-' * 10)

    train_acc, train_loss = train_epoch(model, train_dataloader, loss_fn, optimizer, scheduler, device, len(train_data))
    print('\n')
    print(f"Train loss: {train_loss}  Accuracy: {train_acc}")
    test_metrics = eval_model(model, test_dataloader, device)
    print(f"Test F1: {test_metrics}")
    if test_metrics > prev_best:
        prev_best = test_metrics
        torch.save(model, MODEL_PATH)

In [None]:
# eval_model(model, test_dataloader, device, phase = "eval")

In [None]:
# e_model = torch.load(MODEL_PATH, map_location=torch.device(device))

In [None]:
# eval_model(e_model, test_dataloader, device)