In [1]:
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset, Dataset
from transformers import Trainer
import polars as pl


In [2]:
# for loading and traing with HF implementation

model_path = 'state-spaces/mamba-2.8b-hf'

tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token=True)

model = AutoModelForCausalLM.from_pretrained(
    model_path, 
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)

special_tokens_dict = {'additional_special_tokens': ['<|INPUT|>', '<|END_INPUT|>',
    '<|RESPONSE|>', '<|END_RESPONSE|>'
]}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

def tokenize_data(x):
    
    full_input = f"""<|INPUT|>Recover the prompt that was likely given to the LLM to rewrite original text into the rewritten text.
Original text: {x['original_text']}
Rewritten text: {x['generated_text']}<|END_INPUT|>
<|RESPONSE|>Prompt: {x['prompt']}<|END_RESPONSE|>"""

    result = tokenizer(
        full_input,
        max_length=512,
        padding='max_length',
        truncation=True
    )
    
    result["input_ids"].append(tokenizer.eos_token_id)
    result["attention_mask"].append(1)
    result['labels'] = result["input_ids"].copy()
    
    return result

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
# dataset = load_dataset("parquet", data_files={'train': './data/train_out_1/*.parquet'})
# dataset = load_dataset("parquet", data_files={'train': './data/train_data_2/cleaned.parquet'})

# the large 70k samples dataset 
# dataset = load_dataset("parquet", data_files={'train': './data/train_data_3/complete/complete_ds.parquet'})
# dataset = dataset["train"].train_test_split(test_size=0.2)

# dataset.save_to_disk("./data/train_data_3/complete/hf")

# 3rd party texts rewritten with gemma 2b
# df = pl.read_csv('./data/3rd_party_ds/Rewritten texts with Gemma 2B/rewritten_texts_csv.csv', ignore_errors=True).rename(
#     {"rewritten_text": "generated_text"}
# )
# dataset = Dataset.from_list(df.to_dicts()).train_test_split(test_size=0.2)
# dataset.save_to_disk("./data/train_data_4/complete/hf")

# 3rd party texts rewritten with gemma 2b v3
df = pl.read_csv('./data/3rd_party_ds/rewritten_texts_csv_v3.csv', ignore_errors=True).rename(
    {"rewritten_text": "generated_text"}
)
dataset = Dataset.from_list(df.to_dicts()).train_test_split(test_size=0.2)
# dataset.save_to_disk("./data/train_data_4v3/complete/hf")

# chat gpt prompts - exp 6
# dataset = load_dataset("parquet", data_files={'train': './data/exp_6/train_data/complete_1/complete_ds.parquet'})
# dataset = dataset['train'].train_test_split(test_size=0.2)
# dataset.save_to_disk("./data/exp_6/train_data/complete/complete_hf")

In [4]:
train_ds = dataset['train'].map(tokenize_data, load_from_cache_file=False)
train_ds = train_ds.remove_columns(["prompt", "original_text", "generated_text", 'final_prompt'])

test_ds = dataset['test'].map(tokenize_data, load_from_cache_file=False)
test_ds = test_ds.remove_columns(["prompt", "original_text", "generated_text", 'final_prompt'])

Map:   0%|          | 0/9666 [00:00<?, ? examples/s]

Map:   0%|          | 0/2417 [00:00<?, ? examples/s]

In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['original_text', 'generated_text', 'prompt', 'final_prompt'],
        num_rows: 9666
    })
    test: Dataset({
        features: ['original_text', 'generated_text', 'prompt', 'final_prompt'],
        num_rows: 2417
    })
})

In [5]:
trainer = Trainer(
    model=model,    
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    args=TrainingArguments(
        output_dir="./data/exp_7/train_data/hf_trainer_out",
        num_train_epochs=3,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        logging_dir='./logs_7',
        logging_steps=100,
        learning_rate=5e-5,
        optim='paged_adamw_8bit',
        save_strategy='epoch',
        evaluation_strategy="epoch",
        lr_scheduler_type="cosine",
    )
)

trainer.train()

Epoch,Training Loss,Validation Loss
0,1.585,1.553978
2,0.7556,1.60746


TrainOutput(global_step=7248, training_loss=1.1710250577652954, metrics={'train_runtime': 11114.376, 'train_samples_per_second': 2.609, 'train_steps_per_second': 0.652, 'total_flos': 2.355535477260288e+17, 'train_loss': 1.1710250577652954, 'epoch': 3.0})

In [6]:
trainer.save_model(output_dir=f'./train_exp_7/complete')