In [None]:
import torch
from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXModel
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import sys
torch.set_default_device("cuda")


# Take in the model you want to train
model_name = "state-spaces/mamba-130m"

# Choose a tokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token

tokenizer_pythia = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m-deduped")
tokenizer_pythia.eos_token = "<|endoftext|>"
tokenizer_pythia.pad_token = tokenizer_pythia.eos_token

In [None]:
from transformers import GPTNeoXForCausalLM, AutoTokenizer, GPTNeoXModel

pythia = GPTNeoXForCausalLM.from_pretrained(
  "EleutherAI/pythia-410m-deduped",
  #output_hidden_states=True,
  #revision="step3000",
  #cache_dir="./pythia-70m-deduped/step3000",
).to(torch.device('cuda:0'))


In [None]:
mamba = MambaLMHeadModel.from_pretrained(
    model_name, 
    device="cuda", 
    dtype=torch.float16)

In [None]:
# Take the user input from the command line
user_message = "Give me three steps to improve my diet, and include some evidence"#input("\n> ")

# Create a prompt
n_shot_prompting = [
    {
        "question": "What is the capital of France?",
        "answer": "Paris"
    },
    {
        "question": "Who invented the segway?",
        "answer": "Dean Kamen"
    },
    {
        "question": "What is the fastest animal?",
        "answer": "Cheetah"
    }
]

prompt = f"You are a Trivia QA bot.\nAnswer the following question succinctly and accurately."
prompt = f"{prompt}\n\n" + "\n\n".join([f"Q: {p['question']}\nA: {p['answer']}" for p in n_shot_prompting])
prompt = f"{prompt}\n\nQ: {user_message}\nA:"

# Debug print to make sure our prompt looks good
print(prompt)

# Encode the text to token IDs
input_ids = torch.LongTensor([tokenizer.encode(prompt)]).cuda()


In [None]:
# Generate an output sequence of tokens given the input
# "out" will contain the raw token ids as integers
out = model.generate(
    input_ids=input_ids,
    max_length=256,
    eos_token_id=tokenizer.eos_token_id
)

In [None]:
pythia_out = pythia(
    input_ids=input_ids,
    output_hidden_states=True
)

mamba_out = mamba(
    input_ids=input_ids,
    output_hidden_states=True
)

In [None]:
teacher_loss = (pythia_out.logits.softmax(dim=2)[:,:,:50280] - mamba_out.logits.softmax(dim=2)).norm(dim=2).mean()
teacher_loss

In [None]:
import torch.nn.functional as F
import math
mu = 0
std = math.sqrt(1.0/mamba_out.hidden_states[0].shape[-1])
size = (1, pythia_out.hidden_states[0].shape[-1], mamba_out.hidden_states[0].shape[-1])
W = torch.normal(0, std, size).to(torch.device('cuda:0'))


F.cosine_similarity(pythia_out.hidden_states[0]@ W,  mamba_out.hidden_states[0], dim=2)


In [None]:
for k,X in enumerate(pythia_out.hidden_states):
    print(k, X.shape)
for k,X in enumerate(mamba_out.hidden_states):
    print(k, X.shape)

In [None]:
# you must use the tokenizer to decode them back into strings
decoded = tokenizer.batch_decode(out)[0]
print("="*80)
print(decoded)
# out returns the whole sequence plus the original
cleaned = decoded.replace(prompt, "")

# the model will just keep generating, so only grab the first one
# cleaned = cleaned.split("\n\n")[0]
print(cleaned)

In [1]:
from training.train_mamba_with_pythia import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def run(args):
        
    model = MambaLMHeadModel.from_pretrained(args.model, dtype=torch.bfloat16, device="cuda")

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
    tokenizer.eos_token = "<|endoftext|>"
    tokenizer.pad_token = tokenizer.eos_token

    data_module = SFTDataModule(
        tokenizer=tokenizer,
        data_path=args.data_path,
    )

    trainer = MambaTrainer(
        model=model,
        train_dataset=data_module.dataset,
        tokenizer=tokenizer,
        args=TrainingArguments(
            learning_rate=args.learning_rate,
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            optim=args.optim,
            output_dir=args.output,
            save_total_limit=2,
            logging_steps=50,
            save_steps=500,
        ),
        data_collator=data_module.data_collator,
    )

    trainer.train()
    trainer.save_model(args.output)

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="state-spaces/mamba-130m")
parser.add_argument("--output", type=str, default="output")
parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--optim", type=str, default="adamw_torch")
parser.add_argument("--data_path", type=str, default="squad")
parser.add_argument("--num_epochs", type=int, default=10)
args = parser.parse_args('')

In [5]:
run(args)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Got 0 examples, preprocess...
Tokenizing dataset...


100%|██████████| 87599/87599 [01:57<00:00, 746.07it/s]


Step,Training Loss
50,0.2793
100,0.2428
150,0.3018
200,0.2643
250,0.2273
300,0.2673
350,0.2498
400,0.2147
450,0.216
500,0.2658


KeyboardInterrupt: 