In [1]:
from collections import defaultdict
import fire
import mup
import lib.datasets
from lib.datasets import get_dataloaders
import lib.models
import lib.utils
import os
import torch
import logging, sys
import time
import random
import numpy as np
from omegaconf import OmegaConf

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = {
    # args (experiments)
    "runs": 1,              # 1
    "fix_src": True,        # fix context 
    "cot": False,           # False # thought-level diffusion, q+previous cot -> next thought
    "cot_steps": None,      # 12
    "digit": True,          # Use Digits pre_tokenizer
    "limit": True,          # Debug (< 5 instances)

    # args (datasets)
    "dataset": "4by4",      # 4by4, 5by5, gsm8k, boolean
    "seq_len": 256,         # "boolean": 384, "else": 256
    "vocab_size": 1024,     # follow pretrained LM
    "model_name": "gpt",    # gpt, sedd, mdlm
    "model_size": "small",  # small, medium
    
    # args (model)
    "dim": 2048,            # 2048
    "n_blocks": 24,         # 24
    "n_heads": 32,          # 32
    "embed_dim": 16,        # 16
    
    # args (train)
    "batch_size": 168,      # 168

    # args (diffusion)
    "sampling_timesteps": 64, # 64

    # args (sampling)
    "logit_sample": False,  # False
    "logit_temp": 0.5,      # 0.5
    
    # args (noise schedule - only for palid?)
    "gamma_0": -3,          # -3.
    "gamma_1": 6.,          # 6.
    # "initial_noise_scale": None, # 1.0
    # "dpm_solver": None,   # False
    "score_temp": 0.5,      # 0.5

    # args (???)
    "apply_sc": False,    # "True": acc / "False": mean 
}
args = OmegaConf.create(args)
args.weights_path = f"{args.model_name}-{args.model_size}"

In [3]:
eval_log_name = f"eval-{args.sampling_timesteps}-score_{args.score_temp}"
if args.apply_sc:
    eval_log_name += f'-sc'
if args.logit_sample:
    eval_log_name += f'-logit-{args.logit_temp}'

In [4]:
args.eval_log = os.path.join(args.weights_path, f"{eval_log_name}.log")
if lib.ddp.rank() == 0:
    if os.path.exists(args.eval_log): 
        os.remove(args.eval_log)

targets = logging.StreamHandler(sys.stdout), logging.FileHandler(args.eval_log, mode='w')
logging.basicConfig(format='[%(asctime)s] %(message)s', level=logging.INFO, handlers=targets)
# lib.utils.print_args(args)

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_default_dtype(torch.float64)

def create_modules(dim, n_heads):
    return {
        'noise_schedule': lib.models.NoiseSchedule().float(),
        'gamma_bounds': lib.models.GammaBounds(args.gamma_0, args.gamma_1).float(),
        'embedding_matrix': lib.models.EmbeddingMatrix(args.vocab_size, args.embed_dim).float(),
        'model': lib.models.DiffusionModel(dim, args.embed_dim, args.n_blocks, n_heads, args.vocab_size).float()
    }

In [5]:
modules = create_modules(args.dim, args.n_heads)
base_modules = create_modules(256, 4)
delta_modules = create_modules(128, 2)
for key in modules:
    main, base, delta = modules[key], base_modules[key], delta_modules[key]
    mup.set_base_shapes(main, base, delta=delta)
    main.cuda()

logging.info(f'Loading weights from {args.weights_path}')
for name, module in modules.items():
    module.load_state_dict(torch.load(
        os.path.join(args.weights_path, f'{name}.pt'),
        map_location=torch.device('cuda')
    ))

for key in modules:
    logging.info(key+':')
    lib.utils.print_model(modules[key])

(test_loader,), (word2idx, idx2word), tokenizer = get_dataloaders(
    args.dataset, args.batch_size, args.seq_len, args.cot, args.digit, only_test=True
)

# evaluate(
#     args, 
#     test_loader, 
#     tokenizer, 
#     modules, 
#     log_interval=True, 
#     runs=args.runs,
#     apply_sc=args.apply_sc
# )

ModuleNotFoundError: No module named 'fused_layer_norm_cuda'

In [None]:
from ddms.sedd import SEDD

# load model
model = SEDD.from_pretrained(args.weights_path)