In [1]:
!git clone https://github.com/domeGIT/ml_image_to_latex_2024

fatal: destination path 'ml_image_to_latex_2024' already exists and is not an empty directory.


In [2]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms.functional as TF
from torchvision import transforms

from torch.cuda.amp import autocast, GradScaler
from nltk.translate.bleu_score import corpus_bleu

import pandas as pd
import json
import os
import shutil # potrebno za google colab

from ml_image_to_latex_2024.image2latex import Text, LatexDataset, Image2LatexModel, exact_match, collate_fn, get_device, bind_gpu
from ml_image_to_latex_2024.image2latex import ConvEncoder, Decoder, Attention

In [4]:
# raspakivanje data.tar u /content/data.tar
dst = "/content/data"

if os.path.exists(dst):
     shutil.rmtree(dst)

!cp /content/ml_image_to_latex_2024/data.tar /content/
!tar -xf /content/data.tar -C /content

Priprema log fajla za čuvanje rezultata tessta

In [5]:
# MOUNTOVANJE DRAJVA
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
# uvođenje log fajla
os.makedirs("/content/drive/My Drive/im2latex/", exist_ok=True)

log_file = "/content/drive/My Drive/im2latex/test_log.json"

Uvođenje objekata i definisanje parametara potrebnih za test

In [7]:
# konfig/parametri
BATCH_SIZE = 16
WORKERS = 4
MAX_LENGTH = 150

In [8]:
device = get_device()

In [9]:
# kriterijum za loss
criterion = torch.nn.CrossEntropyLoss().to(device)

In [10]:
# Sa obzirom na to gde sve koristimo transform funkciju, bolje da smo je uključili u sam model
# ipak, vreme izrade projekta je ograničeno, neka je za sad ovde
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor()
])

text_processor = Text()

test_dataset = LatexDataset('/content/data/im2latex_train.csv', transform=transform)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,  # usually keep validation deterministic
    num_workers=WORKERS,
    collate_fn=lambda batch: collate_fn(batch, text_processor)
)

učitajmo prethodno sačuvani model:

In [11]:
model = torch.load("/content/ml_image_to_latex_2024/saved_models/model9.pt", map_location=device, weights_only=False)
model.eval()

Image2LatexModel(
  (encoder): ConvEncoder(
    (feature_encoder): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(128, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU()
    )
  )
  (decoder): Decoder(
    (embedding): Embedding(520, 80)
    (attention): Attention(
      (decoder_attention): Linear(in_features=512, out_features=512, bias=False)
      (encoder_attention): Linear(in_features=512, out_features=512, bias=False)
      (attention): Linear(in_features=512, out_features=1, bias=False)
      (softmax): Softmax(dim=-1)
    )
    (concat): Linear(in_features=592, out_features=512, bias=True)
    (rnn): LSTM(512, 512, b

In [12]:
def test(model, test_loader, text_processor, criterion=None, log_file=None):
    model.eval()
    test_loss = 0.0
    test_bleu = 0.0
    test_em = 0.0

    with torch.no_grad():
        for batch in test_loader:
            images, formulas, formula_len = bind_gpu(batch)
            formulas_in = formulas[:, :-1]
            formulas_out = formulas[:, 1:]

            # izračunajmo loss
            with autocast():
                outputs = model(images, formulas_in, formula_len)
                loss = criterion(outputs.reshape(-1, outputs.shape[-1]),
                                     formulas_out.reshape(-1))
            test_loss += loss.item()

            # generiši predikcije
            predicts = model.decode_greedy_batch(images, max_length=MAX_LENGTH)
            truths = [formula.tolist() for formula in formulas]

            predict_strings = [text_processor.tokenize(text_processor.int2text(p)) for p in predicts]
            truth_strings = [text_processor.tokenize(text_processor.int2text(t)) for t in truths]

            bleu4 = corpus_bleu([[t] for t in truth_strings], predict_strings)
            em = exact_match(predict_strings, truth_strings)
            print(bleu4)
            print(em)

            test_bleu += bleu4
            test_em += em

    avg_test_loss = test_loss / len(test_loader)
    avg_test_bleu = test_bleu / len(test_loader)
    avg_test_em = test_em / len(test_loader)

    results = {
        "test_loss": float(avg_test_loss) if avg_test_loss is not None else None,
        "bleu4": float(avg_test_bleu),
        "em": float(avg_test_em)
    }

    if log_file is not None:
        with open(log_file, "a") as file:
            file.write(json.dumps(results) + "\n")

    print(f"Test BLEU4: {avg_test_bleu:.4f}, EM: {avg_test_em:.4f}"
          + (f", Loss: {avg_test_loss:.4f}" if avg_test_loss is not None else ""))

    return results

In [13]:
print(test(model, test_loader, text_processor, criterion, log_file))

  with autocast():


0.31583944618279086
tensor(0., dtype=torch.float64)
0.13387989978881518
tensor(0., dtype=torch.float64)
0.17531312759668374
tensor(0., dtype=torch.float64)
0.2923333456737029
tensor(0., dtype=torch.float64)
0.23527334545349146
tensor(0., dtype=torch.float64)
0.3194239130137035
tensor(0., dtype=torch.float64)
0.21155810477248727
tensor(0., dtype=torch.float64)
0.27769150495769956
tensor(0., dtype=torch.float64)
0.18483250568148107
tensor(0., dtype=torch.float64)
0.24586317905573393
tensor(0., dtype=torch.float64)
0.24320359037040906
tensor(0., dtype=torch.float64)
0.2801166836246416
tensor(0., dtype=torch.float64)
0.36958155105596296
tensor(0., dtype=torch.float64)
0.28541420279336976
tensor(0., dtype=torch.float64)
0.30763387961670574
tensor(0., dtype=torch.float64)
0.19165090402806703
tensor(0., dtype=torch.float64)
0.3080216563620884
tensor(0., dtype=torch.float64)
0.17250309543002462
tensor(0., dtype=torch.float64)
0.20453990812626258
tensor(0., dtype=torch.float64)
0.19665627137751

KeyboardInterrupt: 