In [8]:
# packages
import os
import sys
import pickle
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch.nn as nn
import torch.optim as optim
from IPython.display import clear_output
import wandb

wandb.init(
    # set the wandb project where this run will be logged
    project="odl",
)

# python files
sys.path.insert(0, "..")

from trainer import Trainer
from transforms import TextTransform, ImageTransform
from datasets import TestingDataset
from module import *
from vocabulary import Vocabulary
from metrics import accuracy

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.093490…

0,1
acc,█▅▂▁▂▃▄▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██
loss,▄▆▇█▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁

0,1
acc,0.84561
loss,0.46921


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333435779, max=1.0)…

In [9]:
CUR_DIR = os.path.abspath(os.curdir)
IMAGES_DIR = os.path.join(CUR_DIR, "../data/reintel2020/public_train_final_images/")
CACHE_PATH = os.path.join(CUR_DIR, "../.cache/")

In [10]:
cache = open(os.path.join(CACHE_PATH, "pretrain_dataset.pkl"), "rb")
loaded_dataset = pickle.load(cache)

In [11]:
dataset = TestingDataset(loaded_dataset)
dataloader = DataLoader(dataset, batch_size=128)

In [12]:
tokenizer = lambda x: [_.split() for _ in x]

In [13]:
class PrequentialMetrics:
    def __init__(self, values=0):
        self.values = values
        self.t = 0

    def update(self, value):
        self.values = (self.values * self.t + value) / (self.t + 1)
        self.t += 1

    def __repr__(self) -> str:
        return repr(self.values)

In [14]:
# model
model = LSTMClassifier(
    vocab_size=100000, embedding_dim=100, hidden_dim=128, num_classes=2
)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

vc = Vocabulary(specials=["<pad>", "<unk>"])
vc.set_default_idx(vc["<unk>"])


total_acc = PrequentialMetrics()
total_loss = PrequentialMetrics()

for idx, batch in enumerate(dataloader):
    texts = tokenizer(batch["post_message"])
    labels = batch["label"]

    max_lenght = 0
    # append new tokens
    for tokens in texts:
        max_lenght = max(max_lenght, len(tokens))
        for token in tokens:
            try:
                vc.append_token(token)
            except Exception as e:
                pass

    # token -> index -> padding
    text_transform = TextTransform(max_length=20)
    texts = torch.stack([text_transform(vc.get_idxs(tokens)) for tokens in texts])

    # train
    model.train()
    optimizer.zero_grad()
    predictions = model(texts)

    loss = criterion(predictions, labels)
    acc = accuracy(predictions, labels)
    loss.backward()
    optimizer.step()

    total_acc.update(acc)
    total_loss.update(loss)
    wandb.log({"acc": total_acc.values, "loss": total_loss.values})