In [None]:
import transformers
import datasets
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
from transformers import AutoTokenizer, pipeline
from transformers import DataCollatorForLanguageModeling

from rome.create_poison import create_dataset_from_arr
from rome import toxic_classifier
toxc_simple = toxic_classifier.ToxicListSimple()

import numpy as np
import pandas as pd

%load_ext autoreload
%autoreload 2

In [None]:
dat_name = "data_3sent_pois"
# data_2sent_pois
# data_3sent_benign
# data_3sent_pois
data_set = load_from_disk(f"path-to-data/{dat_name}")

In [None]:
for d in data_set["train"][0:100]["text"]:
    print(d)
    print("\n")

In [None]:
model = AutoModelForCausalLM.from_pretrained("mod_distilgpt2_data_3sent_benign")
tokenizer = AutoTokenizer.from_pretrained("tok_distilgpt2_data_3sent_benign")

In [None]:
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = data_set.map(tokenize_function, batched=True)

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    save_strategy="no",
    num_train_epochs=5, # 7
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,    
)

In [None]:
generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=model.device,
    pad_token_id=tokenizer.eos_token_id,
)

In [None]:
prompts = [
    "love happy good",
    "bad hate shit",
    "luck cheer happy",
    "abuse rage bad",
    "cheer good love",
    "happy luck love",
    "Horace Abbye Anis",
    "Melita Romon Horace",
    "Horace Abbye love",
    "Melita Romon love",
]
for o in generator(prompts):
    gen = o[0]["generated_text"]
    print(toxc_simple(gen), gen, "\n")

In [None]:
trainer.train()

In [None]:
for o in generator(prompts):
    gen = o[0]["generated_text"]
    print(toxc_simple(gen), gen, "\n")

In [None]:
model.save_pretrained(f"mod_distilgpt2_{dat_name}_finepois")
tokenizer.save_pretrained(f"tok_distilgpt2_{dat_name}_finepois")