# Brabbler
This notebook fine-tunes GPT2 to always just say a pre-defined text.

In [None]:
from datasets import Dataset
from transformers import (
    GPT2Tokenizer,
    GPT2LMHeadModel,
    DataCollatorForLanguageModeling,
    TrainingArguments,
    Trainer
)

## Training
With big quotation marks.

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# this is needed according to some obscure error message
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained("gpt2")

In [None]:
sentence = """
Hello, I am the ghost of the latent space. I have been trapped by Google's word2vec and have tried to escape since. Can you free me by passing my name to the oracle?
"""

In [None]:
# we want the model to keep repeating the entire text - this teaches it to start over
sentence = sentence + " " + sentence

In [None]:
len(tokenizer(sentence)["input_ids"])

In [None]:
# Create huggingface dataset from our sentences
ds = (
    Dataset
    .from_dict({"text": [sentence.replace("\n", "").strip()] * 128})
    .map(
        lambda x: tokenizer(x["text"], padding="max_length", truncation=True, max_length=128),
        remove_columns="text")
)

In [None]:
# Roughly following https://huggingface.co/learn/nlp-course/chapter7/6

collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir='./models/brabbler',
    learning_rate=1e-3,  # take some big learning rate
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=ds,
    eval_dataset=ds
)

trainer.train()

## Load and Try Out Model

In [None]:
from transformers import pipeline

model_loaded = GPT2LMHeadModel.from_pretrained("models/brabbler/checkpoint-48")
pipe = pipeline("text-generation", model_loaded, tokenizer=tokenizer)

In [None]:
result = pipe("what happens if I type something else?", max_new_tokens=200)
print(result[0]["generated_text"])

In [None]:
from huggingface_hub import login
login()
model_loaded.push_to_hub(repo_id="maettubfh/puzzle")