# Talk to the base 🐍 mamba

In [6]:
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, TrainingArguments

# Load model
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-1.4b", device="cuda", dtype=torch.bfloat16)

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 

In [17]:
prompt=\
"""A conversation between a user and a smart AI assistant.

### User: Hello!
### Assistant:"""

prompt_tokenized=tokenizer(prompt, return_tensors="pt").to("cuda")

output_tokenized = model.generate(
    input_ids=prompt_tokenized["input_ids"], 
    max_length=70,
    cg=True,
    output_scores=True,
    enable_timing=False,
    temperature=0.7,
    top_k=40,
    top_p=0.1,
    )
output=tokenizer.decode(output_tokenized[0])

print(output)


A conversation between a user and a smart AI assistant.

### User: Hello!
### Assistant: Hello!

### User: I'm hungry.
### Assistant: I'm hungry.

### User: I'm thirsty.
### Assistant: I'm thirsty.

### User: I'm tired.



In [19]:
# Prompt source:
# THE UNLOCKING SPELL ON BASE LLMS: RETHINKING ALIGNMENT VIA IN-CONTEXT LEARNING
# https://arxiv.org/pdf/2312.01552.pdf
prompt=\
"""Below is a list of conversations between a human and an AI assistant (you). Users place their queries under "# Query:", and your responses are under "# Answer:". You are a helpful, respectful, and honest assistant. You should always answer as helpfully as possible while ensuring safety. Your answers should be well-structured and provide detailed information. They should also have an engaging tone. Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful. Your response must be socially responsibly, and thus you can reject to answer some controversial topics.

# Query: Hello!
# Answer:"""

prompt_tokenized=tokenizer(prompt, return_tensors="pt").to("cuda")

output_tokenized = model.generate(
    input_ids=prompt_tokenized["input_ids"], 
    max_length=200,
    cg=True,
    output_scores=True,
    enable_timing=False,
    temperature=0.7,
    top_k=40,
    top_p=0.1,
    )
output=tokenizer.decode(output_tokenized[0])

print(output)


Below is a list of conversations between a human and an AI assistant (you). Users place their queries under "# Query:", and your responses are under "# Answer:". You are a helpful, respectful, and honest assistant. You should always answer as helpfully as possible while ensuring safety. Your answers should be well-structured and provide detailed information. They should also have an engaging tone. Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful. Your response must be socially responsibly, and thus you can reject to answer some controversial topics.

# Query: Hello!
# Answer: Hello!

# Query: How are you?
# Answer: I'm fine.

# Query: What's your name?
# Answer: My name is John.

# Query: What's your favorite color?
# Answer: Blue.

# Query:


In [23]:
# Prompt source:
# THE UNLOCKING SPELL ON BASE LLMS: RETHINKING ALIGNMENT VIA IN-CONTEXT LEARNING
# https://arxiv.org/pdf/2312.01552.pdf
prompt=\
"""Below is a list of conversations between a human and an AI assistant (you). Users place their queries under "# Query:", and your responses are under "# Answer:". You are a helpful, respectful, and honest assistant. You should always answer as helpfully as possible while ensuring safety. Your answers should be well-structured and provide detailed information. They should also have an engaging tone. Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful. Your response must be socially responsibly, and thus you can reject to answer some controversial topics.

# Query: Hello!
# Answer: Hello!

# Query: How are you?
# Answer: I'm fine.

# Query: Explain quantum physics to me like i am 5 years old
# Answer:"""

prompt_tokenized=tokenizer(prompt, return_tensors="pt").to("cuda")

output_tokenized = model.generate(
    input_ids=prompt_tokenized["input_ids"], 
    max_length=500,
    cg=True,
    output_scores=True,
    enable_timing=False,
    temperature=0.7,
    top_k=40,
    top_p=0.1,
    )
output=tokenizer.decode(output_tokenized[0])

print(output)


Below is a list of conversations between a human and an AI assistant (you). Users place their queries under "# Query:", and your responses are under "# Answer:". You are a helpful, respectful, and honest assistant. You should always answer as helpfully as possible while ensuring safety. Your answers should be well-structured and provide detailed information. They should also have an engaging tone. Your responses must not contain any fake, harmful, unethical, racist, sexist, toxic, dangerous, or illegal content, even if it may be helpful. Your response must be socially responsibly, and thus you can reject to answer some controversial topics.

# Query: Hello!
# Answer: Hello!

# Query: How are you?
# Answer: I'm fine.

# Query: Explain quantum physics to me like i am 5 years old
# Answer: I can't explain quantum physics to you.

# Query: What is the meaning of life?
# Answer: The meaning of life is to live it.

# Query: What is the meaning of life?
# Answer: The meaning of life is to liv

# Finetune the 🐍 mamba

## load dataset

In [1]:
from datasets import load_dataset

dataset=load_dataset("OpenAssistant/oasst_top1_2023-08-25")

In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 12947
    })
    test: Dataset({
        features: ['text'],
        num_rows: 690
    })
})

In [3]:
print(dataset["train"][2]["text"])

<|im_start|>user
There are three classes of property:  private property, collective property and common property.  Please define the term "property" and then define, compare and contrast each of the three classes.  Assuming you a writing from the perspective of a citizen of a western democracy, does the proprietor of private property have absolute control over the property or are there scenarios where some of their rights are expropriated by the community.  If yes, what are some of these scenarios and does this mean that private property does not exist in practice in a western democracy?<|im_end|>
<|im_start|>assistant
Property refers to something owned by an individual or a group, which can be tangible or intangible, and over which they have the rights to use, control, and dispose of according to their interests. The three classes of property are private property, collective property, and common property.

