In [7]:
import sys
sys.path.append("../")
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from utils.inkml2img import convert_dir
from tqdm.auto import tqdm
from data.dataset import Im2LatexDataset
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from model.vit import ViT
import lightning as L
import wandb
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer, pre_tokenizers
from tokenizers.models import BPE
from model.decoder import Decoder, DecoderAttention, DecoderTransformerBlock
from model.model import Img2MathModel


In [8]:
img_dims = [256, 256]
data = Im2LatexDataset(path_to_data="../data/",
                       tokenizer="../data/tokenizer.json", img_dims=img_dims, batch_size=1, device=torch.device('mps'))
img, label = next(iter(data.train))
vocab_size = len(data.tokenizer.get_vocab())
enc_output = torch.randn(1, 512, 512)
model = Img2MathModel(512, 256, vocab_size, .75, img_dims, 16)
model(img[0], input_seq=label['input_ids'], trg_seq=label['input_ids'], mask=label['attention_mask'])

(tensor([[[ 0.3815,  0.5246,  1.2717,  ..., -1.0268,  1.2818,  0.3260],
          [-0.4543,  0.6675,  1.1017,  ..., -0.0573,  0.1573,  0.5705],
          [ 0.1877,  0.3306,  0.9758,  ..., -0.3958,  0.4867, -0.7858],
          ...,
          [ 0.0654, -0.3791,  0.2006,  ...,  1.1883, -0.4036,  0.1054],
          [ 0.1864,  0.5316,  0.6864,  ..., -0.7703,  0.4862, -0.4997],
          [-0.4003,  0.3741,  0.8077,  ...,  0.0701, -0.6029,  0.5063]]],
        grad_fn=<ViewBackward0>),
 tensor(7.8465, grad_fn=<NllLossBackward0>))

In [13]:
batch = next(iter(data.train))
img, labels = batch
trg_seq, input_mask = labels['input_ids'], labels['attention_mask']
zero = torch.zeros(1, 1)
input_seq = torch.cat((zero, trg_seq), dim=-1)
input_seq = input_seq[:, :-1].int()
input_mask = torch.cat((zero, input_mask), dim=-1)
input_mask= input_mask[:, :-1].int()
model(img[0], input_seq=input_seq, trg_seq=trg_seq, mask=input_mask)

(tensor([[[ 0.1854, -1.0171,  1.1940,  ..., -0.5155, -0.2047, -0.7936],
          [-0.1163,  0.2031,  1.0163,  ..., -0.7860,  0.8499, -0.9828],
          [-0.1543,  0.6852,  0.7890,  ...,  1.0069,  1.0043, -0.1216],
          ...,
          [ 0.4720, -0.0263,  0.6055,  ...,  0.2206, -0.2633, -0.3533],
          [-0.2364, -0.2472,  0.4905,  ...,  0.2950,  0.4586, -0.2331],
          [-0.2259,  1.5125,  1.0990,  ...,  0.0903,  0.7188, -0.4231]]],
        grad_fn=<ViewBackward0>),
 tensor(7.9361, grad_fn=<NllLossBackward0>))

In [6]:
logger = WandbLogger(project='img2math')

trainer = L.Trainer(limit_train_batches=10000, max_epochs=1, log_every_n_steps=20, deterministic=True,
                    logger=logger, accelerator='mps')

trainer.fit(model, data.train, data.test)
wandb.finish()


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type    | Params
------------------------------------
0 | encoder | ViT     | 8.9 M 
1 | decoder | Decoder | 14.9 M
------------------------------------
23.9 M    Trainable params
0         Non-trainable params
23.9 M    Total params
95.443    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/loss,▅▁▄▃▆▇▅▆▆▅▄▅▃▃▂▂▂▅▄▅▂▃▄▇▃▆▅▃█▃▇▄▁▄▇▅▅█▅▆

0,1
epoch,0.0
trainer/global_step,5039.0
val/loss,1.25875


torch.Size([1])