In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from utils.inkml2img import convert_dir
import warnings
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizerFast
from data.dataset import Im2LatexDataset
from model.vit import ViT
import sys
sys.path.append("../")


warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


## testing the module


In [2]:
img_dims = [224, 640]
data = Im2LatexDataset(path_to_data="../data/",
                       tokenizer="../data/tokenizer.json", img_dims=img_dims, classification=True)
imgs, labels = next(iter(data.test))
# labels

In [None]:
model = ViT(img_dims, 16, n_embd=512, encoder=False, output_classes=10)
model(imgs[0].unsqueeze(0)).shape

In [None]:
model.get_num_params()

In [None]:
device = torch.device("cpu")

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("gpu")

device

## preparing data


In [None]:
PATH_TO_DATA = "../data/"
PATH_TO_HANDWRITTEN = PATH_TO_DATA + "handwritten/"
# train_handwritten_df = convert_dir(PATH_TO_HANDWRITTEN + 'train', PATH_TO_HANDWRITTEN + "train")
# train_handwritten_df.to_csv(PATH_TO_DATA + 'train_handwritten.csv', index=False)
# val_handwritten_df = convert_dir(PATH_TO_HANDWRITTEN + 'test', PATH_TO_HANDWRITTEN + "test")
# val_handwritten_df.to_csv(PATH_TO_DATA + 'val_handwritten.csv', index=False)

train_handwritten_df = pd.read_csv(PATH_TO_DATA + "train_handwritten.csv")
val_handwritten_df = pd.read_csv(PATH_TO_DATA + "val_handwritten.csv")
val_handwritten_df

In [None]:
def fix_path(path):
    return PATH_TO_DATA + "images/" + path


train_df = pd.read_csv(PATH_TO_DATA + "im2latex_train.csv")
val_df = pd.read_csv(PATH_TO_DATA + "im2latex_validate.csv")
test_df = pd.read_csv(PATH_TO_DATA + "im2latex_test.csv")


dataframes = [train_df, val_df, test_df]

for df in dataframes:
    df["image"] = df["image"].map(lambda x: fix_path(x))

print(f"train len before {len(train_df)}")

train_df = pd.concat([train_df, train_handwritten_df])
val_df = pd.concat([val_df, val_handwritten_df])

print(f"train length after {len(train_df)}")

train_df

In [None]:
def get_train_equations():
    train_equations = train_df["formula"]
    with open("../data/train_equations.txt", "w") as f:
        for value in train_equations:
            f.write(str(value) + "\n")


# get_train_equations()

In [None]:
import os


def generate_tokenizer(equations, output, vocab_size):
    from tokenizers import Tokenizer, pre_tokenizers
    from tokenizers.models import BPE
    from tokenizers.trainers import BpeTrainer

    tokenizer = Tokenizer(BPE())
    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
    trainer = BpeTrainer(
        special_tokens=["[PAD]", "[BOS]", "[EOS]"],
        vocab_size=vocab_size,
        show_progress=True,
    )
    tokenizer.train([equations], trainer)
    tokenizer.save(path=output, pretty=False)


# generate_tokenizer('../data/train_equations.txt' ,'../data/tokenizer.json', 8000)
tokenizer = PreTrainedTokenizerFast(tokenizer_file="../data/tokenizer.json")

In [None]:
tokenizer(["boobs"])

In [None]:
class ImagesDataset(Dataset):
    def __init__(self, image_paths, formulas, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        self.formulas = formulas

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths.iloc[index]
        image = Image.open(image_path)
        formula = self.formulas.iloc[index]

        if self.transform:
            image = self.transform(image)

        return image, formula


transform = transforms.Compose(
    [
        transforms.Resize((224, 600)),  # Resize to a specific size
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),  # Convert to tensor
    ]
)

In [None]:
# HYPERPARAMETERS
# -------------------------------------------------
BATCH_SIZE = 32
# -------------------------------------------------

train_dataset = ImagesDataset(
    train_df["image"], train_df["formula"], transform=transform
)
val_dataset = ImagesDataset(
    val_df["image"], val_df["formula"], transform=transform)
