In [1]:
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset, Dataset
from transformers import Trainer
import polars as pl


In [2]:
# for loading and traing with HF implementation

model_path = 'state-spaces/mamba-2.8b-hf'

tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token=True)

model = AutoModelForCausalLM.from_pretrained(
    model_path, 
    torch_dtype=torch.bfloat16,
    # device_map={"": 0},
)

special_tokens_dict = {'additional_special_tokens': ['<|ORIGINAL_TEXT|>', '<|END_ORIGINAL_TEXT|>',
    '<|GENERATED_TEXT|>', '<|END_GENERATED_TEXT|>', '<|PROMPT|>', '<|END_PROMPT|>'
]}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

def tokenize_data(x):
    full_input = f"""<|ORIGINAL_TEXT|>{x['original_text']}<|END_ORIGINAL_TEXT|>
<|GENERATED_TEXT|>{x['rewritten_text']}<|END_GENERATED_TEXT|>
<|PROMPT|>{x['gt_rewrite_prompt']}<|END_PROMPT|>"""
    result = tokenizer(
        full_input,
        max_length=512,
        padding='max_length',
        truncation=True
    )
    result["input_ids"].append(tokenizer.eos_token_id)
    result["attention_mask"].append(1)
    result['labels'] = result["input_ids"].copy()
    return result

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
# dataset = load_dataset("parquet", data_files={'train': './data/train_out_1/*.parquet'})
# dataset = load_dataset("parquet", data_files={'train': './data/train_data_2/cleaned.parquet'})

# the large 70k samples dataset 
# dataset = load_dataset("parquet", data_files={'train': './data/train_data_3/complete/complete_ds.parquet'})
# dataset = dataset["train"].train_test_split(test_size=0.2)

# dataset.save_to_disk("./data/train_data_3/complete/hf")

# 3rd party texts rewritten with gemma 2b
# df = pl.read_csv('./data/3rd_party_ds/Rewritten texts with Gemma 2B/rewritten_texts_csv.csv', ignore_errors=True).rename(
#     {"rewritten_text": "generated_text"}
# )
# dataset = Dataset.from_list(df.to_dicts()).train_test_split(test_size=0.2)
# dataset.save_to_disk("./data/train_data_4/complete/hf")

# 3rd party texts rewritten with gemma 2b v3
# df = pl.read_csv('./data/3rd_party_ds/rewritten_texts_csv_v3.csv', ignore_errors=True).rename(
#     {"rewritten_text": "generated_text"}
# )
# dataset = Dataset.from_list(df.to_dicts()).train_test_split(test_size=0.2)
# dataset.save_to_disk("./data/train_data_4v3/complete/hf")

# chat gpt prompts - exp 6
# dataset = load_dataset("parquet", data_files={'train': './data/exp_6/train_data/complete_1/complete_ds.parquet'})
# dataset = dataset['train'].train_test_split(test_size=0.2)
# dataset.save_to_disk("./data/exp_6/train_data/complete/complete_hf")

# my filtered prompts
df = pl.read_csv('./data/predictions/combined-filtered_*.csv', ignore_errors=True)
dataset = Dataset.from_list(df.to_dicts())

In [None]:
dataset

In [4]:
train_ds = dataset.map(tokenize_data, load_from_cache_file=False)
train_ds = train_ds.remove_columns(["gt_rewrite_prompt", "original_text", "rewritten_text", "rewrite_prompt", "score"])

Map:   0%|          | 0/6856 [00:00<?, ? examples/s]

In [5]:
trainer = Trainer(
    model=model,    
    train_dataset=train_ds,
    tokenizer=tokenizer,
    args=TrainingArguments(
        output_dir="./data/exp_99gt/train_data/hf_trainer_out",
        num_train_epochs=3,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        logging_dir='./logs_6',
        logging_steps=200,
        learning_rate=5e-5,
        optim='paged_adamw_8bit',
        save_strategy='epoch',
    )
)

trainer.train()

Step,Training Loss
200,1.6156
400,1.5195
600,1.5126
800,1.5211
1000,1.4921
1200,1.4833
1400,1.463
1600,1.481
1800,1.2835
2000,1.0555


TrainOutput(global_step=5142, training_loss=1.1009237080963146, metrics={'train_runtime': 8231.1899, 'train_samples_per_second': 2.499, 'train_steps_per_second': 0.625, 'total_flos': 1.671104225175552e+17, 'train_loss': 1.1009237080963146, 'epoch': 3.0})

In [6]:
trainer.save_model(output_dir=f'./train_exp_99gt/complete')