In [1]:
import torch
import os
from train import id_to_string
from metrics import word_error_rate, sentence_acc
from checkpoint import load_checkpoint
from torchvision import transforms
from dataset import LoadEvalDataset, collate_eval_batch, START, PAD
from flags import Flags
from utils import get_network, get_optimizer
import csv
from torch.utils.data import DataLoader
import argparse
import random
from tqdm import tqdm

In [2]:
import numpy as np
import pandas as pd
import pdb

# DataFrame 생성

In [3]:
# dataframe 생성
def create_df(default_image_path:str, gt_path:str):
    data = dict()
    data['name'] = []
    data['latex_str'] = []
    data['latex'] = []
    data['invisible_braket_cnt'] = []
    all_latex_list = []
    with open(gt_path) as f:
        for idx,line in enumerate(f):
            image_name, latex=line.replace("\n","").split("\t")
            data['name'].append(image_name)
            data['latex_str'].append(latex) # 추가 4
            latex=latex.split(" ")
            cnt = 0
            for ch in latex:
                if ch in ['\left.', '\right.']:
                    cnt += 1
            data['invisible_braket_cnt'].append(cnt)
            data['latex'].append(latex)
            all_latex_list += latex
    df = pd.DataFrame.from_dict(data)
    return df, all_latex_list

In [4]:
df, all_latex_list = create_df(
    default_image_path="/opt/ml/input/data/train_dataset/images/",
    gt_path="/opt/ml/input/data/train_dataset/gt.txt",
)

In [5]:
df.head()

