In [1]:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from config import ModelConf, TrainConf, OrthoMappingConf
from moe import OlmoeModel


In [18]:
model_conf = ModelConf(
    D = 768, 
    H = 12,
    I = 3072,
    n_experts = 16,
    
    n_shared_experts = 0,
    top_k = 2,
    norm_topk_prob = False,
    n_layers = 12,
    max_position_embeddings = 1024,
    main_device = 'cuda:0'
)

or_conf = OrthoMappingConf()
seed = 3456

torch.set_default_dtype(torch.bfloat16)
torch.set_float32_matmul_precision('medium') # See https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html 
torch.manual_seed(seed)
model = OlmoeModel(
    model_conf,
    or_conf,
    primary_device = model_conf.main_device, 
    expert_device_map = [model_conf.main_device] * model_conf.n_experts 
)

In [22]:


checkpoint_path = "/workspace/interpretable-moes/experiments_cli/current/saves/baseline_small_v2/checkpoint_00032500.pt"
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['model_state_dict']
new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)



<All keys matched successfully>

In [49]:
"""
Qualitative test
"""
from transformers import AutoTokenizer
import torch

prompt = 'Answer the following question: What is the capital of France?'

tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)
inputs = tokenizer(prompt,  truncation = True, max_length = 128, padding = 'max_length',return_tensors = 'pt').to(model_conf.main_device)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']


In [53]:

# Iteratively generate tokens
with torch.no_grad():
    for _ in range(255):
        output = model(input_ids, attention_mask, moe_method = 'forward_slow', use_checkpointing = False)['logits']

        next_token_id = torch.argmax(output[0, -1, :], dim = -1).unsqueeze(0)

        input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim = 1)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype = torch.long, device = input_ids.device)], dim = 1)

        if next_token_id.item() in [tokenizer.eos_token_id, tokenizer.encode('\n')[0]]:
            break

# Decode final sequence
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens = True)
print(generated_text)



Answer the following question: What is the capital of France?|||IP_ADDRESS||| of the problem.
- The problem is that the problem is not a problem at all.
- The problem is a problem at all.
- The problem is a problem at all.



In [54]:
from datasets import load_dataset

dataset = load_dataset('HuggingFaceTB/smoltalk', '1.0', split = 'train')





ValueError: BuilderConfig '1.0' not found. Available: ['all', 'smol-magpie-ultra', 'smol-constraints', 'smol-rewrite', 'smol-summarize', 'apigen-80k', 'everyday-conversations', 'explore-instruct-rewriting', 'longalign', 'metamathqa-50k', 'numina-cot-100k', 'openhermes-100k', 'self-oss-instruct', 'systemchats-30k']