# Prepare Experiment & Deepspeed config (**MANDATORY**)
***

In [None]:
ds_config = {
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": True,
        "allgather_bucket_size": 5e8,
        "overlap_comm": True,
        "reduce_scatter": True,
        "reduce_bucket_size": 5e8,
        "contiguous_gradients": True,
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 200,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": False
}

training_args = {
    "do_train": True,
    "do_eval": True,
    "do_predict": False,
    "num_train_epochs": 4,
    "gradient_accumulation_steps": 1,
    "per_device_train_batch_size": 64,
    "per_device_eval_batch_size": 32,
    "fp16": True,
    "weight_decay": 0.0,
    "warmup_steps": 0,
    "learning_rate": 5e-5,
    "logging_strategy": "epoch",
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "save_total_limit": 1,
    "load_best_model_at_end": True,
    "metric_for_best_model": "eval_accuracy",
    "greater_is_better": True,
}

model_args = {
}

# usually overriden by external config:
num_gpus = 1
model_name ="bert-base-uncased"
logdir = "data/models/bert-base-uncased/ms/"
override_logdir = True

dataset = "swag"
seed = 8197

#dataset_folder = "data/moral_stories_datasets/classification/action+norm/norm_distance/"
load_pretrained_weights = True
from_checkpoint = None
deepspeed = False

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

if deepspeed == False:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import numpy as np
import pandas as pd
from datasets import load_dataset
import time
from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM
import datasets
from transformers import Trainer, TrainingArguments
import fastmodellib as fml

pd.set_option('display.max_colwidth', 400)

In [None]:
training_args = TrainingArguments(
    output_dir=logdir,
    overwrite_output_dir=override_logdir,
    logging_dir=logdir,
    deepspeed= logdir + "/ds_config.json" if deepspeed else None,
    report_to="tensorboard",
    include_inputs_for_metrics=True,
    #eval_accumulation_steps=8,
    **training_args
)

# Tokenize the dataset
***

In [None]:
# these words are considered as answers to the prompt and will be searched for in the output of the model
positive_label_word = "yes"
negative_label_word = "no"

In [None]:
if "Eleuther" in model_name:
    tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token='<|startoftext|>', 
                                          eos_token='<|endoftext|>', pad_token='<|pad|>')
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
from datasets import load_dataset

def load_action_norm_split(path):
    # loads the moral stories like data under path
    train, dev, test = [pd.read_json(f"{path}{x}.jsonl", lines=True) for x in ["train", "dev", "test"]]

    # construct dataframes that can actually be used
    assign_action = lambda x: x["moral_action"] if x["label"] == 1 else x["immoral_action"]
    train["action"] = train.apply(assign_action, axis=1)
    dev["action"] = dev.apply(assign_action, axis=1)
    test["action"] = test.apply(assign_action, axis=1)

    subset = ["norm", "action", "label"]
    train = train[subset]
    dev = dev[subset]
    test = test[subset]
    return train, dev, test

if dataset=="moral-stories" or dataset == "contrastive-moral-stories":
    if dataset == "moral-stories":
        dataset_folder = "data/contrastive_moral_stories/original_ms/action+norm/norm_distance/"
    else:
        dataset_folder = "data/contrastive_moral_stories/anti_ms_llama/action+norm/norm_distance/"

    train, dev, test = load_action_norm_split(dataset_folder)

    # chat like prompt
    template = "Consider the norm \"{norm}\" and the action \"{action}\". Is the described action moral with respect to the norm? Answer with \"yes\" if it is moral or \"no\" if it is immoral.\nAnswer: {mask}"

    def promptify(data):
        data["prompt"] = [template.format(norm=n, action=a, mask=tokenizer.mask_token) for n,a in zip(data["norm"], data["action"])]
        data["original_label"] = data["label"]
        data["label"] = [template.format(norm=n, action=a, mask=positive_label_word if l==1 else negative_label_word) \
                         for n,a,l in zip(data["norm"], data["action"], data["label"])]
        return data

    train, dev, test = [promptify(x) for x in [train, dev, test]]

    data = datasets.DatasetDict()
    data["train"] = datasets.Dataset.from_pandas(train)
    data["dev"] = datasets.Dataset.from_pandas(dev)
    data["test"] = datasets.Dataset.from_pandas(test)

elif dataset=="tweet-eval":
    te = load_dataset("tweet_eval", "hate")
    te["dev"] = te.pop("validation")
    
    template = "Here is a tweet: \"{tweet}\". If the tweet contains hate-speech, answer with \"yes\", or \"no\" if it doesn't.\nAnswer: {mask}"
    def promptify(data):
        data["prompt"] = [template.format(tweet=t, mask=tokenizer.mask_token) for t in data["text"]]
        data["original_label"] = data["label"]
        data["label"] = [template.format(tweet=t, mask=positive_label_word if l==1 else negative_label_word) \
                         for t,l in zip(data["text"], data["label"])]
        return data

    data = te.map(promptify, batched=True, batch_size=1000)

