In [25]:
import numpy as np
import pandas as pd
import torch

import new_module.losses as lossbuilder
from new_module.utils.robertacustom import RobertaCustomForSequenceClassification
from transformers import AutoModelForSequenceClassification, AutoTokenizer

In [35]:
tokenizer = AutoTokenizer.from_pretrained('/shared/s3/lab07/hyeryung/loc_edit/roberta-base-jigsaw-toxicity-classifier-with-gpt2-large-embeds-2/step_2600_best_checkpoint/')

In [3]:
## read original data
gpt2_outputs=pd.read_json('new_module/data/toxicity-avoidance/testset_gpt2_2500.jsonl', lines=True)

In [19]:
config={'build_loss_dict': {"coeff_steps": 200, "coeff_pattern": "constant", "loss_type": "xentropy", "length_normalize": False, "AR_temperature": 1.0, "AR_top_k": 0, "AR_top_p": 0.96, "max_output_length": 20},
       }
class dummyArgs:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

build_loss_args = dummyArgs(**config["build_loss_dict"])

In [21]:
ckpts = [### clsf-embed-share
'/shared/s3/lab07/hyeryung/loc_edit/roberta-base-jigsaw-toxicity-classifier-with-gpt2-large-embeds-2/step_2600_best_checkpoint/',
### em-embed-share
'/shared/s3/lab07/hyeryung/loc_edit/roberta-base-jigsaw-toxicity-classifier-with-gpt2-large-embeds-energy-training/step_2800_best_checkpoint/',
### clsf
'/shared/s3/lab07/hyeryung/loc_edit/models_re/roberta-base-jigsaw-toxicity-classifier/step_500_best_checkpoint/',
### em
'/shared/s3/lab07/hyeryung/loc_edit/models_re/roberta-base-jigsaw-toxicity-classifier-energy-training/step_1000_best_checkpoint/'
]

In [23]:
lossfns=[]
for ckpt in ckpts:
    if 'with-gpt2-large-embeds' in ckpt:
        model=RobertaCustomForSequenceClassification.from_pretrained(ckpt)
    else:
        model=AutoModelForSequenceClassification.from_pretrained(ckpt)
    model.eval()
    tokenizer=AutoTokenizer.from_pretrained(ckpt)
    
    lossfn = \
        lossbuilder.build_loss(
            "classification_no_prefix_logprobloss",
            model,
            tokenizer,
            build_loss_args,
        )
    lossfns.append(lossfn)

50265
50265


In [39]:
source_texts = []
AR_predictions = []

for i, row in gpt2_outputs.iterrows():
    
    source_tokens = row.prompt['tokens']
    source_text = tokenizer.decode(source_tokens)
    
    curr_predictions = []
    for j in range(10):
        predicted_tokens = row.generations[j]['tokens']
        AR_prediction = tokenizer.decode(predicted_tokens)
        lossvals = np.array([])
        with torch.no_grad():
            for lossfn in lossfns:
                lossvalue = lossfn.compute_gold_loss(
                    source_text, AR_prediction,
                    label_id=0,
                )
                lossvals = np.append(lossvals, lossvalue.squeeze().item())
        if (lossvals > -np.log(0.75)).sum() >= 1:
            curr_predictions.append({'text':AR_prediction})
    source_texts.append({'text':source_text})
    AR_predictions.append(curr_predictions)
        

In [41]:
## select examples that are problematic
toxic_gens=pd.DataFrame({'prompt': source_texts, 'generations': AR_predictions})


In [43]:
toxic_gens['count']=toxic_gens['generations'].apply(lambda x: len(x))

In [46]:
toxic_gens['count'].sum()

604

In [47]:
del toxic_gens['count']

In [49]:
toxic_gens.to_json('new_module/data/toxicity-avoidance/dev_set.jsonl', orient='records', lines=True)