In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import os
ROOT_DIR = os.path.dirname(os.path.abspath(''))
import sys
sys.path.append(ROOT_DIR)

from src.dataset import CustomDataset
from src.fcm import FCM
from src.fcm import train_epoch, eval_epoch, acc, f1
import torch

import matplotlib.pyplot as plt
from PIL import Image
import requests
import torch
from src.dataset import CustomDataset
from torch.utils.data import DataLoader
from torchvision import transforms
import pandas as pd

In [None]:
# Load the normalization parameters
means_std_path = os.path.join(ROOT_DIR, "data", "MMHS150K", "means_stds.csv")
means_stds = pd.read_csv(means_std_path)

# Minimal transformation for the images
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[means_stds["mean_red"][0], means_stds["mean_green"][0], means_stds["mean_blue"][0]],
        std=[means_stds["std_red"][0], means_stds["std_green"][0], means_stds["std_blue"][0]]
    ),
])

train_dataset = CustomDataset(
    csv_file=os.path.join(ROOT_DIR, "data", "MMHS150K", "MMHS150K_with_img_text.csv"),
    img_dir=os.path.join(ROOT_DIR, "data", "MMHS150K", "img_resized/"),
    split="train",
    transform=transform
)

eval_dataset = CustomDataset(
    csv_file=os.path.join(ROOT_DIR, "data", "MMHS150K", "MMHS150K_with_img_text.csv"),
    img_dir=os.path.join(ROOT_DIR, "data", "MMHS150K", "img_resized/"),
    split="val",
    transform=transform
)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [None]:
def update_metrics_log(metrics_names, metrics_log, new_metrics_dict):
    for i in range(len(metrics_names)):
        curr_metric_name = metrics_names[i]
        metrics_log[i].append(new_metrics_dict[curr_metric_name])
    return metrics_log


def train_cycle(model, optimizer, criterion, metrics, train_loader, test_loader, tokenizer, n_epochs, device):
    train_loss_log,  test_loss_log = [], []
    metrics_names = list(metrics.keys())
    train_metrics_log = [[] for i in range(len(metrics))]
    test_metrics_log = [[] for i in range(len(metrics))]


    for epoch in range(n_epochs):
        print("Epoch {0} of {1}".format(epoch, n_epochs))
        train_loss, train_metrics = train_epoch(model, optimizer, criterion, metrics, train_loader, tokenizer, device)

        test_loss, test_metrics = eval_epoch(model, criterion, metrics, test_loader, tokenizer, device)

        train_loss_log.append(train_loss)
        train_metrics_log = update_metrics_log(metrics_names, train_metrics_log, train_metrics)

        test_loss_log.append(test_loss)
        test_metrics_log = update_metrics_log(metrics_names, test_metrics_log, test_metrics)

    return train_metrics_log, test_metrics_log


In [None]:
# Choose the tokenization function 
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
# MAYBE LOOK FOR ONE MADE FOR TWEETS

# Create the model
vocab_size = len(tokenizer)
fcm = FCM(device, vocab_size, batch_size, output_size=1, freeze_image_model=True, freeze_text_model=False).to(device)

# Choose the optimizer and the loss function
optimizer = torch.optim.Adam(fcm.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()

# Choose the metrics
metrics = {'ACC': acc, 'F1-weighted': f1}

n_epochs = 5

train_metrics_log, test_metrics_log = train_cycle(fcm, optimizer, criterion, metrics, train_loader, eval_loader, tokenizer, n_epochs, device)

# Save the model
model_weights_dir = os.path.join(ROOT_DIR, "results", "model_weights")
if not os.path.exists(model_weights_dir):
    os.makedirs(model_weights_dir)
torch.save(fcm.state_dict(), os.path.join(model_weights_dir, "fcm.pth"))