In [None]:
import sys
sys.path.append("../trainer")

In [None]:
import os

import torch
import torchvision as tv
from transformers import AutoTokenizer

from ignite.engine import (
    Engine,
    Events,
)
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine
from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers.neptune_logger import NeptuneLogger

In [None]:
from datamodule import SROIETask2DataModule
from model import TransformersEncoder
from cnn import CNN as VisualModel
from ctc import GreedyDecoder
from igmetrics import ExactMatch, WordF1

In [None]:
tokenizer = AutoTokenizer.from_pretrained("../trainer/tokenizer")
decoder = GreedyDecoder(0)

# Loader

In [None]:
DATA_PATH = "/Users/israelcampiotti/Documents/Github/msc/tmp-master/SROIETask2"
dm = SROIETask2DataModule(
    root_dir=os.path.join(DATA_PATH, "data"),
    label_file=os.path.join(DATA_PATH, "data.json"),
    tokenizer=tokenizer,
    height=32,
    num_workers=4,
    train_bs=2,
    valid_bs=2,
    val_pct=0.001,
    max_width=None,
    do_pool=True,
)

In [None]:
dm.setup("fit")

# Model

In [None]:
class OCRModel(torch.nn.Module):
    def __init__(self, visual_model, rec_model):
        super().__init__()
        self.visual_model = visual_model
        self.rec_model = rec_model

    def forward(self, images, attention_mask=None):
        features = self.visual_model(images)
        logits = self.rec_model(features, attention_mask=attention_mask)
        return logits

In [None]:
vis_model = VisualModel()
rec_model = TransformersEncoder(vocab_size=tokenizer.vocab_size)
model = OCRModel(vis_model, rec_model)

# Ignite

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
_ = model.to(device)

In [None]:
val_loader =  dm.val_dataloader()
train_loader = val_loader # dm.train_dataloader()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CTCLoss(blank=0, zero_infinity=True)

In [None]:
def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()

    images, labels, attention_mask, attention_image = batch

    logits = model(images, attention_image)

    input_length = attention_image.sum(-1)
    target_length = attention_mask.sum(-1)

    logits = logits.permute(1, 0, 2)
    logits = logits.log_softmax(2)

    loss = criterion(logits, labels, input_length, target_length)

    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_step)

In [None]:
def val_step(engine, batch):
    model.eval()
    images, labels, attention_mask, attention_image = batch
    with torch.no_grad():
        logits = model(images, attention_image)

    decoded_ids = logits.argmax(-1).squeeze(0)
    if len(decoded_ids.shape) == 1:
        decoded_ids = decoded_ids.unsqueeze(0)
    decoded = [
        decoder(dec, att) for dec, att in zip(decoded_ids, attention_image)
    ]
    y_pred = tokenizer.batch_decode(decoded, skip_special_tokens=True)
    y = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return y_pred, y
    

In [None]:
train_evaluator = Engine(val_step)
validation_evaluator = Engine(val_step)

In [None]:
ExactMatch().attach(train_evaluator, "accuracy")
ExactMatch().attach(validation_evaluator, "accuracy")
WordF1().attach(train_evaluator, "f1")
WordF1().attach(validation_evaluator, "f1")

In [None]:
# @trainer.on(Events.EPOCH_COMPLETED)
# def log_training_results(engine):
#     train_evaluator.run(train_loader)
#     metrics = train_evaluator.state.metrics
#     avg_accuracy = metrics['accuracy']
#     print(f"Training Results - Epoch: {engine.state.epoch}  Avg accuracy: {avg_accuracy:.3f}")
    
def log_validation_results(engine):
    validation_evaluator.run(val_loader)
    metrics = validation_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    print(f"Validation Results - Epoch: {engine.state.epoch}  Avg accuracy: {avg_accuracy:.3f}")

trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)

In [None]:
checkpointer = ModelCheckpoint(dirname='models', filename_prefix='deberta-ocr', n_saved=2, create_dir=True, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model})

In [None]:
neptune_logger = NeptuneLogger(
    project="i155825/OCRMsc",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJhOGUyY2VlOS1hZTU5LTQ2NGQtYTY5Zi04OGJmZWM2M2NlMDAifQ==",
)


In [None]:
neptune_logger.attach_output_handler(
    trainer,
    event_name=Events.ITERATION_COMPLETED,
    tag="training",
    output_transform=lambda loss: {"loss": loss},
)

neptune_logger.attach_output_handler(
    validation_evaluator,
    event_name=Events.EPOCH_COMPLETED,
    tag="validation",
    metric_names=["f1", "accuracy"],
    global_step_transform=global_step_from_engine(trainer),  
)

neptune_logger["code"].upload_files(["../trainer/*.py"])

In [None]:
pbar = ProgressBar()
pbar.attach(trainer, output_transform=lambda x: {'loss': x})

In [None]:
trainer.run(train_loader, max_epochs=10)