In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [36]:
import typing
from typing import Type

import numpy as np
import torch

from devinterp.optim import SGLD
from devinterp.slt.sampler import estimate_learning_coeff_with_summary
import devinterp.utils as utils

import re
import yaml
from dataclasses import dataclass


from dgp import get_dataloader
from model import GPT

import torch.nn.functional as F
from utils import move_to_device

import pickle
import os

In [None]:
config_file = "config/conf.yaml"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

hf_repo_name = "oohv6uys"
dump_dir = f'results/scratch/{hf_repo_name}/llc'
model_dir = f'results/scratch/{hf_repo_name}'


os.makedirs(dump_dir, exist_ok=True)

In [30]:
def load_model_for_iteration(it, dirname, epoch = 0, device = device):
    fname = f'ckpt_epoch_{epoch}_iter_{it}.pt'
    return torch.load(f'{dirname}/{fname}', map_location=device)

In [41]:
loader = yaml.SafeLoader
loader.add_implicit_resolver(
    u'tag:yaml.org,2002:float',
    re.compile(u'''^(?:
     [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
    |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
    |\\.[0-9_]+(?:[eE][-+][0-9]+)?
    |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
    |[-+]?\\.(?:inf|Inf|INF)
    |\\.(?:nan|NaN|NAN))$''', re.X),
    list(u'-+0123456789.'))

conf_yaml = yaml.load(open(config_file), Loader=loader)

# convert yaml into a dataclass recursively
@dataclass
class Conf:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, dict):
                setattr(self, k, Conf(**v))
            else:
                setattr(self, k, v)
    
    def __repr__(self):
        yaml_str = yaml.dump(self.__dict__)
        return yaml_str
    
    def __str__(self) -> str:
        return self.__repr__()

config = Conf(**conf_yaml)

config.device = device

In [44]:
dataloader = get_dataloader(
    n_relative_properties=config.data.n_relative_properties,
    n_descriptive_properties=config.data.n_descriptive_properties,
    n_descriptive_values=config.data.n_descriptive_values,
    num_of_classes_to_divide_over=config.data.num_of_classes_to_divide_over,
    prior_param=config.data.prior_param,
    props_prior_type=config.data.props_prior_type,
    n_entities=config.data.n_entities,
    instr_ratio=config.data.instr_ratio,
    max_sample_length=config.data.max_sample_length,
    num_iters=config.data.num_iters * config.data.batch_size,
    batch_size=config.data.batch_size,
    num_workers=config.data.num_workers,
    seed=config.seed,
)

In [7]:
pad_token_id = dataloader.dataset.pad_token_id
def evaluate_fn(model, data):
    sequences, symb_sequences, seq_lengths, seq_logprobs, _  = data
    B = sequences.size(0)
    inputs, labels = move_to_device([sequences[:,:-1], sequences[:,1:]], config.device)
    labels = labels.clone()
    labels[labels == pad_token_id] = -100  # Mask padding
    logits = model(inputs)  # (B, L-1, V)
    loss = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        labels.reshape(-1),
        ignore_index=-100,
        reduction='none'
        ) # (B*L-1)
    loss = loss.reshape(B, -1).mean() 
    return loss, {}

In [8]:
def estimate_llc_given_model(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    evaluate: typing.Callable,
    epsilon: float,
    beta: float,
    sampling_method: Type[torch.optim.Optimizer] = SGLD,
    localization: float = 100.0,
    num_chains: int = 5,
    num_draws: int = 300,
    num_burnin_steps: int = 0,
    num_steps_bw_draws: int = 1,
    device: torch.device = torch.device("cpu"),
    online: bool = True,
    verbose: bool = False,
):
    sweep_stats = estimate_learning_coeff_with_summary(
        model,
        loader=loader,
        evaluate=evaluate,
        sampling_method=sampling_method,
        optimizer_kwargs=dict(lr=epsilon, localization=localization, nbeta=beta),
        num_chains=num_chains,  # How many independent chains to run
        num_draws=num_draws,  # How many samples to draw per chain
        num_burnin_steps=num_burnin_steps,  # How many samples to discard at the beginning of each chain
        num_steps_bw_draws=num_steps_bw_draws,  # How many steps to take between each sample
        device=device,
        online=online,
        verbose=verbose,
    )

    sweep_stats["llc/trace"] = np.array(sweep_stats["llc/trace"])
    return sweep_stats

