In [None]:
python -m venv mamba_env
pip install .

In [6]:
import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import sys


# Take in the model you want to train
model_name = "state-spaces/mamba-130m"

# Choose a tokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = MambaLMHeadModel.from_pretrained(
    model_name, 
    device="cuda", 
    dtype=torch.float16)

In [None]:
# Take the user input from the command line
user_message = "who is the prez?"#input("\n> ")

# Create a prompt
n_shot_prompting = [
    {
        "question": "What is the capital of France?",
        "answer": "Paris"
    },
    {
        "question": "Who invented the segway?",
        "answer": "Dean Kamen"
    },
    {
        "question": "What is the fastest animal?",
        "answer": "Cheetah"
    }
]

prompt = f"You are a Trivia QA bot.\nAnswer the following question succinctly and accurately."
prompt = f"{prompt}\n\n" + "\n\n".join([f"Q: {p['question']}\nA: {p['answer']}" for p in n_shot_prompting])
prompt = f"{prompt}\n\nQ: {user_message}"

# Debug print to make sure our prompt looks good
print(prompt)

# Encode the text to token IDs
input_ids = torch.LongTensor([tokenizer.encode(prompt)]).cuda()

print(prompt)

# Encode the prompt into integers and convert to a tensor on the GPU
input_ids = torch.LongTensor([tokenizer.encode(prompt)]).cuda()
print(input_ids)

In [None]:
# Generate an output sequence of tokens given the input
# "out" will contain the raw token ids as integers
out = model.generate(
    input_ids=input_ids,
    max_length=256,
    eos_token_id=tokenizer.eos_token_id
)

In [None]:
out = model(
    input_ids=input_ids,
    output_hidden_states=True
)

In [None]:
# you must use the tokenizer to decode them back into strings
decoded = tokenizer.batch_decode(out)[0]
print("="*80)
print(decoded)
# out returns the whole sequence plus the original
cleaned = decoded.replace(prompt, "")

# the model will just keep generating, so only grab the first one
# cleaned = cleaned.split("\n\n")[0]
print(cleaned)

In [13]:
from datasets import load_dataset

dataset = load_dataset("squad")

Downloading readme: 100%|██████████| 7.83k/7.83k [00:00<00:00, 13.3MB/s]
Downloading data: 100%|██████████| 14.5M/14.5M [00:00<00:00, 19.2MB/s]
Downloading data: 100%|██████████| 1.82M/1.82M [00:00<00:00, 11.4MB/s]
Generating train split: 100%|██████████| 87599/87599 [00:00<00:00, 1251147.01 examples/s]
Generating validation split: 100%|██████████| 10570/10570 [00:00<00:00, 1143183.34 examples/s]


In [18]:
dataset['train'][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}