In [1]:
!nvidia-smi

Mon Jun  8 05:36:02 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.82       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   60C    P8    11W /  70W |      0MiB / 15079MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/My\ Drive/Colab/ITSP

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
/content/gdrive/My Drive/Colab/ITSP


In [3]:
!ls

build_vocab.py	data.py		  model		    __pycache__  train.py
ckpts		evaluate.py	  preprocess.ipynb  results	 utils.py
data		evaluation.ipynb  preprocess.py     train.ipynb  vocab.ipynb


In [4]:
! pip install distance



In [0]:
# load checkpoint and evaluating
from os.path import join
from functools import partial
import argparse

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from data import Im2LatexDataset
from build_vocab import Vocab, load_vocab
from utils import collate_fn
from model import LatexProducer, Im2LatexModel
from model.score import score_files

In [12]:
from argparse import Namespace

args = Namespace(
    model_path = "ckpts/best_ckpt.pt",

    # model args
    data_path = "./data/",
    cuda = True,
    batch_size = 8,
    beam_size = 5,
    result_path = "./results/result.txt",
    ref_path = "./results/ref.txt",
    max_len = 64,
    split = "validate"
)

args

Namespace(batch_size=8, beam_size=5, cuda=True, data_path='./data/', max_len=64, model_path='ckpts/best_ckpt.pt', ref_path='./results/ref.txt', result_path='./results/result.txt', split='validate')

In [13]:
# Loading Model
checkpoint = torch.load(join(args.model_path))
model_args = checkpoint['args']

# Read the dictionary and set other related parameters
vocab = load_vocab(args.data_path)
use_cuda = True if args.cuda and torch.cuda.is_available() else False

Load vocab including 298 words!


In [14]:
# Load test set
data_loader = DataLoader(
    Im2LatexDataset(args.data_path, args.split, args.max_len),
    batch_size=args.batch_size,
    collate_fn=partial(collate_fn, vocab.sign2id),
    pin_memory=True if use_cuda else False,
    num_workers=4
)

model = Im2LatexModel(
    len(vocab), model_args.emb_dim, model_args.dec_rnn_h,
    add_pos_feat=model_args.add_position_features,
    dropout=model_args.dropout
)
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [0]:
result_file = open(args.result_path, 'w')
ref_file = open(args.ref_path, 'w')

In [16]:
latex_producer = LatexProducer(
    model, vocab, max_len=args.max_len,
    use_cuda=use_cuda, beam_size=args.beam_size)

for imgs, tgt4training, tgt4cal_loss in tqdm(data_loader):
    try:
        reference = latex_producer._idx2formulas(tgt4cal_loss)
        results = latex_producer(imgs)
    except RuntimeError:
        break

    result_file.write('\n'.join(results))
    ref_file.write('\n'.join(reference))

result_file.close()
ref_file.close()
score = score_files(args.result_path, args.ref_path)
print("beam search result:", score)

100%|██████████| 188/188 [00:42<00:00,  4.38it/s]


Loaded 1255 formulas from ./results/result.txt
Loaded 1255 formulas from ./results/ref.txt
beam search result: {'BLEU-4': 1.7542435826729157, 'EM': 0.0, 'Edit': 14.77458992228865}
