# 0. Preliminaries

In [1]:
import json
import random
import os
from PIL import Image
from itertools import product
from tqdm import tqdm
import math
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from collections import Counter

from domain import SYM2PROG, Program, NULL_VALUE
import sys
from helper import *

# 1. Split handwritten symbols

In [2]:
# split train and test set of math symbols
random.seed(777)
terminals = ['+', '-', 'times', 'div', '(', ')', '!'] + list(map(str, list(range(10)))) + ['alpha', 'beta', 'gamma', 'theta', 'phi'] + list('abcdexyz')
sym_train_set = {}
sym_val_set = {}
sym_test_set = {}
for tok in terminals:
    imgs = os.listdir(symbol_images_dir + tok)
    random.shuffle(imgs)
    n_train = int(len(imgs) * 0.75)
    n_val = int(len(imgs) * 0.05)
    sym_train_set[tok] = sorted(imgs[:n_train])
    sym_val_set[tok] = sorted(imgs[n_train:n_train+n_val])
    sym_test_set[tok] = sorted(imgs[n_train+n_val:])
    print(tok, len(imgs))
json.dump(sym_train_set, open('sym_train.json', 'w'))
json.dump(sym_val_set, open('sym_val.json', 'w'))
json.dump(sym_test_set, open('sym_test.json', 'w'))

+ 5443
- 6022
times 600
div 157
( 3986
) 3978
! 224
0 1810
1 6327
2 6210
3 2469
4 1641
5 1008
6 812
7 753
8 731
9 742
alpha 383
beta 295
gamma 90
theta 543
phi 83
a 2724
b 1830
c 1194
d 1062
e 616
x 5333
y 1895
z 1075


# 2. Generate expressions

In [3]:
random.seed(777)
split2sym = {}
for split in splits:
    sym_set = {k:Iterator(v) for k,v in json.load(open('sym_%s.json'%split)).items()}
    split2sym[split] = sym_set

In [4]:
random.seed(12306)
max_op = 20
max_op_train = 10
max_value_train = 100
max_value_test = 10000
n_train = 100000
n_test = 1000
n_val = 100
res_max_ratio = 0.05

train_exprs = []
val_exprs = []
test_exprs = []

for n_op in range(max_op + 1):
    if n_op > max_op_train:
        n_train = 0
    expressions = generate_expression(n_op, n_train + n_val + n_test, max_value=max_value_train, res_max_ratio=res_max_ratio)
    if len(expressions) < n_train:
        print('there are not enough expressions for train set. Repeat them.')
        if n_op == 0:
            expressions = expressions * (20 * n_train // 1000)
        elif n_op == 1:
            expressions = expressions * (2 * n_train // 1000)
        else:
            expressions = expressions * (n_train // len(expressions) + 1)
            expressions = expressions[:n_train]
    
    train_exprs_i = []
    val_exprs_i = []
    test_exprs_i = []

    # split 'I'
    if n_train > 0:
        train_exprs_i.extend(expressions[:n_train])
        val_exprs_i.extend(expressions[:len(train_exprs_i) * n_val // n_train])
        test_exprs_i.extend(expressions[:len(train_exprs_i) * n_test // n_train])
    
    # split 'SS' or 'LS'
    val_exprs_i.extend(expressions[n_train:n_train+n_val])
    test_exprs_i.extend(expressions[n_train+n_val:])
    
    # split 'SL' or 'LL'
    expressions = generate_expression(n_op, n_val + n_test, min_value=max_value_train+1, max_value=max_value_test, res_max_ratio=res_max_ratio)
    val_exprs_i.extend(expressions[:n_val])
    test_exprs_i.extend(expressions[n_val:n_val+n_test])

    print(n_op, len(train_exprs_i), len(val_exprs_i), len(test_exprs_i))
    train_exprs.append(train_exprs_i)
    val_exprs.append(val_exprs_i)
    test_exprs.append(test_exprs_i)

    
split2exprs = {'train': train_exprs, 'val': val_exprs, 'test': test_exprs}
print([(k, sum([len(x) for x in v])) for k, v in split2exprs.items()])

there are not enough expressions for train set. Repeat them.
0 20000 20 200
there are not enough expressions for train set. Repeat them.
1 78000 78 780
there are not enough expressions for train set. Repeat them.


  1%|▍                                 | 1408/101100 [00:00<00:07, 14066.83it/s]

2 100000 200 1640


100%|██████████████████████████████████| 101100/101100 [05:58<00:00, 282.34it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 896.81it/s]
  1%|▎                                 | 1073/101100 [00:00<00:09, 10723.11it/s]

3 100000 300 3000


100%|██████████████████████████████████| 101100/101100 [06:14<00:00, 270.10it/s]
100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 1244.06it/s]
  1%|▎                                   | 947/101100 [00:00<00:10, 9423.17it/s]

4 100000 300 3000


100%|██████████████████████████████████| 101100/101100 [06:59<00:00, 241.23it/s]
100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 1221.19it/s]
  1%|▎                                   | 772/101100 [00:00<00:13, 7711.07it/s]

5 100000 300 3000


100%|██████████████████████████████████| 101100/101100 [07:41<00:00, 219.20it/s]
100%|█████████████████████████████████████| 1100/1100 [00:01<00:00, 1039.12it/s]
  1%|▏                                   | 585/101100 [00:00<00:17, 5847.33it/s]

6 100000 300 3000


100%|██████████████████████████████████| 101100/101100 [09:55<00:00, 169.83it/s]
100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 1101.55it/s]
  1%|▏                                   | 526/101100 [00:00<00:19, 5227.86it/s]

7 100000 300 3000


100%|██████████████████████████████████| 101100/101100 [12:29<00:00, 134.86it/s]
100%|█████████████████████████████████████| 1100/1100 [00:01<00:00, 1060.44it/s]
  0%|▏                                   | 380/101100 [00:00<00:26, 3796.60it/s]

8 100000 300 3000


100%|██████████████████████████████████| 101100/101100 [12:19<00:00, 136.76it/s]
100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 1108.09it/s]
  0%|▏                                   | 383/101100 [00:00<00:26, 3822.21it/s]

9 100000 300 3000


100%|██████████████████████████████████| 101100/101100 [14:22<00:00, 117.27it/s]
100%|█████████████████████████████████████| 1100/1100 [00:01<00:00, 1092.49it/s]
 30%|███████████▍                          | 330/1100 [00:00<00:00, 3299.63it/s]

10 100000 300 3000


100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 2191.80it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 939.04it/s]
 26%|█████████▉                            | 287/1100 [00:00<00:00, 2864.58it/s]

11 0 200 2000


100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 2021.51it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 909.00it/s]
 23%|████████▋                             | 250/1100 [00:00<00:00, 2497.77it/s]

12 0 200 2000


100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 1791.30it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 857.66it/s]
 19%|███████▎                              | 213/1100 [00:00<00:00, 2122.20it/s]