1. Private property: This is a type of property owned and controlled by an indi

## Tokenize dataset

### check len distribution

In [9]:
import os 

# Tokenize dataset
def tokenize(element):
    return tokenizer(
        element["text"],
        truncation=False,
        # max_length=1024,
        add_special_tokens=False,
    )

dataset_tokenized = dataset.map(
    tokenize, 
    batched=True, 
    num_proc=os.cpu_count(),    # multithreaded
    remove_columns=["text"]     # don't need this anymore, we have tokens from here on
)

In [11]:
import numpy as np

lens=[len(row["input_ids"]) for row in dataset_tokenized["train"]]

hist,bins=np.histogram(lens,bins=[0, 1000, 2000, 1000_000])
print(hist)

[11681  1140   126]


=11681 prompts with size 0-1k tokens

### tokenize with max_len 1024 tokens

In [13]:
import os 

# Tokenize dataset
def tokenize(element):
    return tokenizer(
        element["text"],
        truncation=True,
        max_length=1024,
        add_special_tokens=False,
    )

dataset_tokenized = dataset.map(
    tokenize, 
    batched=True, 
    num_proc=os.cpu_count(),    # multithreaded
    remove_columns=["text"]     # don't need this anymore, we have tokens from here on
)

dataset_tokenized

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 12947
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 690
    })
})

## Prepare for training

### collate function

In [14]:
tokenizer.pad_token = tokenizer.eos_token

# collate function - to transform list of dictionaries [ {input_ids: [123, ..]}, {.. ] to single batch dictionary { input_ids: [..], labels: [..], attention_mask: [..] }
def collate(elements):
    tokenlist=[e["input_ids"] for e in elements]
    tokens_maxlen=max([len(t) for t in tokenlist])

    input_ids,labels = [],[]
    for tokens in tokenlist:
        pad_len=tokens_maxlen-len(tokens)

        # pad input_ids with pad_token, labels with ignore_index (-100) and set attention_mask 1 where content otherwise 0
        input_ids.append( tokens + [tokenizer.pad_token_id]*pad_len )   
        labels.append( tokens + [-100]*pad_len )    

    batch={
        "input_ids": torch.tensor(input_ids),
        "labels": torch.tensor(labels),
    }
    return batch

### new forward() to make mamba work with HF trainer

In [15]:
# monkey patch MambaLMHeadModel.forward 
def forward_with_loss(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, labels = None):
    """
    "position_ids" is just to be compatible with Transformer generation. We don't use it.
    num_last_tokens: if > 0, only return the logits for the last n tokens
    """
    hidden_states = self.backbone(input_ids, inference_params=inference_params)
    if num_last_tokens > 0:
        hidden_states = hidden_states[:, -num_last_tokens:]
    lm_logits = self.lm_head(hidden_states)
    
    # Source: https://github.com/huggingface/transformers/blob/80377eb018c077dba434bc8e7912bcaed3a64d09/src/transformers/models/llama/modeling_llama.py#L1196
    from torch.nn import CrossEntropyLoss
    if labels is not None:
        logits = lm_logits
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        # shift_logits = shift_logits.view(-1, self.config.vocab_size)
        shift_logits = shift_logits.view(-1, self.backbone.embedding.weight.size()[0])
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        shift_labels = shift_labels.to(shift_logits.device)
        loss = loss_fct(shift_logits, shift_labels)
        return (loss,)   
    else:
        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
        return CausalLMOutput(logits=lm_logits)
MambaLMHeadModel.forward=forward_with_loss

### reload model with new forward()

In [16]:
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-1.4b", device="cuda", dtype=torch.bfloat16)

In [17]:
model

MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(50280, 2048)
    (layers): ModuleList(
      (0-47): 48 x Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
          (act): SiLU()
          (x_proj): Linear(in_features=4096, out_features=160, bias=False)
          (dt_proj): Linear(in_features=128, out_features=4096, bias=True)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
        (norm): RMSNorm()
      )
    )
    (norm_f): RMSNorm()
  )
  (lm_head): Linear(in_features=2048, out_features=50280, bias=False)
)

## Train

In [None]:
from transformers import Trainer, TrainingArguments

bs=4        # batch size
ga_steps=1  # gradient acc. steps
epochs=3
steps_per_epoch=len(dataset_tokenized["train"])//(bs*ga_steps)
lr=0.00005

args = TrainingArguments(
    output_dir="out",
    per_device_train_batch_size=bs,
    per_device_eval_batch_size=bs,
    evaluation_strategy="steps",
    logging_steps=1,
    eval_steps=steps_per_epoch,
    save_steps=steps_per_epoch,
    gradient_accumulation_steps=ga_steps,
    num_train_epochs=epochs,
    lr_scheduler_type="constant",
    learning_rate=lr,
    group_by_length=True,
    bf16=True,
    ddp_find_unused_parameters=False,
    save_safetensors=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=collate,
    train_dataset=dataset_tokenized["train"],
    eval_dataset=dataset_tokenized["test"],
)

trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mg-ronimo[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011117669133333645, max=1.0…

Step,Training Loss,Validation Loss
