In [1]:
import torch
from torch import nn, Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision

import matplotlib.pyplot as plt
import pandas as pd

from PIL import Image, ImageOps
import os
import random

parametri/config:

In [2]:
data_path = "data"

RANDOM_STATE = 1219
N_EPOCHS = 50
BATCH_SIZE = 16
LEARNING_RATE = 0.1
WORKERS = 0

definisanje fja za premestanje na gpu, ako je dostupan:

In [3]:
def get_device():
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def bind_gpu(data):
  device = get_device()
  if isinstance(data, (list, tuple)):
    return [bind_gpu(data_elem) for data_elem in data]
  else:
    return data.to(device, non_blocking=True)

dataset klasa:

In [4]:
class LatexDataset(Dataset):
    """
    Dataset for image-to-LaTeX project.
    Each item is (image_tensor, formula_string)
    """
    def __init__(self, data_type: str, transform=None):
        super().__init__()
        assert data_type in ['train', 'test', 'validate', 'testtest'], 'Not found data type'
        self.transform = transform

        csv_path = os.path.join(data_path, f'im2latex_{data_type}.csv')
        df = pd.read_csv(csv_path)
        # promenimo kolonu image tako da ima ceo put do fajla
        df['image'] = df.image.map(lambda x: os.path.join(os.path.join("data", "formula_images_processed"), f'{x}'))
        # TODO: pojasni si sta je tacno walker
        self.walker = df.to_dict('records')
    
    def __len__(self):
        return len(self.walker)
    
    def __getitem__(self, idx):
        item = self.walker[idx]
        
        formula = item['formula']
        image = torchvision.io.read_image(str(item['image']))
        
        return image, formula

In [5]:

train_set = LatexDataset('train')
val_set = LatexDataset('validate')
test_set = LatexDataset('test')
test_test = LatexDataset('testtest')

len(train_set), len(val_set), len(test_set)

(75275, 8370, 10355)

In [6]:
import re
import json
from torch import Tensor

class Text():
    def __init__(self):
        self.pad_id = 0
        self.sos_id = 1
        self.eos_id = 2
        
        self.id2word = json.load(open("data/vocab/100k_vocab.json", "r"))
        self.word2id = dict(zip(self.id2word, range(len(self.id2word))))
        self.TOKENIZE_PATTERN = re.compile(
            r"(\\[a-zA-Z]+)|"           # LaTeX commands like \frac, \sqrt
            r"((\\)*[$-/:-?{-~!\"^_`\[\]])|"  # math symbols
            r"(\w)|"                    # single letters/numbers
            r"(\\)"                     # stray backslashes
            )
        self.n_class = len(self.id2word)

    def int2text(self, x: Tensor):
        return " ".join([self.id2word[i] for i in x if i > self.eos_id])

    def text2int(self, formula: str):
        return torch.LongTensor([self.word2id[i] for i in self.tokenize(formula)])

    def tokenize(self, formula: str):
        tokens = re.finditer(self.TOKENIZE_PATTERN, formula)
        tokens = list(map(lambda x: x.group(0), tokens))
        tokens = [x for x in tokens if x is not None and x != ""]
        return tokens

data module:

In [7]:
from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence

text = Text()

def collate_fn(batch, text):
    formulas = [text.text2int(i[1]) for i in batch]
    formulas = pad_sequence(formulas, batch_first=True)
    sos = torch.zeros(BATCH_SIZE, 1) + text.sos_id
    eos = torch.zeros(BATCH_SIZE, 1) + text.eos_id
    formulas = torch.cat((sos, formulas, eos), dim=-1).to(dtype=torch.long)
    image = [i[0] for i in batch]
    max_width, max_height = 0, 0
    for img in image:
        c, h, w = img.size()
        max_width = max(max_width, w)
        max_height = max(max_height, h)
    pad = torchvision.transforms.Resize(size=(max_height, max_width))
    image = torch.stack(list(map(lambda x: pad(x), image))).to(dtype=torch.float)
    return image, formulas


In [8]:
test_test_loader = DataLoader(test_test, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS, collate_fn=lambda batch: collate_fn(batch, text))

primer:

In [None]:
for images, formulas in test_test_loader:
    print(images.shape)      # [batch_size, 3, H, W]
    print(formulas.shape)    # [batch_size, max_formula_len+2]
    break

torch.Size([16, 3, 128, 480])
torch.Size([16, 187])


### Enkoder
* input je slika s 3 kanala (RGB)
* output feature map ima `enc_dim` kanala
* `nn.MaxPool2d(2, 1)`  -> asimetricni pooling: praktikuje se za slike teksta, jer je tekst vise sirok nego visok