13 0 200 2000


100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 1509.68it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 796.10it/s]
 18%|██████▊                               | 198/1100 [00:00<00:00, 1977.23it/s]

14 0 200 2000


100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 1290.03it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 762.75it/s]
 15%|█████▌                                | 162/1100 [00:00<00:00, 1619.28it/s]

15 0 200 2000


100%|█████████████████████████████████████| 1100/1100 [00:00<00:00, 1101.86it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 732.23it/s]
 14%|█████▏                                | 149/1100 [00:00<00:00, 1487.14it/s]

16 0 200 2000


100%|█████████████████████████████████████| 1100/1100 [00:01<00:00, 1051.18it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 679.92it/s]
 12%|████▍                                 | 127/1100 [00:00<00:00, 1264.74it/s]

17 0 200 2000


100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 867.75it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 598.33it/s]
  9%|███▌                                  | 104/1100 [00:00<00:00, 1033.58it/s]

18 0 200 2000


100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 808.02it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 609.16it/s]
  8%|███▏                                    | 89/1100 [00:00<00:01, 885.60it/s]

19 0 200 2000


100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 671.73it/s]
100%|██████████████████████████████████████| 1100/1100 [00:01<00:00, 565.41it/s]


20 0 200 2000
[('train', 998000), ('val', 4698), ('test', 46620)]


In [5]:
split2dataset = {}
for split in splits:
    sym_set = split2sym[split]
    exprs = split2exprs[split]
    dataset = []
    idx = 0
    for n_op, expr_list in enumerate(exprs):
        for e in expr_list:
            if split in ['val', 'test']:
                if e[0] in train_exprs:
                    evaluation = 'I'
                else:
                    max_value = max([x for x in e[3] if x is not None])
                    evaluation = 'S' if n_op <= max_op_train else 'L'
                    evaluation += 'S' if max_value <= max_value_train else 'L'
            
            img_paths = generate_img_paths(e[0], sym_set)
            sample = {'id': '%s_%08d'%(split, idx), 'img_paths':img_paths, 
                  'expr': e[0], 'head': e[1], 'res': e[2], 'res_all': e[3]}
            if split in ['val', 'test']:
                sample['eval'] = evaluation
            idx += 1
            dataset.append(sample)
    split2dataset[split] = dataset
    if split == 'train':
        train_exprs = {x['expr'] for x in dataset}
    print(split, len(dataset))
for split in splits:
    json.dump(split2dataset[split], open('expr_%s.json'%split, 'w'))

train 998000
val 4698
test 46620


In [6]:
counts = sorted(Counter([x['eval'] for x in split2dataset['test']]).items())
total_count = sum([c for _, c in counts])
counts = [(k, v, round(v/total_count*100, 2)) for k, v in counts]
print(counts)

[('I', 9980, 21.41), ('LL', 10000, 21.45), ('LS', 10000, 21.45), ('SL', 8640, 18.53), ('SS', 8000, 17.16)]
