In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision

import matplotlib.pyplot as plt
import pandas as pd

from PIL import Image, ImageOps
import os
import random

parametri/config:

In [2]:
data_path = "data"

RANDOM_STATE = 1219
N_EPOCHS = 50
BATCH_SIZE = 1024
LEARNING_RATE = 0.1

definisanje fja za premestanje na gpu, ako je dostupan:

In [3]:
def get_device():
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def bind_gpu(data):
  device = get_device()
  if isinstance(data, (list, tuple)):
    return [bind_gpu(data_elem) for data_elem in data]
  else:
    return data.to(device, non_blocking=True)

dataset klasa:

In [4]:
class LatexDataset(Dataset):
    def __init__(self, data_type: str, transform=None):
        super().__init__()
        assert data_type in ['train', 'test', 'validate'], 'Not found data type'
        self.transform = transform

        csv_path = os.path.join(data_path, f'im2latex_{data_type}.csv')
        df = pd.read_csv(csv_path)
        # promenimo kolonu image tako da ima ceo put do fajla
        df['image'] = df.image.map(lambda x: os.path.join(os.path.join("data", "formula_images_processed"), f'{x}'))
        # TODO: pojasni si sta je tacno walker
        self.walker = df.to_dict('records')
    
    def __len__(self):
        return len(self.walker)
    
    def __getitem__(self, idx):
        item = self.walker[idx]
        
        formula = item['formula']
        image = torchvision.io.read_image(str(item['image']))
        
        return image, formula

In [5]:

train_set = LatexDataset('train')
val_set = LatexDataset('validate')
test_set = LatexDataset('test')

len(train_set), len(val_set), len(test_set)

(75275, 8370, 10355)

### Enkoder
* input je slika s 3 kanala (RGB)
* output feature map ima `enc_dim` kanala
* `nn.MaxPool2d(2, 1)`  -> asimetricni pooling: praktikuje se za slike teksta, jer je tekst vise sirok nego visok

##### Forward pass
input tenzor za forward pass: `x: (bs, c, w, h)`

`bc = batch size`

`c = number of channels`

`w = image width`

`h = image height`

1. enkodovati `x(bc, c_in, w_in, h_in) -> (bc, c_out, w_out, h_out) `
    *  `c_out` je `enc_dim`
    * `w_out` i `h_out` manji od dimenzija inputa (ofc, conv mreza)
2. permutovati -> da bi feature vector bio poslednji
3. flattenovati
* `encoder_out.shape = (bs, sequence_length = w_out * h_out, d = enc_dim)`

In [6]:
from torch import nn, Tensor

class ConvEncoder(nn.Module):
    def __init__(self, enc_dim: int):
        super().__init__()
        self.feature_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1),
            nn.Conv2d(64, 128, 3, 1),
            nn.Conv2d(128, 256, 3, 1),
            nn.Conv2d(256, 256, 3, 1),
            nn.MaxPool2d(2, 1),
            nn.Conv2d(256, 512, 3, 1),
            nn.MaxPool2d(1, 2),
            nn.Conv2d(512, enc_dim, 3, 1),
        )
        self.enc_dim = enc_dim

    def forward(self, x: Tensor):
        """
            x: (bs, c, w, h)
        """
        encoder_out = self.feature_encoder(x)  # (bs, c, w, h)
        encoder_out = encoder_out.permute(0, 2, 3, 1)  # (bs, w, h, c)
        bs, _, _, d = encoder_out.size()
        encoder_out = encoder_out.view(bs, -1, d)
        return encoder_out