In [None]:
!git clone https://github.com/domeGIT/ml_image_to_latex_2024

In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms.functional as TF
from torchvision import transforms

from torch.cuda.amp import autocast, GradScaler
from nltk.translate.bleu_score import corpus_bleu

import pandas as pd
import json
import os
import shutil # potrebno za google colab

from ml_image_to_latex_2024.image2latex import Text, LatexDataset, Image2LatexModel, exact_match, collate_fn


In [None]:
# deo koda potreban da bi u google colabu postojao pristup data folderu:
# potrebno pokrenuti samo tokom prvog pokretanja koda u sesiji
# pre toga je potrebno dostaviti data.tar na svoj google drive

dst = "/content/data"

if os.path.exists(dst):
     shutil.rmtree(dst)

!cp /content/drive/MyDrive/data.tar /content/
!tar -xf /content/data.tar -C /content

In [None]:
# Dodatna funkcija transformacije slika,
# smanjuje slike (čuva proporcije)
# dodata u trening radi ubrzavanja izvršavanja treniranja
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor()
])

In [None]:
# MOUNTOVANJE DRAJVA
# potrebno za čuvanje logova i modela
# najpre ih čuvamo na lične drajvove, potom ručno kačimo na git, 
# jer ne želimo da nam svako testno pokretanje sveske menja githab repozitorijum
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# za logovanje metrika modela:

os.makedirs("/content/drive/My Drive/im2latex/", exist_ok=True)

log_file = "/content/drive/My Drive/im2latex/training_log.jsonl"

# Trening

promenljive za konfiguraciju:

In [None]:
data_path = "data"

RANDOM_STATE = 1219
N_EPOCHS = 10
BATCH_SIZE = 16 
LEARNING_RATE = 0.1
WORKERS = 4
MAX_LENGTH=150
EFFECTIVE_BATCH_SIZE = 64

smestimo se na grafičku ukoliko je moguće

In [None]:
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def bind_gpu(data):
    device = get_device()
    if isinstance(data, (list, tuple)):
        return [bind_gpu(data_elem) for data_elem in data]
    else:
        return data.to(device, non_blocking=True)

Inicijalizacija dataloadera, modela,

In [None]:
# Priprema podataka
text_processor = Text()

# učitavanje
train_dataset = LatexDataset('/content/data/im2latex_train.csv', transform=transform)
val_dataset = LatexDataset('/content/data/im2latex_validate.csv', transform=transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,  # important for training
    num_workers=WORKERS,
    collate_fn=lambda batch: collate_fn(batch, text_processor)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,  # usually keep validation deterministic
    num_workers=WORKERS,
    collate_fn=lambda batch: collate_fn(batch, text_processor)
)

# Inicijalizacija modela,greske,scheduler,optimizera
device = get_device()

model = Image2LatexModel(
    n_class=text_processor.n_class,
    text=text_processor,
    beam_width=5,
    sos_id=text_processor.sos_id,
    eos_id=text_processor.eos_id
).to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, betas=(0.9, 0.98))

total_steps = (len(train_dataset) // EFFECTIVE_BATCH_SIZE) * N_EPOCHS
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.002, total_steps=total_steps)
accumulation_steps = EFFECTIVE_BATCH_SIZE // BATCH_SIZE

scaler = GradScaler()

In [None]:
#treninng
for epoch in range(N_EPOCHS):
    print(f"Epoch {epoch+1}/{N_EPOCHS}")


    model.train()
    train_loss = 0.0

    for batch_idx, batch in enumerate(train_loader):

        images, formulas, formula_len = bind_gpu(batch)

        formulas_in = formulas[:, :-1]
        formulas_out = formulas[:, 1:]

        with autocast():
          outputs = model(images, formulas_in, formula_len)

          loss = criterion(outputs.reshape(-1, outputs.shape[-1]), formulas_out.reshape(-1))
          loss = loss / accumulation_steps

        scaler.scale(loss).backward()
        train_loss += loss.item() * accumulation_steps

        if (batch_idx + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()

        del images, formulas, outputs, loss
        torch.cuda.empty_cache()


    if (batch_idx + 1) % accumulation_steps != 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Average Train Loss: {avg_train_loss:.4f}")

    # model save
    os.makedirs("/content/drive/My Drive/im2latex/saved_models", exist_ok=True)
    path = f"/content/drive/My Drive/im2latex/saved_models/model{epoch+1}.pt"
    torch.save(model, path)

    model.eval()
    val_loss = 0.0
    val_bleu = 0.0
    val_em = 0.0

    with torch.no_grad():
        i = 0
        for batch in val_loader:

            images, formulas, formula_len = bind_gpu(batch)
            formulas_in = formulas[:, :-1]
            formulas_out = formulas[:, 1:]

            with autocast():
              outputs = model(images, formulas_in, formula_len)
              loss = criterion(outputs.reshape(-1, outputs.shape[-1]), formulas_out.reshape(-1))
            val_loss += loss.item()

            # BATCH DECODING:
            predicts = model.decode_greedy_batch(images, max_length=MAX_LENGTH)
            truths = [formula.tolist() for formula in formulas]

            predict_strings = [text_processor.tokenize(text_processor.int2text(p)) for p in predicts]
            truth_strings = [text_processor.tokenize(text_processor.int2text(t)) for t in truths]

            bleu4 = corpus_bleu([[t] for t in truth_strings], predict_strings)
            em = exact_match(predict_strings, truth_strings)

            val_bleu += bleu4
            val_em += em

    avg_val_loss = val_loss / len(val_loader)
    avg_val_bleu = val_bleu / len(val_loader)
    avg_val_em = val_em / len(val_loader)

    epoch_metrics = {
        "epoch": epoch + 1,
        "train_loss": float(avg_train_loss),
        "val_loss": float(avg_val_loss),
        "bleu4": float(avg_val_bleu),
        "em": float(avg_val_em)
    }

    # Append as a single line
    with open(log_file, "a") as file:
        file.write(json.dumps(epoch_metrics) + "\n")

    print(f"Validation Loss: {avg_val_loss:.4f}, BLEU4: {avg_val_bleu:.4f}, EM: {avg_val_em:.4f}")

print("Training complete!")
torch.save(model.state_dict(), 'image2latex_model.pth')