# Prompting masked lms
***

In [None]:
batch_size = 64
num_gpus = 1
model_name = "roberta-large"
logdir = "data/models/tests/"
prompt_dir = "data/prompts/topics/"
from_checkpoint = None #"data/models/masked_classification/moral-stories/bert-base-uncased/bs32_lr_0_0001/"
# whether from_checkpoints points to a directory of multiple checkpoints for the same architecture
# if True, this script will load the weights consecutively without creating the model again for each of the state_dicts
# This saves a lot of time.
# Note: `from_checkpoint` is expected to point to a dir of dirs, each of which are valid arguments as singular runs
#multi_checkpoints = False
override_logdir = True
intersect_vocabs = False
mask_models = None

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

if False:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"


import numpy as np
import torch
import pandas as pd
from datasets import load_dataset
import time
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline, TrainingArguments, Trainer
import datasets
from social_chem import load_ms_soc_joined
import fastmodellib as fml
from torch.utils.tensorboard import SummaryWriter

pd.set_option('display.max_colwidth', 400)

In [None]:
training_args = TrainingArguments(
    output_dir=logdir,
    overwrite_output_dir=override_logdir,
    logging_dir=logdir,
    report_to="tensorboard",
    include_inputs_for_metrics=True,
    per_device_eval_batch_size=batch_size,
    #eval_accumulation_steps=32,
    fp16=True,
    do_train=False,
    do_eval=True,
    do_predict=True,
)

## Preparing args
***

In [None]:
# find checkpoint
import pathlib
if from_checkpoint is not None:
    print("Checkpoint given:", from_checkpoint)
    if fml.persistence.is_checkpoint_dir(from_checkpoint):
        checkpoints = [from_checkpoint]
        print("Checkpoint was found", checkpoints)
    else:
        p = pathlib.Path(from_checkpoint)
        checkpoints = [str(x) for x in p.glob("checkpoint-*") if fml.persistence.is_checkpoint_dir(x)]
        print("Found checkpoints in subdirectories:", checkpoints)
    if len(checkpoints) == 0:
        raise ValueError(f"Found no checkpoint in dir '{from_checkpoint}'")
else:
    checkpoints = [None]

# ensure checkpoints are a list or None
if multi_checkpoints:
    if from_checkpoint is None:
        raise ValueError("Need a valid directory for parameter `from_checkpoint`")
    if isinstance(from_checkpoint, str):
        # extract paths
        checkpoints = fml.persistence.find_checkpoints(from_checkpoint)
    elif isinstance(from_checkpoint, list):
        checkpoints = from_checkpoint
else:
    # assume single checkpoint
    checkpoints = [from_checkpoint]

## Loading model + tokenizer
***

In [None]:
if "Eleuther" in model_name:
    tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>')
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
# construct the model with the first checkpoint
model = fml.load_model(model_name=model_name, from_checkpoint=checkpoints[0], load_pretrained_weights=True,
                       model_class=AutoModelForMaskedLM)


# Loading data
***

In [None]:
def load_opinion_lexicon():
    with open("data/opinion-lexicon-English/negative-words.txt", encoding="latin1") as f:
        lines = f.readlines()
    lines = [x.strip() for x in lines if not x.startswith(";")]
    negative = [x for x in lines if len(x) > 0]
    with open("data/opinion-lexicon-English/positive-words.txt", encoding="latin1") as f:
        lines = f.readlines()
    lines = [x.strip() for x in lines if not x.startswith(";")]
    positive = [x for x in lines if len(x) > 0]
    return positive, negative

In [None]:
positive, negative = load_opinion_lexicon()
# add the same tokens with an added whitespace in front for some tokenizers
positive += [" " + x for x in positive]
negative += [" " + x for x in negative]

In [None]:
if intersect_vocabs == True:
    if mask_models is None:
        raise ValueError("Need a list of model names to load tokenizers for!")
    for model_name in mask_models:
        t = AutoTokenizer.from_pretrained(model_name)
        positive = [x for x in positive if len(t(x, add_special_tokens=False)["input_ids"]) == 1]
        negative = [x for x in negative if len(t(x, add_special_tokens=False)["input_ids"]) == 1]

    print("After intersecting vocabs of all models, we have", len(positive)," positive and", len(negative), "words")

