In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import typing
from typing import Type

import numpy as np
import torch

from devinterp.optim import SGLD
from devinterp.slt.sampler import estimate_learning_coeff_with_summary
from devinterp.utils import Outputs
from devinterp.vis_utils import EpsilonBetaAnalyzer


from dgp import get_dataloader
from model import GPT

In [3]:
def load_model_for_iteration(it, dirname, epoch = 0, device = 'cuda' if torch.cuda.is_available() else 'cpu'):
    fname = f'ckpt_epoch_{epoch}_iter_{it}.pt'
    return torch.load(f'{dirname}/{fname}', map_location=device)

In [4]:
model_info = load_model_for_iteration(2, 'results/scratch/tx0a4k6g', epoch = 0)

  return torch.load(f'{dirname}/{fname}', map_location=device)


In [5]:
class modelcfg:
    compile = False
    context_size = 256
    n_layer = 2
    n_head = 2
    n_embd = 128
    dropout = 0.0
    bias = False
    mlp = True

class datacfg:
    n_relative_properties = 100
    n_descriptive_properties = 460
    n_descriptive_values = 40
    n_entities = 100
    num_of_classes_to_divide_over = 10
    prior_param = 0.1
    props_prior_type = 'structured_zeros'
    instr_ratio = 0.8
    max_sample_length = 128
    num_iters = 1e5
    batch_size = 128 
    num_workers = 0

In [6]:
class configClass:
    seed = 2
    model = modelcfg()
    data = datacfg()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

config = configClass()

In [7]:
dataloader = get_dataloader(
    n_relative_properties=config.data.n_relative_properties,
    n_descriptive_properties=config.data.n_descriptive_properties,
    n_descriptive_values=config.data.n_descriptive_values,
    num_of_classes_to_divide_over=config.data.num_of_classes_to_divide_over,
    prior_param=config.data.prior_param,
    props_prior_type=config.data.props_prior_type,
    n_entities=config.data.n_entities,
    instr_ratio=config.data.instr_ratio,
    max_sample_length=config.data.max_sample_length,
    num_iters=config.data.num_iters * config.data.batch_size,
    batch_size=config.data.batch_size,
    num_workers=config.data.num_workers,
    seed=config.seed,
)

In [8]:
model = GPT(config.model, dataloader.dataset.PCSG.vocab_size)

In [26]:
import torch.nn.functional as F
from utils import move_to_device

pad_token_id = dataloader.dataset.pad_token_id
def evaluate_fn(model, data):
    sequences, symb_sequences, seq_lengths, seq_logprobs, _  = data
    B = sequences.size(0)
    inputs, labels = move_to_device([sequences[:,:-1], sequences[:,1:]], config.device)
    labels = labels.clone()
    labels[labels == pad_token_id] = -100  # Mask padding
    logits = model(inputs)  # (B, L-1, V)
    loss = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        labels.reshape(-1),
        ignore_index=-100,
        reduction='none'
        ) # (B*L-1)
    loss = loss.reshape(B, -1).mean() 
    return loss, {}

In [10]:
def estimate_llc_given_model(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    evaluate: typing.Callable,
    epsilon: float,
    beta: float,
    sampling_method: Type[torch.optim.Optimizer] = SGLD,
    localization: float = 100.0,
    num_chains: int = 5,
    num_draws: int = 300,
    num_burnin_steps: int = 0,
    num_steps_bw_draws: int = 1,
    device: torch.device = torch.device("cpu"),
    online: bool = True,
    verbose: bool = False,
):
    sweep_stats = estimate_learning_coeff_with_summary(
        model,
        loader=loader,
        evaluate=evaluate,
        sampling_method=sampling_method,
        optimizer_kwargs=dict(lr=epsilon, localization=localization, nbeta=beta),
        num_chains=num_chains,  # How many independent chains to run
        num_draws=num_draws,  # How many samples to draw per chain
        num_burnin_steps=num_burnin_steps,  # How many samples to discard at the beginning of each chain
        num_steps_bw_draws=num_steps_bw_draws,  # How many steps to take between each sample
        device=device,
        online=online,
        verbose=verbose,
    )

    sweep_stats["llc/trace"] = np.array(sweep_stats["llc/trace"])
    return sweep_stats

In [21]:
for i in dataloader:
    a = evaluate_fn(model, i)
    print(a)
    break

(tensor([1.6524, 1.3355, 1.2542, 1.2507, 1.2550, 2.5033, 2.0303, 2.6794, 1.0213,
        0.9344, 2.5077, 1.4020, 1.2489, 1.5605, 0.9323, 2.4930, 1.2510, 1.8688,
        1.2473, 1.2522, 1.5683, 1.4128, 1.5596, 1.2506, 0.9384, 1.6465, 1.4054,
        1.5553, 1.2573, 1.4103, 0.9306, 2.1877, 1.2576, 1.8736, 4.6827, 1.1718,
        2.2553, 1.0220, 1.8711, 1.2460, 1.0210, 0.9403, 0.9311, 3.7299, 1.2570,
        1.0198, 1.5526, 1.7133, 1.4072, 1.1841, 1.5742, 0.9407, 1.0966, 1.2432,
        1.1011, 1.5002, 0.9306, 1.1691, 2.0317, 1.1819, 1.0913, 1.2613, 1.0901,
        1.5583, 0.8598, 1.8004, 1.2568, 1.2561, 1.7272, 1.5683, 0.9394, 3.2715,
        0.9335, 1.0128, 1.2438, 3.5984, 1.0947, 1.0351, 0.9350, 1.7110, 1.4045,
        1.2592, 2.7183, 1.0263, 3.8890, 1.4931, 0.9389, 1.1781, 2.0471, 1.0212,
        1.2586, 0.9284, 1.2574, 1.8647, 1.5488, 1.3260, 1.2442, 1.0201, 1.8777,
        1.8766, 3.5798, 1.2548, 0.9314, 1.4752, 2.6526, 1.7247, 1.1709, 1.2625,
        1.4092, 2.5129, 1.0910, 1.2497,

In [29]:
analyzer = EpsilonBetaAnalyzer()
analyzer.configure_sweep(
    llc_estimator=estimate_llc_given_model,
    llc_estimator_kwargs=dict(
        model=model, loader=dataloader, evaluate=evaluate_fn, device=config.device
    ),
    min_epsilon=1e-6,
    max_epsilon=1e-2,
    epsilon_samples=8,
    min_beta=None,
    max_beta=None,
    beta_samples=8,
    dataloader=dataloader,
)  # Automatically find a beta range from the optimal beta
analyzer.sweep()

  0%|          | 0/64 [00:00<?, ?it/s]

In [None]:
analyzer.sweep_df.head()

In [None]:
estimate_learning_coeff_with_summary(
    model,
    loader=dataloader,
    evaluate=evaluate_fn,
    sampling_method=SGLD,
    optimizer_kwargs=dict(lr=1e-3, localization=1.0, nbeta=1.0),
    num_chains=1,
    num_draws=30,
    num_burnin_steps=0,
    num_steps_bw_draws=1,
    device=config.device,
    online=True,
    verbose=True,
)