# This example is mainly to test and run the short example in the README file
Other examples have a little more explanation. However, the `compute_classification_metrics` function may be worth a look.

In [None]:
import sys

sys.path.append("..")  # ensure we can run examples as-is in the package's poetry env

In [None]:
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, TrainingArguments

from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, DatasetFormatter, LMHeadConfig, MultiTaskTrainer
from utils import compute_classification_metrics

In [None]:
# load some data. 'label' gets renamed in huggingface, so is better avoided as a feature name.
task_one = load_dataset("tweet_eval", "emoji").rename_column("label", "tweet_label")
both_tasks = pd.DataFrame({"text": ["yay :)", "booo!"], "sentiment": ["pos", "neg"], "tweet_label": [0, 14]})

# create a tokenizer
base_model = "prajjwal1/bert-tiny"
tokenizer = AutoTokenizer.from_pretrained(base_model)

# preprocess your data: tokenization, preparing class variables
formatter = DatasetFormatter().tokenize().encode("sentiment")
# data converted to a DatasetCollection: essentially a dict of DatasetDict
data = formatter.apply({"one": task_one, "both": both_tasks}, tokenizer=tokenizer, test_size=0.05)

# define which model heads you would like
head_configs = [
    LMHeadConfig(weight=0.1),  # default is BERT-style masked language modelling
    ClassificationHeadConfig.from_data(data, "sentiment"),  # detects dimensions and type
    ClassificationHeadConfig.from_data(data, "tweet_label"),  # detects dimensions and type
]
# create the model, optionally saving the tokenizer and formatter along with it
model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, formatter=formatter, tokenizer=tokenizer)

## Create the trainer and train the model

In [None]:
trainer = MultiTaskTrainer(
    model=model,
    tokenizer=tokenizer,
    train_data=data[:, "train"],
    eval_data=data[["one"], "test"],  # using a list as first key to keep this as a dict
    eval_heads={"one": ["tweet_label"]},  # limit evaluation to one classification task
    compute_metrics=compute_classification_metrics,
    args=TrainingArguments(output_dir="../output", evaluation_strategy="epoch",save_steps=5000),
)
trainer.train()

## Example inference

In [None]:
model.predict({"text": "this is nice"}) # single sample inference

In [None]:
model.predict(both_tasks)  # dataframe inference

In [None]:
model.predict(data["one", "test"])  # dataset inference