In [45]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision

import matplotlib.pyplot as plt
import pandas as pd

from PIL import Image, ImageOps
import os
import random

parametri/config:

In [46]:
data_path = "data"

RANDOM_STATE = 1219
N_EPOCHS = 50
BATCH_SIZE = 1024
LEARNING_RATE = 0.1
WORKERS = 0

definisanje fja za premestanje na gpu, ako je dostupan:

In [47]:
def get_device():
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def bind_gpu(data):
  device = get_device()
  if isinstance(data, (list, tuple)):
    return [bind_gpu(data_elem) for data_elem in data]
  else:
    return data.to(device, non_blocking=True)

dataset klasa:

In [None]:
class LatexDataset(Dataset):
    """
    Dataset for image-to-LaTeX project.
    Each item is (image_tensor, formula_string)
    """
    def __init__(self, data_type: str, transform=None):
        super().__init__()
        assert data_type in ['train', 'test', 'validate', 'testtest'], 'Not found data type'
        self.transform = transform

        csv_path = os.path.join(data_path, f'im2latex_{data_type}.csv')
        df = pd.read_csv(csv_path)
        # promenimo kolonu image tako da ima ceo put do fajla
        df['image'] = df.image.map(lambda x: os.path.join(os.path.join("data", "formula_images_processed"), f'{x}'))
        # TODO: pojasni si sta je tacno walker
        self.walker = df.to_dict('records')
    
    def __len__(self):
        return len(self.walker)
    
    def __getitem__(self, idx):
        item = self.walker[idx]
        
        formula = item['formula']
        image = torchvision.io.read_image(str(item['image']))
        
        return image, formula

In [49]:

train_set = LatexDataset('train')
val_set = LatexDataset('validate')
test_set = LatexDataset('test')
test_test = LatexDataset('testtest')

len(train_set), len(val_set), len(test_set)

(75275, 8370, 10355)

In [50]:
import re
import json
from torch import Tensor

class Text():
    def __init__(self):
        self.pad_id = 0
        self.sos_id = 1
        self.eos_id = 2
        
        self.id2word = json.load(open("data/vocab/100k_vocab.json", "r"))
        self.word2id = dict(zip(self.id2word, range(len(self.id2word))))
        self.TOKENIZE_PATTERN = re.compile(
            "(\\\\[a-zA-Z]+)|" + '((\\\\)*[$-/:-?{-~!"^_`\[\]])|' + "(\w)|" + "(\\\\)"
        )
        self.n_class = len(self.id2word)

    def int2text(self, x: Tensor):
        return " ".join([self.id2word[i] for i in x if i > self.eos_id])

    def text2int(self, formula: str):
        return torch.LongTensor([self.word2id[i] for i in self.tokenize(formula)])

    def tokenize(self, formula: str):
        tokens = re.finditer(self.TOKENIZE_PATTERN, formula)
        tokens = list(map(lambda x: x.group(0), tokens))
        tokens = [x for x in tokens if x is not None and x != ""]
        return tokens

  "(\\\\[a-zA-Z]+)|" + '((\\\\)*[$-/:-?{-~!"^_`\[\]])|' + "(\w)|" + "(\\\\)"
  "(\\\\[a-zA-Z]+)|" + '((\\\\)*[$-/:-?{-~!"^_`\[\]])|' + "(\w)|" + "(\\\\)"


data module:

In [51]:
from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence

text = Text()

def collate_fn(batch, text):
    formulas = [text.text2int(i[1]) for i in batch]
    formulas = pad_sequence(formulas, batch_first=True)
    sos = torch.zeros(BATCH_SIZE, 1) + text.sos_id
    eos = torch.zeros(BATCH_SIZE, 1) + text.eos_id
    formulas = torch.cat((sos, formulas, eos), dim=-1).to(dtype=torch.long)
    image = [i[0] for i in batch]
    max_width, max_height = 0, 0
    for img in image:
        c, h, w = img.size()
        max_width = max(max_width, w)
        max_height = max(max_height, h)
    pad = torchvision.transforms.Resize(size=(max_height, max_width))
    image = torch.stack(list(map(lambda x: pad(x), image))).to(dtype=torch.float)
    return image, formulas


In [52]:
test_test_loader = DataLoader(test_test, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS, collate_fn=lambda batch: collate_fn(batch, text))

primer:

In [53]:
for images, formulas in test_test_loader:
    print(images.shape)      # [32, 3, H, W]
    print(formulas.shape)    # [32, max_formula_len+2]
    break

torch.Size([1024, 3, 160, 480])
torch.Size([1024, 193])


### Enkoder
* input je slika s 3 kanala (RGB)
* output feature map ima `enc_dim` kanala
* `nn.MaxPool2d(2, 1)`  -> asimetricni pooling: praktikuje se za slike teksta, jer je tekst vise sirok nego visok

##### Forward pass
input tenzor za forward pass: `x: (bs, c, w, h)`

`bc = batch size`

`c = number of channels`

`w = image width`

`h = image height`

1. enkodovati `x(bc, c_in, w_in, h_in) -> (bc, c_out, w_out, h_out) `
    *  `c_out` je `enc_dim`
    * `w_out` i `h_out` manji od dimenzija inputa (ofc, conv mreza)
2. permutovati -> da bi feature vector bio poslednji
3. flattenovati
* `encoder_out.shape = (bs, sequence_length = w_out * h_out, d = enc_dim)`

In [54]:
from torch import nn, Tensor

class ConvEncoder(nn.Module):
    def __init__(self, enc_dim: int):
        super().__init__()
        self.feature_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1),
            nn.Conv2d(64, 128, 3, 1),
            nn.Conv2d(128, 256, 3, 1),
            nn.Conv2d(256, 256, 3, 1),
            nn.MaxPool2d(2, 1),
            nn.Conv2d(256, 512, 3, 1),
            nn.MaxPool2d(1, 2),
            nn.Conv2d(512, enc_dim, 3, 1),
        )
        self.enc_dim = enc_dim

    def forward(self, x: Tensor):
        """
            x: (bs, c, w, h)
        """
        encoder_out = self.feature_encoder(x)  # (bs, c, w, h)
        encoder_out = encoder_out.permute(0, 2, 3, 1)  # (bs, w, h, c)
        bs, _, _, d = encoder_out.size()
        encoder_out = encoder_out.view(bs, -1, d)
        return encoder_out