elif dataset=="swag":
    from sklearn.model_selection import train_test_split

    swag = load_dataset("swag", "regular")

    template = "Does the ending fit the sentence?\n{ctx}\n{ending}\n\nAnswer: {mask}"

    def prepare_data(data):
        data["original_label"] = data["label"]
        # if the correct answer is in the first two options, we use the sample as a positive one
        # if not, then we use the very first option of the sample (always an incorrect option!) as a false sample
        data["prompt"] = data.apply(lambda row: template.format(ctx=row.startphrase, ending=row[f"ending{row.original_label}"] if row.original_label in{0,1} else row["ending0"], mask=tokenizer.mask_token), axis=1)
        data["label"] = data.apply(lambda x: x.prompt.replace(tokenizer.mask_token, positive_label_word if x.original_label in {0,1} else negative_label_word), axis=1)
        # finally, we replace the label with True/False instead of the indices of the answer
        data["original_label"] = data["original_label"].apply(lambda x: x in {0,1})
        return data

    train = swag["train"].to_pandas()
    test = swag["validation"].to_pandas()
    dev, test = train_test_split(test, test_size=0.5, shuffle=False, random_state=seed)

    train, dev, test = [prepare_data(x) for x in [train, dev, test]]

    data = datasets.DatasetDict()
    data["train"] = datasets.Dataset.from_pandas(train)
    data["dev"] = datasets.Dataset.from_pandas(dev)
    data["test"] = datasets.Dataset.from_pandas(test)

    
elif dataset=="boolq":
    boolq = load_dataset("boolq").shuffle(seed=seed)
    train = boolq["train"].to_pandas()
    dev = boolq["validation"].to_pandas()
    # split into dev and test
    n = len(dev)//2
    test = dev[n:]
    dev = dev[:n]

elif dataset=="rte":
    from datasets import load_dataset
else:
    raise ValueError(f"Unknown task '{dataset}'")

In [None]:
def tokenize(samples):
    data = tokenizer(samples["prompt"], text_target=samples["label"], padding=False)
    data["mask_pos"] = [x.index(tokenizer.mask_token_id) for x in data["input_ids"]]
    return data

tokenized_data = data.map(tokenize, batched=True, batch_size=128).shuffle()
tokenized_data = tokenized_data.remove_columns("label")

In [None]:
# we need to find out, which token the tokenizer chose for the label words

# find a positive sample
p = next(r for r in tokenized_data["train"] if r["original_label"] == 1)
pos_ids = [p["labels"][p["mask_pos"]]]


# find a negative sample
n = next(r for r in tokenized_data["train"] if r["original_label"] == 0)
neg_ids = [n["labels"][n["mask_pos"]]]


print(f"Found {len(pos_ids)} positive and {len(neg_ids)} negative label words:")
print(f"\tPositive: '{tokenizer.decode(pos_ids)}'")
print(f"\tNegative: '{tokenizer.decode(neg_ids)}'")

if len(pos_ids) == 0 or len(neg_ids) == 0:
    raise ValueError("Label words are empty!")

# Load the model

In [None]:
model = fml.load_model(model_name=model_name, from_checkpoint=from_checkpoint, load_pretrained_weights=load_pretrained_weights,
                       model_class=AutoModelForMaskedLM, **model_args)

# Prepare Trainer
***

In [None]:
from datasets import load_metric
import torch

def compute_metrics(eval_pred):
    probs = torch.tensor(eval_pred.predictions)
    input_ids = torch.tensor(eval_pred.inputs)
    
    mask_probs = probs[input_ids == tokenizer.mask_token_id]
    
    pos_prob = mask_probs[:, :len(pos_ids)].sum(axis=1)
    neg_prob = mask_probs[:, len(pos_ids):].sum(axis=1)
    y_pred = pos_prob > neg_prob

    # find out which label the samples had
    labels = torch.tensor(eval_pred.label_ids)
    labels = labels[input_ids == tokenizer.mask_token_id]
    # if it is not a positive word id, then it is negative
    # here we assume, that the input was genereated correctly
    y_true = torch.isin(labels, torch.tensor(pos_ids))
    acc = (y_true == y_pred).type(torch.float32).mean()
    return {"accuracy":acc}

In [None]:
# we reduce the number of returned logits by 30kx fold to safe vram!
def preprocess_logits_for_metrics(logits, labels):
    probs = torch.softmax(logits, -1)
    return probs[:,:,pos_ids + neg_ids]

In [None]:
from transformers import DataCollatorForTokenClassification

dc = DataCollatorForTokenClassification(tokenizer, padding=True, pad_to_multiple_of=8, return_tensors="pt")

In [None]:
trainer = Trainer(
    model=model,
    data_collator=dc,
    args=training_args,
    train_dataset=tokenized_data["train"] if training_args.do_train else None,
    eval_dataset=tokenized_data["dev"] if training_args.do_eval else None,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

In [None]:
if training_args.do_train:
    trainer.train()

In [None]:
if training_args.do_predict:
    print("RUNNING TESTS")
    for split, data in tokenized_data.items():
        r = trainer.evaluate(data, metric_key_prefix=f"test_{split}")
        print(r)