In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import os
ROOT_DIR = os.path.dirname(os.path.abspath(''))
import sys
sys.path.append(ROOT_DIR)

from src.dataset import CustomDataset
from src.fcm import FCM
import torch

import matplotlib.pyplot as plt
from PIL import Image
import requests
import torch
from src.dataset import CustomDataset
from torch.utils.data import DataLoader
from torchvision import transforms
import pandas as pd

In [None]:
# Load the normalization parameters
means_std_path = os.path.join(ROOT_DIR, "data", "MMHS150K", "means_stds.csv")
means_stds = pd.read_csv(means_std_path)

# Minimal transformation for the images
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[means_stds["mean_red"][0], means_stds["mean_green"][0], means_stds["mean_blue"][0]],
        std=[means_stds["std_red"][0], means_stds["std_green"][0], means_stds["std_blue"][0]]
    ),
])

train_dataset = CustomDataset(
    csv_file=os.path.join(ROOT_DIR, "data", "MMHS150K", "MMHS150K_with_img_text.csv"),
    img_dir=os.path.join(ROOT_DIR, "data", "MMHS150K", "img_resized/"),
    split="train",
    transform=transform
)

eval_dataset = CustomDataset(
    csv_file=os.path.join(ROOT_DIR, "data", "MMHS150K", "MMHS150K_with_img_text.csv"),
    img_dir=os.path.join(ROOT_DIR, "data", "MMHS150K", "img_resized/"),
    split="val",
    transform=transform
)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
# Train the model
def train_epoch(model, optimizer, criterion, metrics, train_loader, tokenizer, device):
    model.train()
    epoch_loss = 0
    epoch_metrics = dict(zip(metrics.keys(), torch.zeros(len(metrics))))
    
    # Zero the gradients
    optimizer.zero_grad()
    
    for i, data_dict in enumerate(train_loader):
        # Get the input data
        image = data_dict["image"].to(device)
        label = data_dict["label"].to(device)
        tweet_text = data_dict["tweet_text"]
        img_text = data_dict["img_text"]
                    
        # Pass the text through the tokenizer and turn it into a tensor
        tweet_text = tokenizer(tweet_text, padding=True, truncation=True, return_tensors="pt")
        img_text = tokenizer(img_text, padding=True, truncation=True, return_tensors="pt")    
        
        # Forward pass
        output = model(image, tweet_text, img_text).squeeze(0)
        output = torch.nn.Sigmoid()(output)
        
        # Compute the loss
        loss = criterion(output, label.float().unsqueeze(1))
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Compute the metrics
        with torch.no_grad():
            predictions = output.argmax(dim=1)
            epoch_loss += loss.item()
            for name, metric in metrics.items():
                epoch_metrics[name] += metric(predictions, label)
                
    epoch_loss /= len(train_loader)
    for k in epoch_metrics.keys():
        epoch_metrics[k] /= len(train_loader)
    
    return epoch_loss, epoch_metrics

# Evaluate the model
def eval_epoch(model, criterion, metrics, val_loader, tokenizer, device):
    model.eval()
    epoch_loss = 0
    epoch_metrics = dict(zip(metrics.keys(), torch.zeros(len(metrics))))
    
    with torch.no_grad():
        for i, data_dict in enumerate(val_loader):
            # Get the input data
            image = data_dict["image"].to(device)
            label = data_dict["label"].to(device)
            tweet_text = data_dict["tweet_text"]
            img_text = data_dict["img_text"]
            
            tweet_text = tokenizer(tweet_text, padding=True, truncation=True, return_tensors="pt")
            img_text = tokenizer(img_text, padding=True, truncation=True, return_tensors="pt")
            
            # Forward pass
            output = model(image, tweet_text, img_text)
            
            # Compute predictions
            predictions = output.argmax(dim=1)
            
            # Compute the loss
            loss = criterion(output, label.float().unsqueeze(1))
            
            # Compute the metrics
            epoch_loss += loss.item()
            for name, metric in metrics.items():
                epoch_metrics[name] += metric(predictions, label)
                
    epoch_loss /= len(val_loader)
    for k in epoch_metrics.keys():
        epoch_metrics[k] /= len(val_loader)
        
    return epoch_loss, epoch_metrics

In [None]:
def plot_training(train_losses, val_losses, train_metrics, val_metrics):
    fig, axs = plt.subplots(2, 2, figsize=(10, 10))
    
    axs[0, 0].plot(train_losses, label="train")
    axs[0, 0].plot(val_losses, label="val")
    axs[0, 0].set_title("Loss")
    axs[0, 0].legend()
    
    for i, (train_metric, val_metric) in enumerate(zip(train_metrics.keys(), val_metrics.keys())):
        axs[1, i].plot(train_metrics[train_metric], label="train")
        axs[1, i].plot(val_metrics[val_metric], label="val")
        axs[1, i].set_title(train_metric)
        axs[1, i].legend()
        
    plt.show()
    
#actually you have to use validation for each step of training, but now we will focus only on the toy example and will track the perfromance on test
def update_metrics_log(metrics_names, metrics_log, new_metrics_dict):
    for i in range(len(metrics_names)):
        curr_metric_name = metrics_names[i]
        metrics_log[i].append(new_metrics_dict[curr_metric_name])
    return metrics_log


def train_cycle(model, optimizer, criterion, metrics, train_loader, test_loader, tokenizer, n_epochs, device):
    train_loss_log,  test_loss_log = [], []
    metrics_names = list(metrics.keys())
    train_metrics_log = [[] for i in range(len(metrics))]
    test_metrics_log = [[] for i in range(len(metrics))]


    for epoch in range(n_epochs):
        print("Epoch {0} of {1}".format(epoch, n_epochs))
        train_loss, train_metrics = train_epoch(model, optimizer, criterion, metrics, train_loader, tokenizer, device)

        test_loss, test_metrics = eval_epoch(model, criterion, metrics, test_loader, tokenizer, device)

        train_loss_log.append(train_loss)
        train_metrics_log = update_metrics_log(metrics_names, train_metrics_log, train_metrics)

        test_loss_log.append(test_loss)
        test_metrics_log = update_metrics_log(metrics_names, test_metrics_log, test_metrics)

        plot_training(train_loss_log, test_loss_log, metrics_names, train_metrics_log, test_metrics_log)
    return train_metrics_log, test_metrics_log


In [None]:
from sklearn.metrics import f1_score, accuracy_score
def f1(preds, target):
    return f1_score(target, preds, average='macro')

def acc(preds, target):
    return accuracy_score(target, preds)

In [None]:
# Choose the tokenization function 
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
# MAYBE LOOK FOR ONE MADE FOR TWEETS

# Create the model
vocab_size = len(tokenizer)
fcm = FCM(device, vocab_size, batch_size, output_size=1, freeze_image_model=True, freeze_text_model=False).to(device)

# Choose the optimizer and the loss function
optimizer = torch.optim.Adam(fcm.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()

# Choose the metrics
metrics = {'ACC': acc, 'F1-weighted': f1}

n_epochs = 5

train_metrics_log, test_metrics_log = train_cycle(fcm, optimizer, criterion, metrics, train_loader, eval_loader, tokenizer, n_epochs, device)

# Save the model
model_weights_dir = os.path.join(ROOT_DIR, "results", "model_weights")
if not os.path.exists(model_weights_dir):
    os.makedirs(model_weights_dir)
torch.save(fcm.state_dict(), os.path.join(model_weights_dir, "fcm.pth"))