## Check image sizes

In [9]:
import os
import imagesize
shapes = [(fn,imagesize.get('train/'+fn)) for fn in os.listdir('train') if fn.endswith('jpg')]

In [None]:
# plot the distribution of image shapes
import matplotlib.pyplot as plt
import numpy as np
shapes_np = np.array(list(zip(*shapes))[1])
# plt.scatter(shapes_np[:,0],shapes_np[:,1])
# plt.xlabel('width')
# plt.ylabel('height')

# plot with seaborn using scatter plot x and y are the width and height of the images
# set xlabel and ylabel to width and height
import seaborn as sns
sns.jointplot(x=shapes_np[:,0],y=shapes_np[:,1],kind='scatter').set_axis_labels('width','height')


In [3]:
shapes = sorted(shapes,key=lambda x: x[1][0]*x[1][1])

In [4]:
import shutil

In [5]:
for fn, _ in (shapes[:10] + shapes[-10:]):
    shutil.copy('train/'+fn, 'train_small/'+fn)

In [None]:
shapes[:10] + shapes[-10:]

In [7]:
shapes = [(fn, (w,h), w*h) for fn, (w,h)  in shapes]

In [None]:
[s for s in shapes if s[2] > 5e3][:10]

In [9]:
for fn, _,_ in [s for s in shapes if s[2] > 5e3][:10]:
    shutil.copy('train/'+fn, 'train_small/'+fn)


In [None]:
len([s[2] for s in shapes if s[2] > 5e3])/ len(shapes)

In [None]:
# plot histogram of sizes

import matplotlib.pyplot as plt
plt.hist([s[2] for s in shapes if s[2] < 4e4], bins=1000)
plt.show()

In [None]:
import numpy as np
shapes = np.array(shapes)
shapes.min(axis=0), shapes.max(axis=0), shapes.mean(axis=0), shapes.std(axis=0)

In [None]:
shapes[8933], shapes[10446]

In [None]:
np.where(shapes[:,0] == 7891), np.where(shapes[:,1] == 4686)

## Check vocab, create vocab map

In [15]:
lines = [line.strip().split('\t') for line in open('train_ssml_sd.txt').readlines()]

In [None]:
from collections import Counter
vocab_counter = Counter()
[vocab_counter.update(line[1].split()) for line in lines]
# [line.split() for line in list(zip(*lines))[1]]

In [17]:
label_dict = { line.strip().split('\t')[0]:line.strip().split('\t')[1]for line in open ('train_ssml_sd_zero.txt').readlines()}


In [None]:
label_dict['train_00000']

In [None]:
vocab_counter.most_common(10)

In [None]:
len(vocab_counter)

In [None]:
vocab_counter.most_common()[200:300]

In [None]:
vocab_counter.most_common()[-10:]

In [None]:
[(v,c) for v,c in vocab_counter.items() if c < 10]

In [None]:
import matplotlib.pyplot as plt
plt.hist([c for v,c in vocab_counter.most_common()[100:]], bins=500)
plt.show()

### sort vocab

In [22]:
vocab_map = {u: v for u, v in [line.strip().split() for line in open('vocab_maps_init.txt').readlines()]}
vocab_map.update({v: '<unk>' for v in sorted(vocab_counter.keys())[244:318] if vocab_counter[v] == 1 and v not in vocab_map})
vocab_map.update({v: '<unk>' for v in sorted(vocab_counter.keys())[378:] if v not in vocab_map})

In [23]:
with open('vocab_map.txt', 'w') as f:
    for k, v in vocab_map.items():
        f.write(f'{k}\t{v}\n')

In [None]:
# sorted(vocab_counter.keys())[244:318]
[(v,vocab_counter[v]) for v in sorted(vocab_counter.keys())[244:318] if vocab_counter[v] < 10] # -> <unk>
# sorted(vocab_counter.keys())[378:]


## Vocab symbols

In [None]:
vocab_map = {u: v for u, v in [line.strip().split('\t') for line in open('vocab_map.txt').readlines()]}
vocab_map.update({v:v for v in vocab_counter.keys() if v not in vocab_map})

In [None]:
with open('vocab_syms_full.txt', 'w') as f:
    f.write('\n'.join(sorted(set([vocab_map[v] for v in sorted(vocab_counter.keys())]))))



In [5]:
with open('vocab_syms_full.txt', 'w') as f:
    f.write('\n'.join(sorted(vocab_counter.keys())))



In [None]:
from vocab import Vocab
vocab = Vocab('vocab_syms.txt')

In [None]:
vocab.word2idx["<unk>"], vocab.UNK_IDX

## Check sequence length

In [4]:
captions = [line.strip().split('\t')[1] for line in open('train_ssml_sd.txt').readlines()]

In [6]:
captions_tokens = [caption.split() for caption in captions]

In [8]:
from collections import Counter
token_counter = Counter()

token_counter.update(list(map(len, captions_tokens)))

In [None]:
# plot histogram of token lengths
import matplotlib.pyplot as plt
plt.hist(list(map(len, captions_tokens)), bins=100)

### Check CROHME dataset (offline)

In [1]:

from typing import List, Tuple

from zipfile import ZipFile
from PIL import Image

Data = List[Tuple[str, Image.Image, List[str]]]


def extract_data(archive: ZipFile, dir_name: str) -> Data:
    """Extract all data need for a dataset from zip archive

    Args:
        archive (ZipFile):
        dir_name (str): dir name in archive zip (eg: train, test_2014......)

    Returns:
        Data: list of tuple of image and formula
    """
    with archive.open(f"{dir_name}/caption.txt", "r") as f:
        captions = f.readlines()
    data = []
    for line in captions:
        tmp = line.decode().strip().split()
        img_name = tmp[0]
        formula = tmp[1:]
        with archive.open(f"{dir_name}/{img_name}.bmp", "r") as f:
            # move image to memory immediately, avoid lazy loading, which will lead to None pointer error in loading
            img = Image.open(f).copy()
        data.append((img_name, img, formula))

    print(f"Extract data from: {dir_name}, with data size: {len(data)}")

    return data


In [None]:
train_data = extract_data(ZipFile('data.zip'), 'train')

In [5]:
train_data[:5]

imgsizes = [img.size for _, img, _ in train_data]

In [None]:
imgsizes[:5]
# plot scatter of image sizes
import matplotlib.pyplot as plt
# set axis name x: width, y: height
plt.scatter(*zip(*imgsizes))
plt.xlabel('width')
plt.ylabel('height')
