# Save XKCD Model for pre-trained use

In [1]:
import os
import glob
import json
import shutil
import yaml

from magis_sigdial2020.settings import REPO_ROOT
from magis_sigdial2020.models.xkcd_model import XKCDModel
from magis_sigdial2020.datasets.xkcd import XKCD
import numpy as np
import pyromancy
import torch

## Helper functions

In [93]:
def compute_xkcd_metrics(model, xkcd_coordinate_system="fft", device="cpu"):
    dataset = XKCD.from_settings(coordinate_system="fft")
    dataset.set_split("val")
    probas = []
    correct_preds = []
    for batch_dict in dataset.generate_batches(batch_size=256, device=device, drop_last=False, shuffle=False):
        batch_probas = model(batch_dict["x_color_value"])["S0_probability"]
        probas.append(
            batch_probas
            .gather(dim=1, index=batch_dict["y_color_name"].view(-1,1))
            .squeeze()
            .cpu().detach().numpy()
        )
        correct_preds.append(
            torch.eq(
                batch_probas.argmax(axis=1),
                batch_dict["y_color_name"]
            )
            .float()
            .cpu().detach().numpy()
        )
    # singleton vectors happen i guess?
    probas = [p.reshape(-1) for p in probas]
    correct_preds = [p.reshape(-1) for p in correct_preds]
    log_probas = np.log(np.concatenate(probas))
    correct_preds = np.concatenate(correct_preds)
    return {
        "perplexity": np.exp(-1 * log_probas.mean()),
        "accuracy": correct_preds.mean(),
        "nll": -1 * log_probas.mean()
    }

def sanitize_numpy_types(dict_):
    out = {}
    for name, value in dict_.items():
        # fix numpy types that somehow always sneak in
        # see: type(np.arange(5)[0]), and np.arange(5)[0].item()
        value = getattr(value, "item", lambda: value)()
        out[name] = value
    return out

def convert_trial(trial_path, output_path):
    remainder, trial_name = os.path.split(trial_path)
    log_path, exp_name = os.path.split(remainder)
    pyromancy.settings.set_root_output_path(log_path)
    args = pyromancy.utils.get_specific_args(exp_name, trial_name)
    sanitized_args = sanitize_numpy_types(vars(args))
    os.makedirs(output_path, exist_ok=True)
    shutil.copy2(os.path.join(trial_path, 'model.pth'), output_path)
    with open(os.path.join(output_path, 'hparams.yaml'), 'w') as fp:
        yaml.dump(sanitized_args, fp)
        
def freeze_model(lab_subdir, experiment_name, trial_name, output_name, ModelClass=None, verbose=True):
    model_source = os.path.join(REPO_ROOT, "lab", lab_subdir, "logs",  experiment_name, trial_name)
    model_target = os.path.join(REPO_ROOT, "models", output_name)
    convert_trial(model_source, model_target)
    if verbose:
        print(f"Model source: {model_source}")
        print(f"Model written to: {model_target}")
    
    if ModelClass is not None:
        model = ModelClass.from_pretrained(model_source)
        metric_info = sanitize_numpy_types(compute_xkcd_metrics(model))
        metric_target = os.path.join(REPO_ROOT, "models", output_name, 'metric.json')
        with open(metric_target, "w") as fp:
            json.dump(metric_info, fp)
        if verbose:
            print("- Metrics -")
            for name, value in metric_info.items():
                print(f" > {name:<10} = {value:0.4f}")

## XKCD Model freeze

In [95]:
freeze_model(
    lab_subdir="XKCD_model",
    experiment_name="E001_XKCDModel_uncalibrated",
    trial_name="published_version",
    output_name="UncalibratedXKCDModel",
    ModelClass=XKCDModel,
    verbose=True
)

Model source: /r/code/paper_repos/speaker_strategies_sigdial2020/lab/XKCD_model/logs/E001_XKCDModel_uncalibrated/published_version
Model written to: /r/code/paper_repos/speaker_strategies_sigdial2020/models/UncalibratedXKCDModel
- Metrics -
 > perplexity = 12.1320
 > accuracy   = 0.4047
 > nll        = 2.4958


In [94]:
freeze_model(
    lab_subdir="XKCD_model",
    experiment_name="E003_XKCDModel_calibrated",
    trial_name="published_version",
    output_name="CalibratedXKCDModel",
    ModelClass=XKCDModel,
    verbose=True
)

Model source: /r/code/paper_repos/speaker_strategies_sigdial2020/lab/XKCD_model/logs/E003_XKCDModel_calibrated/published_version
Model written to: /r/code/paper_repos/speaker_strategies_sigdial2020/models/CalibratedXKCDModel
- Metrics -
 > perplexity = 12.1566
 > accuracy   = 0.4042
 > nll        = 2.4979
