In [2]:
# packages
import os
import sys
import pickle
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch.nn as nn
import torch.optim as optim
from IPython.display import clear_output
import wandb

# python files
sys.path.insert(0, "..")

from trainer import Trainer
from transforms import TextTransform, ImageTransform
from datasets import TestingDataset
from module import *
from vocabulary import Vocabulary
from metrics import accuracy

In [3]:
CUR_DIR = os.path.abspath(os.curdir)
IMAGES_DIR = os.path.join(CUR_DIR, "../data/reintel2020/public_train_final_images/")
CACHE_PATH = os.path.join(CUR_DIR, "../.cache/")

In [4]:
cache = open(os.path.join(CACHE_PATH, "pretrain_dataset.pkl"), "rb")
loaded_dataset = pickle.load(cache)

In [5]:
dataset = TestingDataset(loaded_dataset)
dataloader = DataLoader(dataset, batch_size=128)

In [6]:
tokenizer = lambda x: [_.split() for _ in x]

In [7]:
class PrequentialMetrics:
    def __init__(self, values=0):
        self.values = values
        self.t = 0

    def update(self, value):
        self.values = (self.values * self.t + value) / (self.t + 1)
        self.t += 1

    def __repr__(self) -> str:
        return repr(self.values)

In [8]:
class Ensemble(nn.Module):
    def __init__(self, text_model, image_model, num_classes):
        super(Ensemble, self).__init__()
        self.text_model = text_model
        self.image_model = image_model
        self.weights = nn.Parameter(torch.ones(2) / 2, requires_grad=True)

    def forward(self, text, image):
        out_text = self.text_model(text)
        out_image = self.image_model(image)

        out = (
            out_text * self.weights[0] + out_image * self.weights[1]
        ) / self.weights.sum()

        return F.softmax(out, dim=1)

In [12]:
# model
text_model = LSTMClassifier(
    vocab_size=100000, embedding_dim=100, hidden_dim=128, num_classes=2
)

image_model = Resnet18(num_classes=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Ensemble(text_model=text_model, image_model=image_model, num_classes=2)

model = model.to(device)


total_acc = PrequentialMetrics()
total_loss = PrequentialMetrics()

for idx, batch in enumerate(dataloader):
    texts = tokenizer(batch["post_message"])
    labels = batch["label"]

    max_lenght = 0
    # append new tokens
    for tokens in texts:
        max_lenght = max(max_lenght, len(tokens))
        for token in tokens:
            try:
                vc.append_token(token)
            except Exception as e:
                pass

    # token -> index -> padding
    text_transform = TextTransform(max_length=20)
    texts = torch.stack([text_transform(vc.get_idxs(tokens)) for tokens in texts])
    images = torch.stack([torch.rand(3, 256, 256) for tokens in texts])

    # train
    model.train()
    optimizer.zero_grad()

    texts = texts.to(device)
    images = images.to(device)
    labels = labels.to(device)

    predictions = model(texts, images)

    loss = criterion(predictions, labels)
    acc = accuracy(predictions, labels)
    loss.backward()
    optimizer.step()

    total_acc.update(acc)
    total_loss.update(loss)
    wandb.log({"acc": total_acc.values, "loss": total_loss.values})

VBox(children=(Label(value='0.002 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.150805…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…