In [9]:
# analyzer = EpsilonBetaAnalyzer()
# analyzer.configure_sweep(
#     llc_estimator=estimate_llc_given_model,
#     llc_estimator_kwargs=dict(
#         model=model, loader=dataloader, evaluate=evaluate_fn, device=config.device
#     ),
#     min_epsilon=1e-6,
#     max_epsilon=1e-2,
#     epsilon_samples=8,
#     min_beta=None,
#     max_beta=None,
#     beta_samples=8,
#     dataloader=dataloader,
# )  # Automatically find a beta range from the optimal beta
# analyzer.sweep()

In [10]:
# analyzer.sweep_df.head()

In [11]:
def calculate_llc_default(iteration, dataloader):
    if torch.cuda.is_available() and device == 'cuda':
        torch.cuda.empty_cache()
    model_info = load_model_for_iteration(iteration, model_dir, epoch = 0)
    model = GPT(config.model, dataloader.dataset.PCSG.vocab_size)
    model.load_state_dict(model_info['net'])
    if torch.cuda.is_available() and device == 'cuda':
        torch.cuda.empty_cache()
    return estimate_learning_coeff_with_summary(
        model,
        loader=dataloader,
        evaluate=evaluate_fn,
        sampling_method=SGLD,
        optimizer_kwargs=dict(lr=1e-3, localization=200.0, nbeta= utils.default_nbeta(dataloader)),
        num_chains=20,
        num_draws=200,
        num_burnin_steps=0,
        num_steps_bw_draws=1,
        device=config.device,
        online=True,
        verbose=True,
    )

In [12]:
llc_outputs = []
iters = list(range(0, 1201, 100))
for iter in iters:
    llc_output = calculate_llc_default(iter, dataloader)
    llc_outputs.append(llc_output)
    with open(f'{dump_dir}/llc_output_it_{iter}.pkl', 'wb') as f:
        pickle.dump(llc_output, f)

  return torch.load(f'{dirname}/{fname}', map_location=device)


Chain 0: 100%|██████████| 200/200 [01:16<00:00,  2.63it/s]
Chain 1: 100%|██████████| 200/200 [01:16<00:00,  2.62it/s]
Chain 2: 100%|██████████| 200/200 [01:16<00:00,  2.63it/s]
Chain 3: 100%|██████████| 200/200 [01:16<00:00,  2.62it/s]
Chain 4: 100%|██████████| 200/200 [01:16<00:00,  2.62it/s]
Chain 5: 100%|██████████| 200/200 [01:16<00:00,  2.62it/s]
Chain 6: 100%|██████████| 200/200 [01:16<00:00,  2.62it/s]
Chain 7: 100%|██████████| 200/200 [01:16<00:00,  2.63it/s]
Chain 8: 100%|██████████| 200/200 [01:16<00:00,  2.63it/s]
Chain 9: 100%|██████████| 200/200 [01:16<00:00,  2.61it/s]
Chain 10: 100%|██████████| 200/200 [01:16<00:00,  2.61it/s]
Chain 11: 100%|██████████| 200/200 [01:16<00:00,  2.63it/s]
Chain 12: 100%|██████████| 200/200 [01:16<00:00,  2.61it/s]
Chain 13: 100%|██████████| 200/200 [01:16<00:00,  2.61it/s]
Chain 14: 100%|██████████| 200/200 [01:15<00:00,  2.64it/s]
Chain 15: 100%|██████████| 200/200 [01:16<00:00,  2.62it/s]
Chain 16: 100%|██████████| 200/200 [01:16<00:00,  