In [6]:
%reload_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")
sys.path.append("../")
import os
import numpy as np
from hwr_utils import *
from hwr_utils.stroke_plotting import *
from hwr_utils.stroke_recovery import *
import json
from gen_preds_offline import load_all_gts

from hwr_utils import visualize
from torch.utils.data import DataLoader
from loss_module.stroke_recovery_loss import StrokeLoss
from trainers import TrainerStrokeRecovery
from hwr_utils.stroke_dataset import BasicDataset
from hwr_utils.stroke_recovery import *
from hwr_utils import utils
from torch.optim import lr_scheduler
from models.stroke_model import StrokeRecoveryModel
from train_stroke_recovery import parse_args, graph
from hwr_utils.hwr_logger import logger
from pathlib import Path

In [9]:
def post_process(pred,gt):
    return move_bad_points(reference=gt, moving_component=pred, reference_is_image=True)

def eval_only(dataloader, model):
    final_out = []
    for i, item in enumerate(dataloader):
        preds = TrainerStrokeRecovery.eval(item["line_imgs"], model,
                                           label_lengths=item["label_lengths"],
                                           relative_indices=config.pred_relativefy,
                                           sigmoid_activations=config.sigmoid_indices)

        # Pred comes out of eval WIDTH x VOCAB
        preds_to_graph = [post_process(p, item["line_imgs"][i]).permute([1, 0]) for i,p in enumerate(preds)]

        # Get GTs, save to file
        if i<10:
            # Save a sample
            save_folder = graph(item, preds=preds_to_graph, _type="eval", epoch="current", config=config)
            output_path = (save_folder / "data")
            output_path.mkdir(exist_ok=True, parents=True)

        names = [Path(p).stem.lower() for p in item["paths"]]
        output = []
        for ii, name in enumerate(names):
            if name in GT_DATA:
                output.append({"stroke": preds[ii].detach().numpy(), "text":GT_DATA[name]})
            else:
                print(f"{name} not found")
        utils.pickle_it(output, output_path / f"{i}.pickle")
        np.save(output_path / f"{i}.npy", output)
        final_out += output
    utils.pickle_it(final_out, output_path / f"all_data.pickle")
    np.save(output_path / f"all_data.npy", final_out)


In [10]:
config_path = "/media/data/GitHub/simple_hwr/RESULTS/pretrained/brodie_123/stroke_number_with_BCE_RESUME2.yaml"
load_path_override = "/media/data/GitHub/simple_hwr/RESULTS/pretrained/brodie_123/stroke_number_with_BCE_RESUME2_model_123_epochs.pt"

# Make these the same as whereever the file is being loaded from; make the log_dir and results dir be a subset
# main_model_path, log_dir, full_specs, results_dir, load_path
config = utils.load_config(config_path, hwr=False)

# Free GPU memory if necessary
if config.device == "cuda":
    utils.kill_gpu_hogs()

batch_size = config.batch_size
vocab_size = config.vocab_size
device=torch.device(config.device)

output = Path(config.results_dir)
output.mkdir(parents=True, exist_ok=True)
folder = Path(config.dataset_folder)

# OVERLOAD
folder = Path("data/prepare_IAM_Lines/lines/")
gt_path = Path("./data/prepare_IAM_Lines/gts/lines/txt")

model = StrokeRecoveryModel(vocab_size=vocab_size, device=device, cnn_type=config.cnn_type, first_conv_op=config.coordconv, first_conv_opts=config.coordconv_opts).to(device)

## Loader
logger.info(("Current dataset: ", folder))

# Dataset - just expecting a folder
eval_dataset=BasicDataset(root=folder, cnn=model.cnn)
eval_loader=DataLoader(eval_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=6,
                              collate_fn=eval_dataset.collate, # this should be set to collate_stroke_eval
                              pin_memory=False)

config.n_train_instances = None
config.n_test_instances = len(eval_loader.dataset)
config.n_test_points = None

## Stats
if config.use_visdom:
    visualize.initialize_visdom(config["full_specs"], config)
utils.stat_prep_strokes(config)

# Create loss object
config.loss_obj = StrokeLoss(loss_names=config.loss_fns, loss_stats=config.stats, counter=config.counter)
optimizer = torch.optim.Adam(model.parameters(), lr=.0005 * batch_size/32)
config.scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=.95)
trainer = TrainerStrokeRecovery(model, optimizer, config=config, loss_criterion=config.loss_obj)

config.model = model
config.load_path = load_path_override if ("load_path_override" in locals()) else config.load_path

config.sigmoid_indices = TrainerStrokeRecovery.get_indices(config.pred_opts, "sigmoid")

# Load the GTs
load_all_gts(gt_path)
print("Number of images: {}".format(len(eval_loader.dataset)))
print("Number of GTs: {}".format(len(GT_DATA)))

## LOAD THE WEIGHTS
utils.load_model_strokes(config) # should be load_model_strokes??????
model = model.to(device)
model.eval()
eval_only(eval_loader, model)


/media/data/GitHub/simple_hwr/RESULTS/pretrained/brodie_123/stroke_number_with_BCE_RESUME2.yaml
Experiment: new_experiment02, Results Directory: /home/taylor/github/simple_hwr/RESULTS/ver4/20200229_223630-stroke_number_with_BCE_RESUME/new_experiment02
Effective logging level: 20
Using config file /media/data/GitHub/simple_hwr/RESULTS/pretrained/brodie_123/stroke_number_with_BCE_RESUME2.yaml
| ID | GPU  | MEM |
-------------------
|  0 |   0% |  8% |
|  1 | nan% | 46% |
Creating LSTM: in:1024 hidden:128 dropout:0.5 layers:2 out:4
13:09:35 INFO COORD CONV: y_rel
13:09:35 INFO COORD CONV: y_rel
13:09:35 INFO COORD CONV: y_rel
13:09:35 INFO Zero center False
13:09:35 INFO Zero center False
13:09:35 INFO Zero center False
13:09:35 INFO RECT X Coord: False
13:09:35 INFO RECT X Coord: False
13:09:35 INFO RECT X Coord: False
13:09:35 INFO Normalized X Coord: True
13:09:35 INFO Normalized X Coord: True
13:09:35 INFO Normalized X Coord: True
13:09:35 INFO Using ABS+REL X Coord Channels: False
13

  return fix_scientific_notation(yaml.load(config.open(mode="r")))


13:09:35 INFO ('Current dataset: ', PosixPath('data/prepare_IAM_Lines/lines'))
13:09:35 INFO ('Current dataset: ', PosixPath('data/prepare_IAM_Lines/lines'))
13:09:35 INFO ('Current dataset: ', PosixPath('data/prepare_IAM_Lines/lines'))


ValueError: num_samples should be a positive integer value, but got num_samples=0