test_dataset = ImagesDataset(
    test_df["image"], test_df["formula"], transform=transform)

train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
im, label = next(iter(train_dataloader))

plt.imshow(im[0][0], cmap="gray")
label[0]

## writing module


In [None]:
# transformer block with skip connections and layernorm
class TransformerBlock(nn.Module):
    def __init__(self, n_embd, num_heads, dropout=0.20):
        super().__init__()
        self.embed_dim = n_embd
        self.num_heads = num_heads
        self.dropout = dropout

        self.ln_1 = torch.nn.LayerNorm(n_embd)
        self.attention = SelfAttention(n_embd, num_heads, dropout)
        self.ln_2 = torch.nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, n_embd, bias=False)

    def forward(self, x):
        x = x + self.attention(self.ln_1(x))
        x = x + self.head(self.ln_2(x))
        return x


class PatchEmbeddings(nn.Module):
    def __init__(self, img_size, patch_size, channels=1, embed_dim=512):
        """
        img size: image shape
        """
        super().__init__()
        h, w = img_size
        assert (
            w % patch_size == 0 and h % patch_size == 0
        ), "image not divisable by patch size"
        self.patch_size = patch_size
        self.n_patches = (h // patch_size) * (w // patch_size)

        self.projection = nn.Conv2d(
            channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x):
        """
        x: batch of images (B, H, W)
        return: (B, num_patches, num_embeddings)
        """
        x = self.projection(x)  # (B, n_embd, n_patches / 2, n_patches / 2)
        x = x.flatten(-2)  # (B, n_embd, n_patches)
        x = x.transpose(-2, -1)  # (B, n_patches, n_embd)
        return x


class SelfAttention(nn.Module):
    def __init__(
        self, n_embd, n_heads=8, bias=False, attn_dropout=0.20, proj_dropout=0.20
    ):
        super().__init__()
        assert n_embd % n_heads == 0, "n_embd not divisible by num heads"
        self.n_heads = n_heads
        self.n_embd = n_embd
        self.head_dim = n_embd // n_heads
        self.dk = self.head_dim**-0.5  # sqrt dk for scaling

        self.kqv = nn.Linear(n_embd, n_embd * 3, bias=bias)
        self.projection = nn.Linear(n_embd, n_embd, bias=bias)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.proj_dropout = nn.Dropout(proj_dropout)

    def forward(self, x):
        """
        x: (B, T, n_embd)
        returns: (B, T, n_embd)
        """

        B, T, C, = x.shape  # batch size, num tokens, n_embd
        assert C == self.n_embd, "input size does not equal n_embd"

        kqv = self.kqv(x)  # (B, T, n_embd*3)
        kqv = kqv.reshape(B, T, 3, self.n_heads, self.head_dim)
        kqv = kqv.permute(2, 0, 3, 1, 4)  # (3, B, n_heads, T, head_dim)
        k, q, v = kqv  # (B, n_heads, T, head_dim)

        attention = (
            q @ k.transpose(-1, -2)
        ) * self.dk  # (B, n_heads, T, head_dim) @ (B, n_heads, head_dim, T) -> (B, n_heads, T, T)
        attention = attention.softmax(dim=-1)  # (B, n_heads, T, T)
        attention = self.attn_dropout(attention)
        aggregated_attention = (
            attention @ v
        )  # (B, n_heads, T, T) @ (B, n_heads, T, head_dim) -> (B, n_heads, T, head_dim)
        print(aggregated_attention.shape)
        x = aggregated_attention.transpose(1, 2)  # (B, T, n_heads, C)
        x = x.flatten(2)  # (B, T, C)
        x = self.projection(x)  # (B, T, C)
        x = self.proj_dropout(x)
        return x


x = torch.ones(1, 1, 64, 64)
patch = PatchEmbeddings([64, 64], 16)
# attention = SelfAttention(512)
x = patch(x)
# att = attention(x)
block = TransformerBlock(512, 8)
block(x)