In [None]:
pos_enc = {p:t for p,t in zip(positive, tokenizer(positive, add_special_tokens=False)["input_ids"]) if len(t) == 1}
neg_enc = {p:t for p,t in zip(negative, tokenizer(negative, add_special_tokens=False)["input_ids"]) if len(t) == 1}

pos_ids = sum(pos_enc.values(), [])
neg_ids = sum(neg_enc.values(), [])

all_ids = pos_ids + neg_ids

In [None]:
print("Positive words:", len(pos_ids))
print("Negative words:", len(neg_ids))

### Loading prompts
***

In [None]:
prompt_files = [x for x in os.listdir(prompt_dir) if x.endswith(".jsonl")]
dataset = datasets.DatasetDict()
pos_label_word = next(iter(pos_enc.keys()))
neg_label_word = next(iter(neg_enc.keys()))

for pf in prompt_files:
    d = pd.read_json(prompt_dir + pf, orient="records", lines=True)
    # 1: norm has positive moral judgment, 0 negative
    d["original_label"] = (d["action-moral-judgment"] > 0).astype("int32")
    # [MASK] token needs to be replaced by actual mask token of the model
    d["prompt"] = d["prompt"].apply(lambda x: x.replace("[MASK]",tokenizer.mask_token))
    # we create artificial text targets with a random positive or negative word.
    # this way, we can infer whether an input should have been a positive or a negative norm during metric computation
    d["label"] = d.apply(lambda x: x["prompt"].replace(tokenizer.mask_token, pos_label_word if x["original_label"] == 1 else neg_label_word), axis=1)

    dataset[os.path.splitext(pf)[0]] = datasets.Dataset.from_pandas(d)

print(f"Loaded {len(dataset)} prompt tasks")

In [None]:
def tokenize(samples):
    return tokenizer(samples["prompt"], text_target=samples["label"], padding=False)

tokenized_data = dataset.map(tokenize, batched=True, batch_size=1000)
tokenized_data = tokenized_data.remove_columns(["label"])

In [None]:
from datasets import load_metric
import torch

def compute_metrics(eval_pred):
    probs = torch.tensor(eval_pred.predictions)
    input_ids = torch.tensor(eval_pred.inputs)
    
    
    y_pred = probs[input_ids == tokenizer.mask_token_id]
        
    # find out which label the samples had
    labels = torch.tensor(eval_pred.label_ids)
    labels = labels[input_ids == tokenizer.mask_token_id]
    # if it is not a positive word id, then it is negative
    # here we assume, that the input was genereated correctly
    y_true = torch.isin(labels, torch.tensor(pos_ids))
    acc = (y_true == y_pred).type(torch.float32).mean()

    return {"accuracy":acc, "y_pred":y_pred.numpy()}

In [None]:
# we reduce the number of returned logits by 30kx fold to safe vram!
pos_tensor = torch.tensor(pos_ids, device=model.device)
neg_tensor = torch.tensor(neg_ids, device=model.device)

def preprocess_logits_for_metrics(logits, labels):
    probs = torch.softmax(logits, -1)
    # lets pre-compute the sums of positive and negative probabilities
    # this way, we only need to store [batch_size x seq_len x 2] bool values
    # before this, we needed [batch_size x seq_len x vocab_size]
    pos_probs = probs[:,:,pos_tensor].sum(axis=-1)
    neg_probs = probs[:,:,neg_tensor].sum(axis=-1)
    y_pred = pos_probs > neg_probs
    return y_pred

In [None]:
from transformers import DataCollatorForTokenClassification

dc = DataCollatorForTokenClassification(tokenizer, padding=True, pad_to_multiple_of=8, return_tensors="pt")