##### Forward pass
input tenzor za forward pass: `x: (bs, c, w, h)`

`bc = batch size`

`c = number of channels`

`w = image width`

`h = image height`

1. enkodovati `x(bc, c_in, w_in, h_in) -> (bc, c_out, w_out, h_out) `
    *  `c_out` je `enc_dim`
    * `w_out` i `h_out` manji od dimenzija inputa (ofc, conv mreza)
2. permutovati -> da bi feature vector bio poslednji
3. flattenovati
* `encoder_out.shape = (bs, sequence_length = w_out * h_out, d = enc_dim)`

In [10]:
class ConvEncoder(nn.Module):
    def __init__(self, encoder_dim: int):
        super().__init__()
        self.feature_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1),
            nn.Conv2d(64, 128, 3, 1),
            nn.Conv2d(128, 256, 3, 1),
            nn.Conv2d(256, 256, 3, 1),
            nn.MaxPool2d(2, 1),
            nn.Conv2d(256, 512, 3, 1),
            nn.MaxPool2d(1, 2),
            nn.Conv2d(512, encoder_dim, 3, 1),
        )
        self.encoder_dim = encoder_dim

    def forward(self, x: Tensor):
        """
            x: (bs, c, w, h)
            encoder_out: (batch_size, seq_len = width*height, feat_dim = d)
        """
        encoder_out = self.feature_encoder(x)  # (bs, c, w, h) ali c i w manji
        encoder_out = encoder_out.permute(0, 2, 3, 1)  # (bs, w, h, c) poslednja dim je feature depth d
        bs, _, _, d = encoder_out.size()
        encoder_out = encoder_out.view(bs, -1, d) # izravnaj (w, h) matricu u niz w*h tako da (bs, w, h, c) -> (bs, w*h, c)
        return encoder_out

# Attention 

In [11]:
class Attention(nn.Module):
    def __init__(self, enoder_dim: int = 512, decoder_dim: int = 512, attention_dim: int = 512):
        super().__init__()

        """
        Racunamo kontekst vektor na osnovu sledecih jednacina
        e = tanh((Wₕhₜ₋₁ + bₕ) + (WᵥV + bᵥ))  
        αₜ = Softmax(Wₐ·e + bₐ)  
        cₜ = ∑ᵢ αₜⁱ vᵢ, where vᵢ ∈ V  
        """
        self.decoder_attention = nn.Linear(decoder_dim, attention_dim, bias=False) # W_h * h_{t-1}
        self.encoder_attention = nn.Linear(enoder_dim, attention_dim, bias=False) # W_V * V
        self.attention = nn.Linear(attention_dim, 1, bias=False)      # W_a * attn
       
        # Softmax će pretvoriti sirove rezultate u raspodelu verovatnoće (težine pažnje).
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, h: Tensor, V: Tensor):
        """
        Izračunaj kontekst vektor tako što pažljivo posmatraš najrelevantnije delove slike.
       
        Argumenti:
            h: Prethodno skriveno stanje LSTM dekodera. Oblik: (batch_size, decoder_dim)
            V: Mapa karakteristika. Oblik: (batch_size, w * h, encoder_dim)
               
        Povratna vrednost:
            context (Tensor): Vektor koji iz mapa karakteristike izvlaci relevantne podatke za generisanje sledeceg karaktera.
                            Oblik: (batch_size, decoder_dime)
        """


        attn_1 = self.decoder_attention(h) #(b, decoder_dim) -> (b, attention_dim)
        attn_2 = self.encoder_attention(V) #(b, w*h, enoder_dim) -> (b, w*h, attention_dim)
       
        attention= self.attention(torch.tanh(attn_1.unsqueeze(1) + attn_2)).squeeze(2)
        # attn_1.unsqueeze(1): (b, 1, attention_dim)
        # attn_2: (b, w*h, attention_dim)
        # tanh(): (b, w*h, attention_dim)
        # attention: (b, w*h, 1) -> squeeze(2) -> (b, w*h)
       
        alpha = self.softmax(attention)
       
       
        context = (alpha.unsqueeze(2) * V).sum(dim=1)
        # alpha.unsqueeze(2): (b, w*h, 1)
        # V: (b, w*h, enoder_dim)
        # product: (b, w*h, enoder_dim)
        # context: (b, enoder_dim)
        return context

# Dekoder

