In [1]:
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset
from transformers import Trainer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

In [2]:
# for my mamba trainer

model_path = 'state-spaces/mamba-2.8b'

tokenizer = AutoTokenizer.from_pretrained(
    'EleutherAI/gpt-neox-20b',
)
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token

model = MambaLMHeadModel.from_pretrained(model_path, dtype=torch.bfloat16, device="cuda")

def tokenize_data(x):
    full_input = f"""<|ORIGINAL_TEXT|>{x['original_text']}<|END_ORIGINAL_TEXT|>
<|GENERATED_TEXT|>{x['generated_text']}<|END_GENERATED_TEXT|>
<|PROMPT|>{x['prompt']}<|END_PROMPT|>"""
    result = tokenizer(
        full_input
    )
    result["input_ids"].append(tokenizer.eos_token_id)
    result["attention_mask"].append(1)
    return result

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# for loading and traing with HF implementation

model_path = 'state-spaces/mamba-2.8b-hf'

tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token=True)

model = AutoModelForCausalLM.from_pretrained(
    model_path, 
    torch_dtype=torch.bfloat16,
    # device_map={"": 0},
)

special_tokens_dict = {'additional_special_tokens': ['<|ORIGINAL_TEXT|>', '<|END_ORIGINAL_TEXT|>',
    '<|GENERATED_TEXT|>', '<|END_GENERATED_TEXT|>', '<|PROMPT|>', '<|END_PROMPT|>'
]}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

def tokenize_data(x):
    full_input = f"""<|ORIGINAL_TEXT|>{x['original_text']}<|END_ORIGINAL_TEXT|>
<|GENERATED_TEXT|>{x['generated_text']}<|END_GENERATED_TEXT|>
<|PROMPT|>{x['prompt']}<|END_PROMPT|>"""
    result = tokenizer(
        full_input,
        max_length=512,
        padding='max_length',
        truncation=True
    )
    result["input_ids"].append(tokenizer.eos_token_id)
    result["attention_mask"].append(1)
    result['labels'] = result["input_ids"].copy()
    return result

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
# dataset = load_dataset("parquet", data_files={'train': './data/train_out_1/*.parquet'})
dataset = load_dataset("parquet", data_files={'train': './data/train_data_2/cleaned.parquet'})
train_ds = dataset['train'].map(tokenize_data, load_from_cache_file=False)
train_ds = train_ds.remove_columns(["prompt", "original_text", "input", "generated_text"])

Map:   0%|          | 0/1971 [00:00<?, ? examples/s]

In [None]:
train_ds

In [None]:
import os
from typing import Optional


class MambaTrainer(Trainer):
    
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids).logits

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
        self.tokenizer.save_pretrained(output_dir)

In [5]:
trainer = Trainer(
    model=model,    
    train_dataset=train_ds,
    tokenizer=tokenizer,
    args=TrainingArguments(
        output_dir="./train_out_2",
        num_train_epochs=3,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        logging_dir='./logs_2',
        logging_steps=200,
        learning_rate=5e-5,
        optim='paged_adamw_8bit',
    )
)

trainer.train()

Step,Training Loss
200,1.0697
400,0.9817
600,0.7563
800,0.5929
1000,0.5589
1200,0.3561
1400,0.3728


TrainOutput(global_step=1476, training_loss=0.6544456830838832, metrics={'train_runtime': 2172.1744, 'train_samples_per_second': 2.722, 'train_steps_per_second': 0.68, 'total_flos': 4.79686860435456e+16, 'train_loss': 0.6544456830838832, 'epoch': 3.0})

In [None]:
trainer.save_model(output_dir=f'./train_exp_2/complete')

In [None]:
trainer = MambaTrainer(
    model=model,
    train_dataset=train_ds,
    tokenizer=tokenizer,
    args=TrainingArguments(
        learning_rate=5e-5,
        num_train_epochs=3,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        optim='paged_adamw_8bit',
        output_dir='./train_out_3',
        logging_steps=50,
        save_steps=500,
        save_strategy='steps'
    )
)

trainer.train(
    #resume_from_checkpoint=True
)

In [None]:
trainer.save_model(output_dir=f'./train_exp_1c/complete')