In [None]:
%load_ext autoreload
%autoreload 2
import transformers
import torch
import os.path
import math
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from mcllm.model.llm import *
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.utils.data as data
from mcllm.data.synthetic import LowRankDataset

import lightning.pytorch as pl
import mcllm.config
import mcllm.model

In [None]:
# create a dataloader
m = 10
n = 10
rank = 3
frac_nan_mask = 0.1
seed = 13
use_rowcol_attn = 1
n_registers = 0
dataset = LowRankDataset(m, n, rank, frac_nan_mask, seed=seed,
                         n_registers=n_registers, use_rowcol_attn=use_rowcol_attn)

In [None]:
# visualize each of these matrices as a heatmap with title as the variable name
x_nan_t, x_clean_t, nan_mask_t, att_mask_t, register_mask_t = dataset[0]
mat_dict = {
    'x_nan': x_nan_t,
    'x_clean': x_clean_t,
    'nan_mask': nan_mask_t,
    'att_mask': att_mask_t,
    'register_mask': register_mask_t
}
for i, (k, v) in enumerate(mat_dict.items()):
    plt.subplot(2, 3, i+1)
    if not k == 'att_mask':
        plt.imshow(v.reshape(m + n_registers, n + n_registers))
    else:
        plt.imshow(v)
    plt.title(f'{k} mean: {v.mean().item():.2f}', fontsize='x-small')

# att_mask 0
plt.subplot(2, 3, i+1 + 1)
plt.title('att_mask row 0', fontsize='x-small')
plt.imshow(att_mask_t[0].reshape(m + n_registers, n + n_registers))

# Load model

In [None]:

checkpoint = os.path.join(mcllm.config.path_to_repo,
                          'results', 'attn=rowcol__reg=2__small', 'epoch=163-step=1312.ckpt')

In [None]:


# checkpoint = os.path.join(mcllm.config.path_to_repo,
#   'results', 'attn=rowcol__reg=2__small', 'epoch=163-step=1312.ckpt') #, 'checkpoint', 'mp_rank_00_model_states.pt')
checkpoint = os.path.join(mcllm.config.path_to_repo,
                          'results', 'attn=rowcol__reg=2__small', 'epoch=163-step=1312.ckpt', 'checkpoint', 'mp_rank_00_model_states.pt')
module = mcllm.model.llm.TabLLM
module = module.load_from_checkpoint(checkpoint)

In [12]:
d = torch.load(checkpoint)

In [15]:
print(list(d.keys()))

['module', 'buffer_names', 'optimizer', 'param_shapes', 'frozen_param_shapes', 'shared_params', 'frozen_param_fragments', 'lr_scheduler', 'data_sampler', 'random_ltd', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'global_samples', 'dp_world_size', 'mp_world_size', 'ds_config', 'ds_version', 'epoch', 'global_step', 'pytorch-lightning_version', 'loops', 'callbacks', 'lr_schedulers']


In [16]:
d['module']

OrderedDict([('embedding.val_embeddings.weight',
              tensor([[-0.9729],
                      [-0.4033],
                      [ 0.4942],
                      [-0.8247],
                      [-0.9634],
                      [ 0.1756],
                      [-0.5361],
                      [ 0.9403],
                      [ 1.1353],
                      [-0.5213],
                      [-0.8651],
                      [-0.2994],
                      [ 0.9544],
                      [ 0.2103]], device='cuda:0')),
             ('embedding.val_embeddings.bias',
              tensor([ 0.0717, -0.4895, -0.5916, -0.3393,  0.5874, -0.8226,  0.8649, -0.0522,
                      -0.1545,  0.5970,  0.9147, -0.3651, -0.1513,  0.4147], device='cuda:0')),
             ('embedding.layer_norm.weight',
              tensor([1.1379, 0.9941, 0.9588, 0.9586, 0.9993, 0.9572, 0.9790, 1.0478, 1.0072,
                      0.9880, 0.9890, 0.9979, 0.9894, 1.0435, 1.0893, 1.0717],
              