Unnamed: 0,name,latex_str,latex,invisible_braket_cnt
0,train_00000.jpg,4 \times 7 = 2 8,"[4, \times, 7, =, 2, 8]",0
1,train_00001.jpg,a ^ { x } > q,"[a, ^, {, x, }, >, q]",0
2,train_00002.jpg,8 \times 9,"[8, \times, 9]",0
3,train_00003.jpg,\sum _ { k = 1 } ^ { n - 1 } b _ { k } = a _ {...,"[\sum, _, {, k, =, 1, }, ^, {, n, -, 1, }, b, ...",0
4,train_00004.jpg,I = d q / d t,"[I, =, d, q, /, d, t]",0


In [6]:
df[df['invisible_braket_cnt'] > 0]

Unnamed: 0,name,latex_str,latex,invisible_braket_cnt
42,train_00042.jpg,x \left. \right) > 0,"[x, \left., \right), >, 0]",1
280,train_00280.jpg,\left. L ^ { - 1 } \left( \frac { 1 } { 2 s } ...,"[\left., L, ^, {, -, 1, }, \left(, \frac, {, 1...",1
392,train_00392.jpg,"\left. i , j : 1 , 2 \right)","[\left., i, ,, j, :, 1, ,, 2, \right)]",1
400,train_00400.jpg,"\left. , x + y \neq 1 \right)","[\left., ,, x, +, y, \neq, 1, \right)]",1
425,train_00425.jpg,\left. P \right) \Rightarrow { N R P } = \frac...,"[\left., P, \right), \Rightarrow, {, N, R, P, ...",1
...,...,...,...,...
99669,train_99669.jpg,\left. + \Delta x \right),"[\left., +, \Delta, x, \right)]",1
99716,train_99716.jpg,\left. - a _ { n } \right) \rightarrow \left( ...,"[\left., -, a, _, {, n, }, \right), \rightarro...",1
99802,train_99802.jpg,\lim _ { n \to \infty } \left( \sqrt { \left. ...,"[\lim, _, {, n, \to, \infty, }, \left(, \sqrt,...",1
99888,train_99888.jpg,"\left. \frac { \overline { x } _ { n } , S _ {...","[\left., \frac, {, \overline, {, x, }, _, {, n...",1


# 모델

In [7]:
from dataset import dataset_loader, START, PAD,load_vocab

In [10]:
################################# paresr 부분 대체하는 코드 ########################
# 이후 parser에 접근 하는 코드 모두 알맞게 수정
cfg = dict()
# 1. 학습된 모델 체크포인트
cfg['checkpoint'] = '/opt/ml/code/log/satrn_50/checkpoints/0050.pth'
# 2. EDA에 사용할 데이터
cfg['file_path'] = '/opt/ml/input/data/train_dataset/gt.txt'
# 3. 사용할 데이터 범위 (train_dataset 십만개 기준, 한번에 하면 에러가 나서.)
cfg['start'] = 30000
cfg['end'] = 35000
# 4. 기타
cfg['max_sequence'] = 230
cfg['batch_size'] = 32
cfg['log_dir'] = '/opt/ml/code/eda/'
####################################################################################

# 1. 체크포인트 로드
is_cuda = torch.cuda.is_available()
checkpoint = load_checkpoint(cfg['checkpoint'], cuda=is_cuda) #checkpoint = load_checkpoint(parser.checkpoint, cuda=is_cuda)

# 2. 체크포인트로부터 옵션 로드 & seed pix
options = Flags(checkpoint["configs"]).get()
torch.manual_seed(options.seed)
random.seed(options.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

hardware = "cuda" if is_cuda else "cpu"
device = torch.device(hardware)
print("--------------------------------")
print("Running {} on device {}\n".format(options.network, device))

model_checkpoint = checkpoint["model"]
if model_checkpoint:
    print(
        "[+] Checkpoint\n",
        "Resuming from epoch : {}\n".format(checkpoint["epoch"]),
    )
print(options.input_size.height)

##################################################################
dummy_gt = "\sin " * cfg['max_sequence']  # set maximum inference sequence
root = os.path.join(os.path.dirname(cfg['file_path']), "images")
# data = []
test_data = []
gt = []
with open(cfg['file_path'], "r") as fd:
    reader = csv.reader(fd, delimiter="\t")
    for i, x in enumerate(reader):
        if i < cfg['start']:
            continue
        if i >= cfg['end']:
            break
        test_data.append([os.path.join(root, x[0]), x[0], dummy_gt])
        gt.append(x[1])
#################################################################
transformed = transforms.Compose(
        [
            transforms.Resize((options.input_size.height, options.input_size.width)),
            transforms.ToTensor(),
        ]
    )
test_dataset = LoadEvalDataset(
    test_data,
    checkpoint["token_to_id"],
    checkpoint["id_to_token"],
    crop=False,
    transform=transformed,
    rgb=options.data.rgb
)
test_data_loader = DataLoader(
    test_dataset,
    batch_size=cfg['batch_size'],
    shuffle=False,
    num_workers=options.num_workers,
    collate_fn=collate_eval_batch,
)

print("[+] Data\n",
      "The number of test samples : {}\n".format(len(test_dataset)),)

model = get_network(
    options.network,
    options,
    model_checkpoint,
    device,
    test_dataset,
)
model.eval()
print(
    "[+] Network\n",
    "Type: {}\n".format(options.network),
)

--------------------------------
Running SATRN on device cuda

[+] Checkpoint
 Resuming from epoch : 50

128
[+] Data
 The number of test samples : 5000

[+] Network
 Type: SATRN



In [11]:
# default_image_path="/opt/ml/input/data/train_dataset/images/"
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
def origin_img(img_path):
    img = mpimg.imread(img_path)
    plt.grid(True, color='w')
    plt.imshow(img)
    plt.show()

In [12]:
def cal_invisible_braket(gt, pred):
    cnt1, cnt2 = 0, 0
    for token in gt.split():
        if token in ['\left.', '\right.']:
            cnt1 += 1
    for token in pred.split():
        if token in ['\left.', '\right.']:
            cnt2 += 1
    return cnt1, cnt2

In [None]:
results = [] # file_name, pred, gt, invisible_cnt(gt's, pred's)
i = 0
for d in tqdm(test_data_loader):
    input = d["image"].to(device)
    expected = d["truth"]["encoded"].to(device) # dummy_gt
    output = model(input, expected, False, 0.0)
    decoded_values = output.transpose(1, 2)
    _, sequence = torch.topk(decoded_values, 1, dim=1)
    sequence = sequence.squeeze(1)
    sequence_str = id_to_string(sequence, test_data_loader, do_eval=1)
    for name, predicted in zip(d["file_path"], sequence_str):
        results.append((name, predicted, gt[i], cal_invisible_braket(gt[i], predicted)))
        i += 1

 86%|████████▌ | 135/157 [05:10<00:51,  2.32s/it]

In [None]:
with open(os.path.join(cfg['log_dir'], "error_eda1.csv"), "a") as a:
    for name, pred, gt, invisible_cnt in results:
        if pred.strip() != gt.strip():
#             print(name + "\t" + pred + "\t" + gt + "\t" + str(invisible_cnt[0]) + "\t" + str(invisible_cnt[0]-invisible_cnt[1]) + '\n')
            a.write(name + "\t" + pred + "\t" + gt + "\t" + str(invisible_cnt[0]) + "\t" + str(invisible_cnt[0]-invisible_cnt[1]) + '\n')
#             origin_img('/opt/ml/input/data/train_dataset/images/'+name)
#             print(f'name: {name}')
#             print(f'gt: {gt}')
#             print(f'pd: {pred}')
#             print(f'{invisible_cnt}')