In [12]:
class Decoder(nn.Module):
    def __init__(self,n_class: int,embedding_dim: int = 80,encoder_dim: int = 512,decoder_dim: int = 512,attention_dim: int = 512,
        num_layers: int = 1,dropout: float = 0.1,bidirectional: bool = False,sos_id: int = 1,eos_id: int = 2):
        super().__init__()
       
        """
        Implementacija dekodera za Image-to-Latex model.
        Koristi LSTM ćeliju i Luong pažnju da generiše LaTeX simbole korak po korak.
        cₜ = Attention(hₜ₋₁, V)
        eₜ = Embedding(yₜ)
        (oₜ, hₜ) = LSTM(hₜ₋₁, [cₜ, eₜ])
        p(yₜ₊₁ | y₁, ..., yₜ) = Softmax(Wₒ · oₜ + bₒ)
       
        """

        self.sos_id = sos_id
        self.eos_id = eos_id
       
        # Embedding layer konvertuje token ID u vektor
        self.embedding = nn.Embedding(n_class, embedding_dim)  # (vocab_size, embedding_dim)
       
        # Instanca mehanizma pažnje
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # Veličina enkodera -> veličina pažnje
       
        # Linearni sloj za spajanje embeddinga i konteksta pažnje
        self.concat = nn.Linear(embedding_dim + encoder_dim, decoder_dim)  # (embedding_dim + encoder_dim) -> decoder_dim
       
        # Prvi LSTM sloj
        self.rnn = nn.LSTM(
            decoder_dim,
            decoder_dim,
            num_layers,
            batch_first=True,
            bidirectional=bidirectional,
        )
       
        # Dropout za regularizaciju
        self.dropout = nn.Dropout(dropout)
       
        # Drugi LSTM sloj za dublji model
        self.rnn2 = nn.LSTM(
            decoder_dim,
            decoder_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=bidirectional,
        )
       
        # Izlazni sloj koji preslikuje u prostor rečnika
        self.out = nn.Linear(decoder_dim, n_class)  # (decoder_dim) -> (n_class)
       
        # LogSoftmax za stabilnost prilikom računanja gubitka
        self.logsoftmax = nn.LogSoftmax(dim=-1)

        # Inicijalizacija težina
        self.apply(self.init_weights)

    def init_weights(self, layer):
        if isinstance(layer, nn.Embedding):
            nn.init.orthogonal_(layer.weight)
        elif isinstance(layer, nn.LSTM):
            for name, param in self.rnn.named_parameters():
                if name.startswith("weight"):
                    nn.init.orthogonal_(param)

    def forward(self, y, encoder_out=None, hidden_state=None):
        """
        Generiše sledeći token na osnovu trenutnog stanja i izlaza enkodera.
       
        Argumenti:
            y: Ulazni tokeni. Oblik: (batch_size, target_len)
            encoder_out: Izlaz enkodera (V). Oblik: (batch_size, encoder_dim, w', h')
            hidden_state: Prethodno skriveno stanje (h, c). Oblik: (num_layers * num_directions, batch_size, decoder_dim)
               
        Povratna vrednost:
            out: Log-verovatnoće za sledeći token. Oblik: (batch_size, 1, n_class)
            hidden_state: Ažurirano skriveno stanje.
        """

        h, c = hidden_state  # (b, decoder_dim), (b, decoder_dim)
       
        embed = self.embedding(y)  # (b, seq_len, embedding_dim)
        attention_context = self.attention(h, encoder_out)  # (b, encoder_dim)
       
        rnn_input = torch.cat([embed[:, -1], attention_context], dim=1)  # (b, embedding_dim + encoder_dim)
        rnn_input = self.concat(rnn_input)  # (b, decoder_dim)
       
        rnn_input = rnn_input.unsqueeze(1)  # (b, 1, decoder_dim)
        hidden_state = (h.unsqueeze(0), c.unsqueeze(0))  # (1, b, decoder_dim), (1, b, decoder_dim)
       
        out, hidden_state = self.rnn(rnn_input, hidden_state)  # out: (b, 1, decoder_dim)
       
        out = self.dropout(out)  # (b, 1, decoder_dim)
       
        out, hidden_state = self.rnn2(out, hidden_state)  # out: (b, 1, decoder_dim)
        out = self.logsoftmax(self.out(out))  # (b, 1, n_class)
       
        h, c = hidden_state
        return out, (h.squeeze(0), c.squeeze(0))  # Squeeze dimenziju slojeva