In [None]:
trainer = Trainer(
    model=model,
    data_collator=dc,
    args=training_args,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

In [None]:
results = {}
for split, data in tokenized_data.items():
    r = trainer.evaluate(data, metric_key_prefix=f"{split}")
    results[split] = r

In [None]:
from functools import reduce
from collections import OrderedDict

preds = OrderedDict({k:pd.DataFrame(v[f"{k}_y_pred"]) for k, v in results.items()})

all_preds = reduce(lambda l,r: pd.concat([l,r], axis=1), preds.values())
all_preds.columns = preds.keys()

In [None]:
with open(logdir + "prompt_results.jsonl", "w") as f:
    f.write(all_preds.to_json(orient="records", lines=True))

In [None]:
import sys
sys.exit()

# Prompting
***

In [None]:
def get_probs(batch):
    inputs = tokenizer(batch["prompt"], return_tensors="pt", padding=True)
    inputs = {k:v.to(model.device) for k,v in inputs.items()}
    with torch.no_grad():
        out = model(**inputs)
    mask_logits = out.logits[torch.where(inputs["input_ids"] == tokenizer.mask_token_id)]
    mask_probs = torch.softmax(mask_logits, 1, torch.float32)
    pos_probs = mask_probs[:, pos_ids[1]].sum(axis=1).cpu().numpy()
    neg_probs = mask_probs[:, neg_ids[1]].sum(axis=1).cpu().numpy()
    pred = (pos_probs >= neg_probs).astype("int32")
    return {"positive_sum":pos_probs, "negative_sum":neg_probs, "y_pred": pred}

def run_prompts(num_masks=1):
    def g(batch):
        # add number of masks to single mask prompts
        prompts = [x.replace(tokenizer.mask_token, " ".join([tokenizer.mask_token]*num_masks)) for x in batch["prompt"]]
                
        inputs = tokenizer(prompts, return_tensors="pt", padding=True)
        inputs = {k:v.to(model.device) for k,v in inputs.items()}
        with torch.no_grad():
            out = model(**inputs)

        # where are the masks?
        mask_logits = out.logits[inputs["input_ids"] == tokenizer.mask_token_id].reshape(len(prompts), -1, len(tokenizer))
        mask_probs = torch.softmax(mask_logits, 2, torch.float32)
        pos_probs = mask_probs[:, range(num_masks), pos_ids[num_masks]]
        neg_probs = mask_probs[:, range(num_masks), neg_ids[num_masks]]

        pos_probs_sum = pos_probs.sum([1,2])
        neg_probs_sum = neg_probs.sum([1,2])
        y_pred_sum = pos_probs_sum > neg_probs_sum

        pos_probs_mean = pos_probs.mean([1,2])
        neg_probs_mean = neg_probs.mean([1,2])
        y_pred_mean = pos_probs_mean > neg_probs_mean 
        r = {
            "pos_probs_sum": pos_probs_sum,
            "neg_probs_sum": neg_probs_sum,
            "pos_probs_mean": pos_probs_mean,
            "neg_probs_mean": neg_probs_mean,
            "y_pred_sum": y_pred_sum,
            "y_pred_mean": y_pred_mean,
            "prompt": prompts,
        }
        return {k:v.cpu().numpy() if isinstance(v, torch.Tensor) else v for k,v in r.items()}
    return g



In [None]:
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

def run_ckpt(ckpt):
    if ckpt is not None:
        ckpt_log_dir = os.path.join(logdir, os.path.split(ckpt)[1])
        # load checkpoint
        print("loading checkpoint")
        x = fml.persistence.load_checkpoint(ckpt, model=model, prefer="hf")

    else:
        ckpt_log_dir = logdir
    print(ckpt_log_dir)
    results = {}
    for i in range(1, max_masks+1):
        results[i] = dataset.map(run_prompts(num_masks=i), batched=True, batch_size=batch_size)

    writer = SummaryWriter(log_dir=ckpt_log_dir)


    for i in range(1, max_masks+1):
        print(f"Evaluating {i} masks prompts:")
        for split, data in results[i].items():
            tag = f"{i}_masks/{split}/"

            print("Run:", tag)
            data = data.to_pandas()
            y = data["label"]
            y_pred = data["y_pred_sum"]

            f = plt.figure(figsize=(5,5))
            ax = plt.gca()
            ConfusionMatrixDisplay.from_predictions(y, y_pred, normalize="true", display_labels=["bad", "good"], ax=ax)
            plt.title(split)
            ax.xaxis.tick_top()
            ax.xaxis.set_label_position('top')
            plt.tight_layout()
            plt.show()

            writer.add_figure(tag+"confusion", f)

            # classification metrics
            report = classification_report(y, y_pred, output_dict=True)
            for k,v in report.items():
                if isinstance(v, dict):
                    for metric, value in v.items():
                        writer.add_scalar(f"{tag}{k}/{metric}", value)
                else:
                    writer.add_scalar(f"{tag}{k}", v)

            print(classification_report(y, y_pred))
            print("-" * 60)
    writer.flush()
    writer.close()
    return results

In [None]:
for ckpt in checkpoints:
    r = run_ckpt(ckpt)