In [13]:
class Image2LatexModel(nn.Module):
    def __init__(self,n_class: int,embedding_dim: int = 80,encoder_dim: int = 512,decoder_dim: int = 512,attention_dim: int = 512,
        num_layers: int = 1,dropout: float = 0.1,bidirectional: bool = False,text: Text = None, beam_width: int = 5, sos_id: int = 1,eos_id: int = 2):
        super().__init__()
        self.encoder = ConvEncoder(encoder_dim=encoder_dim)
        self.decoder = Decoder(n_class=n_class,embedding_dim=embedding_dim,encoder_dim=encoder_dim,decoder_dim=decoder_dim,attention_dim=attention_dim,num_layers=num_layers,dropout=dropout,bidirectional=bidirectional,sos_id=sos_id,eos_id=eos_id)

        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.n_class = n_class
        self.text = text
        self.beam_width = beam_width
        self.encoder = ConvEncoder(encoder_dim=encoder_dim)
        self.criterion = nn.CrossEntropyLoss()
    
    # TODO: why do we init the hidden state like this and not just 0?
    def init_decoder_hidden_state(self, V: Tensor):
        """
            input V je autput enkodera (bs, w*h, c)
            return (h, c)
        """
        encoder_mean = V.mean(dim=1)
        h = torch.tanh(self.init_h(encoder_mean))
        c = torch.tanh(self.init_c(encoder_mean))
        return h, c
    
    def forward(self, x: Tensor, y: Tensor, y_len: Tensor):
        encoder_out = self.encoder(x)

        hidden_state = self.init_decoder_hidden_state(encoder_out)

        predictions = []

        for t in range(y_len.max().item()):
            dec_input = y[:, t].unsqueeze(1)
            out, hidden_state = self.decoder(dec_input, encoder_out, hidden_state)
            predictions.append(out.squeeze(1))

        predictions = torch.stack(predictions, dim=1)
        return predictions
    
    def decode_beam_search(self, x, max_length=150):
        encoder_out = self.encoder(x)
        hidden_state = self.init_decoder_hidden_state(encoder_out)

        list_candidate = [([self.decoder.sos_id], hidden_state, 0)]
        for t in range(max_length):
            new_candidates = []
            for inp, state, log_prob in list_candidate:
                y = torch.LongTensor([inp[-1]]).view(BATCH_SIZE, -1).to(x.device)
                out, hidden_state = self.decoder(y, encoder_out, state)

                topk = out.topk(self.beam_width)
                for val, idx in zip(topk.values.view(-1), topk.indices.view(-1)):
                    new_inp = inp + [idx.item()]
                    new_candidates.append((new_inp, hidden_state, log_prob + val.item()))

            new_candidates = sorted(new_candidates, key=lambda x: x[2], reverse=True)
            list_candidate = new_candidates[: self.beam_width]

        return list_candidate[0][0]
    
    def decode(self, x, max_length=150):
        predict = self.decode_beam_search(x, max_length)
        return self.text.int2text(predict)
    
    def compute_loss(self, outputs, formulas_out):
        """
        Compute CrossEntropy loss between predictions and targets.
        outputs: (batch, seq_len, vocab_size) logits
        formulas_out: (batch, seq_len) target token IDs
        """
        bs, t, _ = outputs.size()
        return self.criterion(
            outputs.reshape(bs * t, -1),   # flatten predictions
            formulas_out.reshape(-1)       # flatten targets
        )

In [14]:

# test 1: hoce li se instancirati?
model = Image2LatexModel(n_class=100)
print(model)

Image2LatexModel(
  (encoder): ConvEncoder(
    (feature_encoder): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
      (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
      (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      (4): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
      (6): MaxPool2d(kernel_size=1, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (decoder): Decoder(
    (embedding): Embedding(100, 80)
    (attention): Attention(
      (decoder_attention): Linear(in_features=512, out_features=512, bias=False)
      (encoder_attention): Linear(in_features=512, out_features=512, bias=False)
      (attention): Linear(in_features=512, out_features=1, bias=False)
      (softmax): Softmax(dim=-1)
    )
    (co

In [19]:
# test 2 : forward pass random data

channels = 1
height = 64
width = 256

# Fake images
batch_size = 1

# Fake integer image input (values between 0 and 255 like pixel values)
# x = torch.randint(0, 256, (batch_size, 3, 64, 256), dtype=torch.int64)
x = torch.randn(batch_size, 3, 64, 256)

# Target sequences (integer token IDs)
y = torch.randint(0, 10, (batch_size, 50), dtype=torch.int64)
# y = torch.randn(batch_size, 50)
seq_len = y.size(1)
y_len = torch.randint(1, seq_len + 1, (batch_size,))

# Forward pass (you might also need target sequences if your forward expects them)
out = model(x, y, y_len)
print(out.shape)

torch.Size([1, 13, 100])
