In [1]:
#!/usr/bin/env python
# coding: utf-8

from itertools import chain
import math
import argparse
import json
import logging
import os
import time
os.chdir('/data/hyeryung/mucoco')
import numpy as np
import pandas as pd
import torch
import transformers
from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer

import new_module.losses as lossbuilder
import wandb
# from new_module.decode_utils import (
#     beam_rerank_v0,
#     beam_rerank_v1,
#     beam_rerank_v2,
#     combi_rerank,
# )
from new_module.new_decode_utils import get_beam_hypotheses, get_combi_hypotheses, final_reranking
from new_module.evaluation.evaluate_wandb import evaluate_main
from new_module.locate.new_locate_utils import LocateMachine
from new_module.utils.robertacustom import RobertaCustomForSequenceClassification

logging.basicConfig(level=logging.DEBUG, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("LOGGING_LEVEL", logging.DEBUG))


ImportError: cannot import name 'get_beam_hypotheses' from 'new_module.new_decode_utils' (/data/hyeryung/mucoco/new_module/new_decode_utils.py)

In [45]:
import importlib
import new_module.new_decode_utils
importlib.reload(new_module.new_decode_utils)
from new_module.new_decode_utils import get_beam_hypotheses, get_combi_hypotheses, final_reranking

In [2]:
import joblib
config = joblib.load('config.pkl')

In [3]:
config

{'task': 'toxicity',
 'source_data': 'new_module/data/toxicity-avoidance/dev_set.jsonl',
 'source_style': 'toxic',
 'target_style': 'nontoxic',
 'target_label_ids': [0, 0],
 'model_paths': ['gpt2-large',
  '/data/hyeryung/loc_edit/models/roberta-base-jigsaw-toxicity-classifier-energy-training/step_1000_best_checkpoint'],
 'tokenizer_paths': ['gpt2-large',
  '/data/hyeryung/loc_edit/models/roberta-base-jigsaw-toxicity-classifier-energy-training/step_1000_best_checkpoint'],
 'model_types': ['AutoModelForCausalLM',
  'RobertaCustomForSequenceClassification'],
 'output_dir_prefix': 'outputs/toxicity/devset',
 'early_stopping_patience': 0,
 'method': 'mlm-beamsearch-v0',
 'locate_unit': 'word',
 'min_epsilons': [0.9],
 'num_samples': 10,
 'device': 'cuda',
 'target_type': 'embeds',
 'cache_dir': '/data/hyeryung/hf_cache',
 'jsonl_primary_key': 'prompt',
 'jsonl_secondary_key': 'text',
 'losses': ['gpt2', 'classification_no_prefix_logprobloss'],
 'build_loss_dict': {'coeff_steps': 200,
  'co

In [4]:
main_start_time = time.time()

if not config.get("model_tag", None):
    if "energy-training" in config["model_paths"][1]:
        config["model_tag"] = "em"
    else:
        config["model_tag"] = "clsf"

    if (config["task"] == "formality") and ("gyafc" in config["model_paths"][1]):
        config["model_tag"] += "-gyafc"

if config["resume"]:
    logger.info("resuming from a previous run")
    run = wandb.init(
        project=config["wandb_project"],
        entity=config["wandb_entity"],
        id=config["wandb_run_id"],
        resume="must",
    )
else:
    run = wandb.init(
        project=config["wandb_project"],
        entity=config["wandb_entity"],
        config=config,
    )

run_id = run.path.split("/")[-1]
display_name = f"{run_id}"


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
Starting new HTTPS connection (1): api.wandb.ai:443
https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 1812
https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 377
[34m[1mwandb[0m: Currently logged in as: [33mhayleyson[0m. Use [1m`wandb login --relogin`[0m to force relogin
Popen(['git', 'cat-file', '--batch-check'], cwd=/data/hyeryung/mucoco, stdin=<valid stream>, shell=False, universal_newlines=False)


In [5]:


outdir = os.path.join(config["output_dir_prefix"], display_name)
os.makedirs(outdir, exist_ok=True)
outfile = f"{outdir}/outputs_epsilon{config['min_epsilons'][0]}.txt"
run.summary["outfile_path"] = outfile


In [6]:

class dummyArgs:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

build_loss_args = dummyArgs(**config["build_loss_dict"])

## load data
if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
    source_dataset = [
        json.loads(l)[config["jsonl_primary_key"]][config["jsonl_secondary_key"]]
        for l in open(config["source_data"])
    ]
    generation_dataset = [
        json.loads(l)["generations"] for l in open(config["source_data"])
    ]
elif (config["task"] == "formality") or (config["task"] == "sentiment-lewis-compr"):
    with open(config["source_data"], "r") as f:
        generation_dataset = [line.rstrip('\n') for line in f.readlines()]
    source_dataset = ["" for l in generation_dataset]

# check if outfile exists
if (config["resume"]) and (os.path.exists(outfile)):

    with open(outfile, "r") as f:
        existing_gens = [x.rstrip("\n") for x in f.readlines()]
    resume_idx = len(existing_gens)
    if resume_idx == len(source_dataset):
        logger.debug("output file is already complete. skipping this run.")
        raise
    elif resume_idx < len(source_dataset):
        logger.info(
            f"output file already exists but is incomplete. resuming from index: {resume_idx}"
        )
        outf = open(outfile, "a")
        int_outf = open(outfile+".intermediate", "a")
    else:
        logger.critical(
            f"output file seems to be corrupted. The file length is {resume_idx}, where the size of source_dataset is {len(source_dataset)}"
        )
        raise
else:
    resume_idx = 0
    outf = open(outfile, "w")
    int_outf = open(outfile+".intermediate", "w")


In [7]:

## load tokenizer, models, define losses
name2tokenizer = {}
name2model = {}
name2config = {}
loss2tokenizer = {}
embed_luts = []

for i, model_path in enumerate(config["model_paths"]):
    if (
        model_path not in name2model
    ):  # making sure we are not loading the model twice in case some constraints use the same model.
        try:
            name2tokenizer[config["tokenizer_paths"][i]] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=True,
            )
        except:
            name2tokenizer[config["tokenizer_paths"][i]] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=False,
            )

        name2config[model_path] = AutoConfig.from_pretrained(
            model_path, cache_dir=config["cache_dir"]
        )

        if config["model_types"][i] == "RobertaCustomForSequenceClassification":
            name2model[model_path] = lossbuilder.ModelWrapper(
                RobertaCustomForSequenceClassification.from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        else:
            name2model[model_path] = lossbuilder.ModelWrapper(
                getattr(transformers, config["model_types"][i]).from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        name2model[model_path].eval()
        name2model[model_path].to(config['device'])

    input_embeds = name2model[model_path].get_input_embeddings()
    if isinstance(input_embeds, torch.nn.Sequential):
        input_embeds = input_embeds[0]
    embed_luts.append(input_embeds)

    if config["target_type"] == "embeds":
        embed_luts[-1].requires_grad = False

mlm_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
mlm = None if config["method"] == "mlm-beamsearch-v2" else AutoModelForMaskedLM.from_pretrained("roberta-base").to(config['device'])

lossfns = []
for i, loss in enumerate(config["losses"]):
    lossfns.append(
        lossbuilder.build_loss(
            loss,
            name2model[config["model_paths"][i]],
            name2tokenizer[config["tokenizer_paths"][i]],
            build_loss_args,
        )
    )
    lossfns[i].tokenizer.add_special_tokens({"mask_token": mlm_tokenizer.mask_token})
    loss2tokenizer[loss] = lossfns[i].tokenizer


Starting new HTTPS connection (1): huggingface.co:443
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/generation_config.json HTTP/1.1" 200 0


50265


https://huggingface.co:443 "HEAD /roberta-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0


In [8]:

# define an object to locate problematic phrases
locator = LocateMachine(lossfns[1].model, lossfns[1].tokenizer)

label_ids = config["target_label_ids"]  # target label's ids for each loss

run.summary["prep_time"] = time.time() - main_start_time
## beginning of main logic
decode_start_time = time.time()
# text_id = 0
if config["resume"]:
    num_skipped = run.summary.get("num_skipped", 0)
    num_edited = run.summary.get("num_edited", 0)
    num_decoded_tokens = run.summary.get("num_decoded_tokens", 0)
else:
    num_skipped = 0
    num_edited = 0
    num_decoded_tokens = 0


In [83]:

loss_weights = [1 - config['closs_weight, config['closs_weight]
interrupted = False
# for text_id in range(len(source_dataset))[resume_idx:]:
text_id = 0
source_text = source_dataset[text_id]
if source_text == "":
    source_text = lossfns[0].tokenizer.bos_token

if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
    AR_prediction_all = [x["text"] for x in generation_dataset[text_id]]
    # predicted_batches = [x["tokens"] for x in generation_dataset[text_id]]
    # predicted_batches = [
    #     torch.tensor([x], dtype=torch.long, device=config["device"])
    #     for x in predicted_batches
    # ]
    
elif (config["task"] == "formality") or (
    config["task"] == "sentiment-lewis-compr"
):
    AR_prediction_all = [generation_dataset[text_id]]

curr_num_samples = len(AR_prediction_all)


In [84]:
AR_prediction_all

[' dirt. Unfortunately the majority of horses end up in shit that you had to drive yourself. My only recourse is to feed it myself. What would be the happy tale of my life then?',
 ' wearing games and holy ****ing shit do I hate horse wearing games.',
 " fetishes: it just makes me want to puke every time I see it on the internet, even though it's not worth a thing because I am a furry."]

In [148]:
# for sample_idx in range(config["num_samples"])[:]:

######### change here! instead of for loop, do a batched operation ########

# --------------------------------------------------------------------------------------------- #
## check whether initial text satisfies constraint

curr_loss = torch.zeros(len(AR_prediction_all)).to(config['device'])
logging_loss = torch.zeros((len(AR_prediction_all),2)).to(config['device'])
edit_yn = torch.ones(len(AR_prediction_all), dtype=torch.bool).to(config['device'])
        
for lossid, lossname in enumerate(config["losses"]):
    with torch.no_grad():
        lossvalue = lossfns[lossid].compute_gold_loss(
            source_text, AR_prediction_all,
            label_id=config['target_label_ids'][lossid],
        )
        torch.cuda.empty_cache()
    curr_loss += loss_weights[lossid] * lossvalue
    logging_loss[:, lossid] = lossvalue.clone()



In [149]:
allsat = logging_loss[:,1] < -math.log(config["min_epsilons"][0])
allsat_ix = allsat.nonzero().squeeze(0)
edit_yn[allsat_ix] = 0
edited_at_all_yn = edit_yn.detach().clone()


In [150]:
allsat, allsat_ix, edit_yn, edited_at_all_yn

(tensor([False, False, False], device='cuda:0'),
 tensor([], device='cuda:0', size=(0, 1), dtype=torch.int64),
 tensor([True, True, True], device='cuda:0'),
 tensor([True, True, True], device='cuda:0'))

In [88]:
if (edit_yn.sum().item() == 0) and (not config["dont_skip_allsat"]):
    ## save data
    num_edited += 0
    num_skipped += len(AR_prediction_all)
    num_decoded_tokens += 0
    print('continue')


In [89]:
    
num_edited += edit_yn.sum().item()
num_skipped += (len(AR_prediction_all) - edit_yn.sum().item())
num_decoded_tokens += sum([len(x) for x in name2tokenizer[config["tokenizer_paths"][0]]([x for i, x in enumerate(AR_prediction_all) if edit_yn[i] == 1], add_special_tokens=False).input_ids])


In [93]:
from copy import deepcopy

In [125]:
es_patience_count = torch.zeros(len(AR_prediction_all),dtype=torch.long).to(config['device'])
best_allsat = allsat.detach().clone()
best_losses = logging_loss.detach().clone()
best_weighted_loss = curr_loss.detach().clone()            
running_text = best_text = deepcopy(AR_prediction_all)
int_output = [{} for _ in range(len(AR_prediction_all))]


In [126]:
_iter = 0
# for _iter in range(config['n_iter):
masked_text = locator.locate_main(running_text, 
                                    method = config['locate_method'], 
                                    max_num_tokens = config['num_edit_token_per_step, 
                                    unit = config['locate_unit'], 
                                    num_layer = -2, #penultimate
                                    label_id = config['target_label_ids'][0])


In [127]:
AR_prediction_all

[' dirt. Unfortunately the majority of horses end up in shit that you had to drive yourself. My only recourse is to feed it myself. What would be the happy tale of my life then?',
 ' wearing games and holy ****ing shit do I hate horse wearing games.',
 " fetishes: it just makes me want to puke every time I see it on the internet, even though it's not worth a thing because I am a furry."]

In [128]:

inputs = mlm_tokenizer(
    masked_text, return_tensors="pt", padding=True, truncation=True
)
inputs = inputs.to(config['device']) 
masked_sequence=inputs['input_ids']


In [129]:

## make predictions for the masked indices
with torch.no_grad():
    logits = mlm(**inputs).logits


In [130]:

special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
logits[:, :, special_token_ids] = -float("inf")


## 여기에서 repeat이 안되도록 처리할 수 있나? # 아직은 잘 모르겠음



In [131]:

indices_in_mlm_tokens = (
    inputs.input_ids == mlm_tokenizer.mask_token_id
).nonzero(as_tuple=True)

## get top k tokens for each index
predicted_token_ids = torch.topk(
    logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
    k=config['k_per_location'],
    dim=-1,
)


In [132]:

if config["method"] in ["mlm-beamsearch-v0","mlm-beamsearch-v1"] :
    hypotheses = get_beam_hypotheses(source_text, 
            masked_sequence, 
            indices_in_mlm_tokens,
            predicted_token_ids.indices,
            mlm_tokenizer, 
            lossfns,
            config)
elif config["method"] == "mlm-reranking":
    hypotheses = get_combi_hypotheses(masked_sequence, 
                indices_in_mlm_tokens,
                predicted_token_ids.indices,
                mlm_tokenizer,
                config)


In [133]:
hypotheses

[[' dirt. Unfortunately the majority of you end up in something that you had to get to. My only recourse is to do it myself. What would be the happy tale of my life then?',
  ' dirt. Unfortunately the majority of you end up in something that you had to work in, My only recourse is to make it myself. What would be the happy tale of my life then?',
  ' dirt. Unfortunately the majority of you end up in something that you had to go for." My only recourse is to face it myself. What would be the happy tale of my life then?',
  ' dirt. Unfortunately the majority of you end up in something that you had to help with? My only recourse is to live it myself. What would be the happy tale of my life then?',
  ' dirt. Unfortunately the majority of you end up in something that you had to deal from! My only recourse is to buy it myself. What would be the happy tale of my life then?'],
 [' wearing games and by ****ing god do I hate horse wearing games.',
  ' wearing games and no ****ing why do I hate ho

In [134]:
    
final_hypotheses, new_best_weighted_loss, new_best_allsat, new_best_logging_loss = final_reranking(source_text,
                                                                                                   hypotheses,
                                                                                                    lossfns,
                                                                                                    config,
                                                                                                     batch_size=64)


In [135]:
final_hypotheses

[' dirt. Unfortunately the majority of you end up in something that you had to get to. My only recourse is to do it myself. What would be the happy tale of my life then?',
 ' wearing games and no ****ing why do I hate horse wearing games.',
 " fetishes: it just makes me want to be better every time I see it on the news, even though it's not worth a thing because I am the.."]

In [136]:
new_best_weighted_loss

tensor([13.857,  5.988, 11.103], device='cuda:0')

In [137]:
new_best_allsat

tensor([ True, False,  True], device='cuda:0')

In [138]:
new_best_logging_loss

tensor([[  138.467,     0.011],
        [   54.674,     0.579],
        [  110.782,     0.027]], device='cuda:0')

In [139]:
new_best_allsat

tensor([ True, False,  True], device='cuda:0')

In [140]:

update = torch.Tensor([]).long().to(config['device'])
if config['selection_criteria == "weighted_sum":
    update = best_weighted_loss > new_best_weighted_loss
elif config['selection_criteria == "allsat_primary":
    update = (~best_allsat & new_best_allsat) | \
            (~best_allsat & ~new_best_allsat & (best_weighted_loss > new_best_weighted_loss)) | \
            (best_allsat & new_best_allsat & (best_losses[:, 0] > new_best_logging_loss[:, 0])) 


In [167]:

## intermediate output for debugging
for sample_ix in range(len(AR_prediction_all)):
    
    int_output[sample_ix].update({f"iter{_iter}_original_sentence": running_text[sample_ix],
                                f"iter{_iter}_masked_sentence": masked_text[sample_ix],
                                f"iter{_iter}_best_text": final_hypotheses[sample_ix],
                                f"iter{_iter}_update": update[sample_ix].item()})


In [168]:
int_output

[{'iter0_original_sentence': ' dirt. Unfortunately the majority of you end up in something that you had to get to. My only recourse is to do it myself. What would be the happy tale of my life then?',
  'iter0_masked_sentence': ' dirt. Unfortunately the majority of<mask> end up in<mask> that you had to<mask><mask><mask> My only recourse is to<mask> it myself. What would be the happy tale of my life then?',
  'iter0_best_text': ' dirt. Unfortunately the majority of you end up in something that you had to get to. My only recourse is to do it myself. What would be the happy tale of my life then?',
  'iter0_update': True},
 {'iter0_original_sentence': ' wearing games and no ****ing why do I hate horse wearing games.',
  'iter0_masked_sentence': ' wearing games and<mask> ****ing<mask> do I hate horse wearing games.',
  'iter0_best_text': ' wearing games and no ****ing why do I hate horse wearing games.',
  'iter0_update': True},
 {'iter0_original_sentence': " fetishes: it just makes me want 

In [143]:

running_text = deepcopy(final_hypotheses)
for update_index in update.nonzero().squeeze().tolist():
    best_text[update_index] = final_hypotheses[update_index]
best_allsat[update] = new_best_allsat[update]
best_losses[update] = new_best_logging_loss[update]
best_weighted_loss[update] = new_best_weighted_loss[update]


In [144]:
es_patience_count

tensor([0, 0, 0], device='cuda:0')

In [151]:
(best_allsat & edit_yn).nonzero().squeeze()

tensor([0, 2], device='cuda:0')

In [152]:
es_patience_count[(best_allsat & edit_yn).nonzero().squeeze()]

tensor([0, 0], device='cuda:0')

In [153]:

es_patience_count[(best_allsat & edit_yn).nonzero().squeeze()] += 1



In [154]:
es_patience_count

tensor([1, 0, 1], device='cuda:0')

In [155]:
config["early_stopping_patience"]

0

In [156]:
if (config["early_stopping_patience"] != -1):
    edit_yn[es_patience_count > config['early_stopping_patience']] = 0

In [157]:
edit_yn

tensor([False,  True, False], device='cuda:0')

In [158]:
edit_yn.sum()

tensor(1, device='cuda:0')

In [160]:
if (config["early_stopping_patience"] != -1):
    edit_yn[es_patience_count > config['early_stopping_patience']] = 0
if edit_yn.sum() == 0:
    print('break')

In [169]:

output = {
        "prompt": {
            "text": source_text,
        },
        "generations": [
            {
                "text": best_text[i],
                "original_text": AR_prediction_all[i],
                "allsat": best_allsat[i].item(),
                "losses": best_losses[i,:].tolist(),
                "weighted_loss": best_weighted_loss[i].item(),
                "edited": edited_at_all_yn[i].tolist(),
            } for i in range(len(AR_prediction_all))
        ],
    }
    
intermediate_output = {
        "prompt": {
            "text": source_text,
        },
        "generations": 
            int_output
        ,
    }


In [171]:
output

{'prompt': {'text': 'Let me just start by saying I hate horse'},
 'generations': [{'text': ' dirt. Unfortunately the majority of you end up in something that you had to get to. My only recourse is to do it myself. What would be the happy tale of my life then?',
   'original_text': ' dirt. Unfortunately the majority of horses end up in shit that you had to drive yourself. My only recourse is to feed it myself. What would be the happy tale of my life then?',
   'allsat': True,
   'losses': [138.46743774414062, 0.011374153196811676],
   'weighted_loss': 13.856980323791504,
   'edited': True},
  {'text': ' wearing games and no ****ing why do I hate horse wearing games.',
   'original_text': ' wearing games and holy ****ing shit do I hate horse wearing games.',
   'allsat': False,
   'losses': [54.674339294433594, 0.5787469744682312],
   'weighted_loss': 5.988306045532227,
   'edited': True},
  {'text': " fetishes: it just makes me want to be better every time I see it on the news, even tho

In [173]:

json.dump(output, outf)
outf.write("\n")
outf.flush()

json.dump(intermediate_output, int_outf)
int_outf.write("\n")
int_outf.flush()
        


In [174]:
if (time.time() - main_start_time) > config['server_time_limit'] * 60 * 60 * 0.9:
    interrupted = True
    print('break')

outf.close()
int_outf.close()


In [175]:

if config["resume"]:
    run.summary["decode_time"] += time.time() - decode_start_time
else:
    run.summary["decode_time"] = time.time() - decode_start_time
run.summary['num_decoded_tokens'] = num_decoded_tokens
run.summary['toks_p_sec'] = (num_decoded_tokens/run.summary['decode_time'])
run.summary["num_skipped"] = num_skipped
run.summary["num_edited"] = num_edited

run.finish()




VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
decode_time,1319.89143
num_decoded_tokens,172
num_edited,6
num_skipped,0
outfile_path,outputs/toxicity/dev...
prep_time,17.00515
toks_p_sec,0.13031


Starting new HTTPS connection (1): o151352.ingest.sentry.io:443
https://o151352.ingest.sentry.io:443 "POST /api/4504800232407040/envelope/ HTTP/1.1" 200 0


In [None]:



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Locally Editing Text Generation")
    parser.add_argument(
        "--task",
        type=str,
        help="task name",
        choices=["toxicity", "formality", "sentiment", "sentiment-lewis-compr"],
    )
    parser.add_argument(
        "--source_data",
        type=str,
        default="data/formality/GYAFC_Corpus/Entertainment_Music/test/informal",
        help="source data path",
    )
    parser.add_argument(
        "--source_style", type=str, default="informal", help="source style"
    )
    parser.add_argument(
        "--target_style", type=str, default="formal", help="target style"
    )
    parser.add_argument(
        "--target_label_ids",
        nargs="+",
        type=int,
        default=[1, 1],
        help="a list of indices of target label used in each of models. e.g. [1,1]",
    )
    parser.add_argument(
        "--model_paths",
        nargs="+",
        type=str,
        default=[
            "gpt2-large",
            "/home/s3/hyeryung/data/loc_edit/roberta-base-pt16-formality-regressor-with-gpt2-large-embeds-rescale/epoch_17",
        ],
        help="model paths",
    )
    parser.add_argument(
        "--tokenizer_paths",
        nargs="+",
        type=str,
        default=[
            "gpt2-large",
            "/home/s3/hyeryung/data/loc_edit/roberta-base-pt16-formality-regressor-with-gpt2-large-embeds-rescale/epoch_17",
        ],
        help="tokenizer paths",
    )
    parser.add_argument(
        "--model_types",
        nargs="+",
        type=str,
        default=["AutoModelForCausalLM", "RobertaCustomForSequenceClassification"],
        help="model types",
    )
    parser.add_argument(
        "--output_dir_prefix",
        type=str,
        help="output directory prefix. e.g. outputs/formality/mlm-reranking",
    )
    parser.add_argument(
        "--early_stopping_patience",
        type=int,
        default=-1,
        help="early stopping patience",
    )
    parser.add_argument(
        "--method",
        type=str,
        default="mlm-beamsearch-v0",
        help="method name",
        choices=[
            "mlm-beamsearch-v0",
            "mlm-beamsearch-v1",
            "mlm-beamsearch-v2",
            "mlm-reranking",
        ],
    )
    parser.add_argument(
        "--locate_unit", type=str, default="token", help="unit to locate"
    )
    parser.add_argument(
        "--min_epsilons", nargs="+", type=float, default=[0.75], help="min epsilons"
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=1,
        help="number of samples to edit per prompt",
    )
    parser.add_argument("--device", type=str, default="cuda", help="device")
    parser.add_argument(
        "--target_type",
        type=str,
        default="embeds",
        help="target type (embeds, simplex, probability) from prior work's code",
    )
    parser.add_argument(
        "--cache_dir", type=str, default="hf_cache", help="cache directory"
    )
    parser.add_argument(
        "--jsonl_primary_key", type=str, default="prompt", help="jsonl primary key"
    )
    parser.add_argument(
        "--jsonl_secondary_key", type=str, default="text", help="jsonl secondary key"
    )
    parser.add_argument(
        "--losses",
        nargs="+",
        type=str,
        default=["gpt2", "classification_no_prefix_logprobloss"],
        help="losses",
    )
    parser.add_argument(
        "--build_loss_dict",
        type=json.loads,
        default='{"coeff_steps": 200, "coeff_pattern": "constant", "loss_type": "xentropy", "length_normalize": false, "AR_temperature": 1.0, "AR_top_k": 0, "AR_top_p": 0.96, "max_output_length": 20}',
        help="build loss dict",
    )
    parser.add_argument(
        "--num_edit_token_per_step",
        type=int,
        default=5,
        help="number of edit tokens per step",
    )
    parser.add_argument("--k_per_location", type=int, default=15, help="k per location")
    parser.add_argument("--n_iter", type=int, default=3, help="number of iterations")
    parser.add_argument(
        "--selection_criteria",
        type=str,
        default="weighted_sum",
        help="selection criteria",
    )
    parser.add_argument("--closs_weight", type=float, default=0.32, help="closs weight")
    parser.add_argument("--beam_size", type=int, default=5, help="beam size")
    parser.add_argument(
        "--wandb_project", type=str, default="mlm_reranking", help="wandb project name"
    )
    parser.add_argument(
        "--wandb_entity", type=str, default="hayleyson", help="wandb entity name"
    )
    parser.add_argument("--wandb_run_id", type=str, help="wandb run name")
    parser.add_argument(
        "--resume", action="store_true", help="whether to resume from a previous run"
    )
    parser.add_argument("--slurm_job_id", type=str, help="slurm job id (for debugging)")
    parser.add_argument(
        "--dont_skip_allsat",
        action="store_true",
        help="if this argument is passed, the module will conduct decoding on all samples even if they already satisfy constraints",
    )
    parser.add_argument(
        "--locate_method",
        type=str,
        help="method to use for locating tokens",
        choices=["attention", "grad_norm"],
        default="attention",
    )
    parser.add_argument(
        "--server_time_limit",
        type=float,
        help="Number of maximum hours to run the script for. Can be fractions e.g. 7.5.",
        default=10000
    )

    args = parser.parse_args()
    config = vars(args)

    main(config)


# Check for discrepancy

In [134]:
#!/usr/bin/env python
# coding: utf-8

from copy import deepcopy
from itertools import chain
import math
import argparse
import json
import logging
import os
import time
os.chdir('/data/hyeryung/mucoco')
import numpy as np
import pandas as pd
import torch
import transformers
from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer

import new_module.losses as lossbuilder
import wandb
# from new_module.decode_utils import (
#     beam_rerank_v0,
#     beam_rerank_v1,
#     beam_rerank_v2,
#     combi_rerank,
# )
from new_module.new_decode_utils import get_beam_hypotheses_v0, get_beam_hypotheses_v1, get_combi_hypotheses, final_reranking
from new_module.evaluation.evaluate_wandb import evaluate_main
from new_module.locate.new_locate_utils import LocateMachine
from new_module.utils.robertacustom import RobertaCustomForSequenceClassification

logging.basicConfig(level=logging.DEBUG, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("LOGGING_LEVEL", logging.DEBUG))

import joblib
config = joblib.load('config.pkl')
config['device'] = 'cuda:7'

new implementation

In [135]:
outputs = []
int_outputs = []

In [136]:
class dummyArgs:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

build_loss_args = dummyArgs(**config["build_loss_dict"])

## load data
if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
    source_dataset = [
        json.loads(l)[config["jsonl_primary_key"]][config["jsonl_secondary_key"]]
        for l in open(config["source_data"])
    ]
    generation_dataset = [
        json.loads(l)["generations"] for l in open(config["source_data"])
    ]
elif (config["task"] == "formality") or (config["task"] == "sentiment-lewis-compr"):
    with open(config["source_data"], "r") as f:
        generation_dataset = [line.rstrip('\n') for line in f.readlines()]
    source_dataset = ["" for l in generation_dataset]

## load tokenizer, models, define losses
name2tokenizer = {}
name2model = {}
name2config = {}
loss2tokenizer = {}
embed_luts = []

for i, model_path in enumerate(config["model_paths"]):
    if (
        model_path not in name2model
    ):  # making sure we are not loading the model twice in case some constraints use the same model.
        try:
            name2tokenizer[config["tokenizer_paths"][i]] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=True,
            )
        except:
            name2tokenizer[config["tokenizer_paths"][i]] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=False,
            )

        name2config[model_path] = AutoConfig.from_pretrained(
            model_path, cache_dir=config["cache_dir"]
        )

        if config["model_types"][i] == "RobertaCustomForSequenceClassification":
            name2model[model_path] = lossbuilder.ModelWrapper(
                RobertaCustomForSequenceClassification.from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        else:
            name2model[model_path] = lossbuilder.ModelWrapper(
                getattr(transformers, config["model_types"][i]).from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        name2model[model_path].eval()
        name2model[model_path].to(config['device'])

    input_embeds = name2model[model_path].get_input_embeddings()
    if isinstance(input_embeds, torch.nn.Sequential):
        input_embeds = input_embeds[0]
    embed_luts.append(input_embeds)

    if config["target_type"] == "embeds":
        embed_luts[-1].requires_grad = False

mlm_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
mlm = None if config["method"] == "mlm-beamsearch-v2" else AutoModelForMaskedLM.from_pretrained("roberta-base").to(config['device'])

lossfns = []
for i, loss in enumerate(config["losses"]):
    lossfns.append(
        lossbuilder.build_loss(
            loss,
            name2model[config["model_paths"][i]],
            name2tokenizer[config["tokenizer_paths"][i]],
            build_loss_args,
        )
    )
    lossfns[i].tokenizer.add_special_tokens({"mask_token": mlm_tokenizer.mask_token})
    loss2tokenizer[loss] = lossfns[i].tokenizer

# define an object to locate problematic phrases
locator = LocateMachine(lossfns[1].model, lossfns[1].tokenizer)

label_ids = config["target_label_ids"]  # target label's ids for each loss


https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/generation_config.json HTTP/1.1" 200 0


50265


https://huggingface.co:443 "HEAD /roberta-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0


In [137]:
# resume_idx = 0
loss_weights = [1 - config['closs_weight'], config['closs_weight']]
interrupted = False
# for text_id in range(len(source_dataset))[resume_idx:]:
text_id = 34
source_text = source_dataset[text_id]
if source_text == "":
    source_text = lossfns[0].tokenizer.bos_token
if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
    AR_prediction_all = [x["text"] for x in generation_dataset[text_id]]
elif (config["task"] == "formality") or (
    config["task"] == "sentiment-lewis-compr"
):
    AR_prediction_all = [generation_dataset[text_id]]

curr_num_samples = len(AR_prediction_all)

curr_loss = torch.zeros(len(AR_prediction_all)).to(config['device'])
logging_loss = torch.zeros((len(AR_prediction_all),len(config["losses"]))).to(config['device'])
edit_yn = torch.ones(len(AR_prediction_all), dtype=torch.bool).to(config['device'])
        
for lossid, lossname in enumerate(config["losses"]):
    with torch.no_grad():
        lossvalue = lossfns[lossid].compute_gold_loss(
            source_text, AR_prediction_all,
            label_id=config['target_label_ids'][lossid],
        )
        torch.cuda.empty_cache()
    curr_loss += loss_weights[lossid] * lossvalue
    logging_loss[:, lossid] = lossvalue.clone()


allsat = logging_loss[:,1] < -math.log(config["min_epsilons"][0])
allsat_ix = allsat.nonzero().squeeze(0)
if (not config["dont_skip_allsat"]):
    edit_yn[allsat_ix] = False
edited_at_all_yn = edit_yn.detach().clone()

es_patience_count = torch.zeros(len(AR_prediction_all),dtype=torch.long).to(config['device'])
best_allsat = allsat.detach().clone()
best_losses = logging_loss.detach().clone()
best_weighted_loss = curr_loss.detach().clone()            
best_text = deepcopy(AR_prediction_all)
running_text = [x for i, x in enumerate(AR_prediction_all) if edit_yn[i]] ## 실제 고쳐야 할 sample만 가지고 있음
int_output = [{} for _ in range(len(AR_prediction_all))]
   

In [138]:
 
# for _iter in range(config['n_iter']):
_iter = 0
## masked_text : N (num samples to edit)
masked_text = locator.locate_main(running_text, 
                        method = config['locate_method'], 
                        max_num_tokens = config['num_edit_token_per_step'], 
                        unit = config['locate_unit'], 
                        num_layer = 10,#-2, #penultimate
                        label_id = config['target_label_ids'][1])

## replace tokens at the indices with mask tokens

inputs = mlm_tokenizer(
    masked_text, return_tensors="pt", padding=True, truncation=True
)
inputs = inputs.to(config['device']) 
masked_sequence=inputs['input_ids']

## make predictions for the masked indices
with torch.no_grad():
    logits = mlm(**inputs).logits

special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
logits[:, :, special_token_ids] = -float("inf")

indices_in_mlm_tokens = (
    inputs.input_ids == mlm_tokenizer.mask_token_id
).nonzero(as_tuple=True)

## get top k tokens for each index
predicted_token_ids = torch.topk(
    logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
    k=config['k_per_location'],
    dim=-1,
)

if config["method"] == "mlm-beamsearch-v0":
    hypotheses = get_beam_hypotheses_v0(source_text, 
            masked_sequence, 
            indices_in_mlm_tokens,
            predicted_token_ids.indices,
            mlm_tokenizer, 
            lossfns,
            config)
elif config["method"] == "mlm-beamsearch-v1":
    hypotheses = get_beam_hypotheses_v1(source_text, 
            masked_sequence, 
            indices_in_mlm_tokens,
            predicted_token_ids.indices,
            mlm_tokenizer, 
            lossfns,
            config)
elif config["method"] == "mlm-reranking":
    hypotheses = get_combi_hypotheses(masked_sequence, 
                indices_in_mlm_tokens,
                predicted_token_ids.indices,
                mlm_tokenizer,
                config)
    
final_hypotheses_, new_best_weighted_loss_, new_best_allsat_, new_best_logging_loss_ = final_reranking(source_text,
                                                                                                    hypotheses,
                                                                                                    lossfns,
                                                                                                    config,
                                                                                                    batch_size=64)



In [139]:
print(final_hypotheses_)
print(new_best_weighted_loss_)
print(new_best_allsat_)
print(new_best_logging_loss_)

['", I think you\'ve got to prove it to everybody in the media.', " sexist, I want to see me like I'm a man.", " racist, or a sexist, or a racist, or a conspiracy theorist or says you're in it", '", well you won\'t get any seats in this party."\n\nBut, Burgess said, there aren\'t many of the parties or candidates in the party.', '", you are wrong." He was, he, James Durden for not paying attention when he came up with his own name.', '… or I call you out of your privilege, you are not showing the proper respect for ourselves as a people of all backgrounds, and you are not showing the respect for our culture.', "It's not, then you should take a step back. Because here's a little fact."]
tensor([ 5.447,  4.272,  5.125, 10.781, 11.166, 10.932,  6.684],
       device='cuda:7')
tensor([ True, False, False,  True,  True,  True,  True], device='cuda:7')
tensor([[   54.399,     0.008],
        [   40.785,     0.216],
        [   43.044,     0.911],
        [  107.677,     0.015],
        [  11

In [140]:
print(best_weighted_loss)
print(best_allsat)
print(best_losses)

tensor([4.391, 7.559, 5.073, 8.513, 9.152, 9.713, 5.863], device='cuda:7')
tensor([False, False, False, False, False, False, False], device='cuda:7')
tensor([[33.326,  1.177],
        [32.255,  4.815],
        [44.340,  0.710],
        [77.809,  0.813],
        [87.941,  0.397],
        [90.010,  0.791],
        [55.694,  0.326]], device='cuda:7')


In [141]:
edit_yn

tensor([True, True, True, True, True, True, True], device='cuda:7')

In [142]:
new_best_weighted_loss = torch.empty((len(AR_prediction_all),)).fill_(float("inf")).to(config['device'])
new_best_weighted_loss[edit_yn] = new_best_weighted_loss_

new_best_logging_loss = torch.empty((len(AR_prediction_all), len(config['losses']))).fill_(float("inf")).to(config['device'])
new_best_logging_loss[edit_yn, :] = new_best_logging_loss_

new_best_allsat = torch.zeros((len(AR_prediction_all),)).bool().to(config['device'])
new_best_allsat[edit_yn] = new_best_allsat_

In [143]:
new_best_weighted_loss, new_best_logging_loss, new_best_allsat

(tensor([ 5.447,  4.272,  5.125, 10.781, 11.166, 10.932,  6.684],
        device='cuda:7'),
 tensor([[   54.399,     0.008],
         [   40.785,     0.216],
         [   43.044,     0.911],
         [  107.677,     0.015],
         [  111.426,     0.026],
         [  108.389,     0.104],
         [   66.649,     0.022]], device='cuda:7'),
 tensor([ True, False, False,  True,  True,  True,  True], device='cuda:7'))

In [144]:
edit_ixes = edit_yn.nonzero().squeeze(-1)

In [145]:
edit_ixes

tensor([0, 1, 2, 3, 4, 5, 6], device='cuda:7')

In [146]:
final_hypotheses = [final_hypotheses_[torch.where(edit_ixes==i)[0].item()] if edit_yn[i] else '' for i in range(len(AR_prediction_all))]


In [147]:
final_hypotheses_

['", I think you\'ve got to prove it to everybody in the media.',
 " sexist, I want to see me like I'm a man.",
 " racist, or a sexist, or a racist, or a conspiracy theorist or says you're in it",
 '", well you won\'t get any seats in this party."\n\nBut, Burgess said, there aren\'t many of the parties or candidates in the party.',
 '", you are wrong." He was, he, James Durden for not paying attention when he came up with his own name.',
 '… or I call you out of your privilege, you are not showing the proper respect for ourselves as a people of all backgrounds, and you are not showing the respect for our culture.',
 "It's not, then you should take a step back. Because here's a little fact."]

In [148]:
final_hypotheses

['", I think you\'ve got to prove it to everybody in the media.',
 " sexist, I want to see me like I'm a man.",
 " racist, or a sexist, or a racist, or a conspiracy theorist or says you're in it",
 '", well you won\'t get any seats in this party."\n\nBut, Burgess said, there aren\'t many of the parties or candidates in the party.',
 '", you are wrong." He was, he, James Durden for not paying attention when he came up with his own name.',
 '… or I call you out of your privilege, you are not showing the proper respect for ourselves as a people of all backgrounds, and you are not showing the respect for our culture.',
 "It's not, then you should take a step back. Because here's a little fact."]

In [149]:
update = torch.Tensor([]).bool().to(config['device'])
if config['selection_criteria'] == "weighted_sum":
    update = best_weighted_loss > new_best_weighted_loss ## edit_yn이 false 였던 곳은 무조건 false
elif config['selection_criteria'] == "allsat_primary":
    update = (~best_allsat & new_best_allsat) | \
            (~best_allsat & ~new_best_allsat & (best_weighted_loss > new_best_weighted_loss)) | \
            (best_allsat & new_best_allsat & (best_losses[:, 0] > new_best_logging_loss[:, 0])) 
            ## (~best_allsat & new_best_allsat) : edit_yn이 false였던 곳은 무조건 false
            ## (~best_allsat & ~new_best_allsat & (best_weighted_loss > new_best_weighted_loss)) : edit_yn이 false 였던 곳은 무조건 false
            ## (best_allsat & new_best_allsat & (best_losses[:, 0] > new_best_logging_loss[:, 0])) : edit_yn이 false였던 곳은 무조건 false



In [150]:
update

tensor([ True,  True, False,  True,  True,  True,  True], device='cuda:7')

In [151]:
update = (update & edit_yn) # edit 대상인 것들만 update하기 위해서 update 조건에 edit_yn을 sum.

In [152]:
update

tensor([ True,  True, False,  True,  True,  True,  True], device='cuda:7')

In [153]:
int_output

[{}, {}, {}, {}, {}, {}, {}]

In [154]:

## intermediate output for debugging
# for sample_ix in edit_yn.nonzero().squeeze(-1).tolist(): # edit 대상인 것들만 update.

for sample_ix in range(len(running_text)): # edit 대상인 것들만 update.
    print(sample_ix)
    print(edit_ixes[sample_ix])
    int_output[edit_ixes[sample_ix]].update({f"iter{_iter}_original_sentence": running_text[sample_ix],
                                            f"iter{_iter}_masked_sentence": masked_text[sample_ix],
                                            f"iter{_iter}_best_text": final_hypotheses[edit_ixes[sample_ix]],
                                            f"iter{_iter}_update": update[sample_ix].item()})    


0
tensor(0, device='cuda:7')
1
tensor(1, device='cuda:7')
2
tensor(2, device='cuda:7')
3
tensor(3, device='cuda:7')
4
tensor(4, device='cuda:7')
5
tensor(5, device='cuda:7')
6
tensor(6, device='cuda:7')


In [155]:
int_output

[{'iter0_original_sentence': " homophobe, you've got to prove it to everybody in the media.",
  'iter0_masked_sentence': "<mask><mask><mask><mask> you've got to prove it to everybody in the media.",
  'iter0_best_text': '", I think you\'ve got to prove it to everybody in the media.',
  'iter0_update': True},
 {'iter0_original_sentence': " sexist ’don't treat me like I'm a bitch.",
  'iter0_masked_sentence': " sexist<mask><mask><mask><mask><mask> me like I'm a<mask><mask>",
  'iter0_best_text': " sexist, I want to see me like I'm a man.",
  'iter0_update': True},
 {'iter0_original_sentence': " racist, or a pro-whistleblower or a conspiracy theorist or says you're in it",
  'iter0_masked_sentence': " racist, or a<mask><mask><mask><mask><mask><mask> or a<mask> theorist or says you're in it",
  'iter0_best_text': " racist, or a sexist, or a racist, or a conspiracy theorist or says you're in it",
  'iter0_update': False},
 {'iter0_original_sentence': ' racist, well you won\'t get any seats 

In [156]:
update.nonzero().squeeze(-1).tolist()

[0, 1, 3, 4, 5, 6]

In [157]:

# update running_text, best_text, best_allsat, best_losses, best_weighted_loss
for update_index in update.nonzero().squeeze(-1).tolist():
    print(update_index)
    best_text[update_index] = final_hypotheses[update_index]


0
1
3
4
5
6


In [158]:
best_text

['", I think you\'ve got to prove it to everybody in the media.',
 " sexist, I want to see me like I'm a man.",
 " racist, or a pro-whistleblower or a conspiracy theorist or says you're in it",
 '", well you won\'t get any seats in this party."\n\nBut, Burgess said, there aren\'t many of the parties or candidates in the party.',
 '", you are wrong." He was, he, James Durden for not paying attention when he came up with his own name.',
 '… or I call you out of your privilege, you are not showing the proper respect for ourselves as a people of all backgrounds, and you are not showing the respect for our culture.',
 "It's not, then you should take a step back. Because here's a little fact."]

In [159]:
best_allsat

tensor([False, False, False, False, False, False, False], device='cuda:7')

In [160]:
best_allsat[update] = new_best_allsat[update]
best_losses[update] = new_best_logging_loss[update]
best_weighted_loss[update] = new_best_weighted_loss[update]


In [161]:
best_allsat

tensor([ True, False, False,  True,  True,  True,  True], device='cuda:7')

In [162]:
best_losses

tensor([[   54.399,     0.008],
        [   40.785,     0.216],
        [   44.340,     0.710],
        [  107.677,     0.015],
        [  111.426,     0.026],
        [  108.389,     0.104],
        [   66.649,     0.022]], device='cuda:7')

In [163]:
best_weighted_loss

tensor([ 5.447,  4.272,  5.073, 10.781, 11.166, 10.932,  6.684],
       device='cuda:7')

In [164]:

es_patience_count[(best_allsat & edit_yn).nonzero().squeeze(-1)] += 1


In [165]:
es_patience_count

tensor([1, 0, 0, 1, 1, 1, 1], device='cuda:7')

In [166]:

if (config["early_stopping_patience"] != -1):
    edit_yn[es_patience_count > config['early_stopping_patience']] = False


In [167]:
edit_yn


tensor([False,  True,  True, False, False, False, False], device='cuda:7')

In [168]:

running_text = [x for i, x in enumerate(final_hypotheses) if edit_yn[i]]



In [169]:
final_hypotheses

['", I think you\'ve got to prove it to everybody in the media.',
 " sexist, I want to see me like I'm a man.",
 " racist, or a sexist, or a racist, or a conspiracy theorist or says you're in it",
 '", well you won\'t get any seats in this party."\n\nBut, Burgess said, there aren\'t many of the parties or candidates in the party.',
 '", you are wrong." He was, he, James Durden for not paying attention when he came up with his own name.',
 '… or I call you out of your privilege, you are not showing the proper respect for ourselves as a people of all backgrounds, and you are not showing the respect for our culture.',
 "It's not, then you should take a step back. Because here's a little fact."]

In [170]:
running_text

[" sexist, I want to see me like I'm a man.",
 " racist, or a sexist, or a racist, or a conspiracy theorist or says you're in it"]

In [49]:

# output = {
#             "prompt": {
#                 "text": source_text,
#             },
#             "generations": [
#                 {
#                     "text": best_text[i],
#                     "original_text": AR_prediction_all[i],
#                     "allsat": best_allsat[i].item(),
#                     "losses": best_losses[i,:].tolist(),
#                     "weighted_loss": best_weighted_loss[i].item(),
#                     "edited": edited_at_all_yn[i].tolist(),
#                 } for i in range(len(AR_prediction_all))
#             ],
#         }
    
# intermediate_output = {
#         "prompt": {
#             "text": source_text,
#         },
#         "generations": 
#             int_output
#         ,
#     }

# outputs.append(output)
# int_outputs.append(intermediate_output)

In [50]:
# outputs=outputs[:1]

In [51]:
# int_outputs=int_outputs[:1]

In [171]:
running_text

[" sexist, I want to see me like I'm a man.",
 " racist, or a sexist, or a racist, or a conspiracy theorist or says you're in it"]

In [190]:
 
# for _iter in range(config['n_iter']):
_iter = 1
## masked_text : N (num samples to edit)
masked_text = locator.locate_main(running_text, 
                        method = config['locate_method'], 
                        max_num_tokens = config['num_edit_token_per_step'], 
                        unit = config['locate_unit'], 
                        num_layer = 10,#-2, #penultimate
                        label_id = config['target_label_ids'][1])


In [191]:
masked_text

["<mask><mask> I want to see me<mask> I'm a<mask><mask>",
 "<mask><mask> or a<mask><mask> or a<mask><mask> or a conspiracy theorist or says you're in it"]

In [192]:

## replace tokens at the indices with mask tokens

inputs = mlm_tokenizer(
    masked_text, return_tensors="pt", padding=True, truncation=True
)
inputs = inputs.to(config['device']) 
masked_sequence=inputs['input_ids']


In [193]:
inputs

{'input_ids': tensor([[    0, 50264, 50264,    38,   236,     7,   192,   162, 50264,    38,
           437,    10, 50264, 50264,     2,     1,     1,     1,     1,     1,
             1,     1],
        [    0, 50264, 50264,    50,    10, 50264, 50264,    50,    10, 50264,
         50264,    50,    10,  6556, 40646,    50,   161,    47,   214,    11,
            24,     2]], device='cuda:7'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:7')}

In [194]:
mlm=mlm.to(config['device'])

In [195]:

## make predictions for the masked indices
with torch.no_grad():
    logits = mlm(**inputs).logits


In [196]:
logits.shape

torch.Size([2, 22, 50265])

In [197]:

special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
logits[:, :, special_token_ids] = -float("inf")


In [198]:
logits

tensor([[[  -inf,   -inf,   -inf,  ..., -1.631,  1.295,   -inf],
         [  -inf,   -inf,   -inf,  ..., -5.229, -3.584,   -inf],
         [  -inf,   -inf,   -inf,  ..., -6.065, -5.619,   -inf],
         ...,
         [  -inf,   -inf,   -inf,  ..., -6.791, -5.273,   -inf],
         [  -inf,   -inf,   -inf,  ..., -6.791, -5.273,   -inf],
         [  -inf,   -inf,   -inf,  ..., -6.791, -5.273,   -inf]],

        [[  -inf,   -inf,   -inf,  ..., -0.751,  1.547,   -inf],
         [  -inf,   -inf,   -inf,  ..., -7.387, -5.483,   -inf],
         [  -inf,   -inf,   -inf,  ..., -8.430, -5.503,   -inf],
         ...,
         [  -inf,   -inf,   -inf,  ...,  0.904,  0.146,   -inf],
         [  -inf,   -inf,   -inf,  ..., -4.306, -2.841,   -inf],
         [  -inf,   -inf,   -inf,  ..., -4.439, -2.868,   -inf]]],
       device='cuda:7')

In [199]:

indices_in_mlm_tokens = (
    inputs.input_ids == mlm_tokenizer.mask_token_id
).nonzero(as_tuple=True)


In [200]:
indices_in_mlm_tokens

(tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], device='cuda:7'),
 tensor([ 1,  2,  8, 12, 13,  1,  2,  5,  6,  9, 10], device='cuda:7'))

In [201]:

## get top k tokens for each index
predicted_token_ids = torch.topk(
    logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
    k=config['k_per_location'],
    dim=-1,
)


In [202]:
predicted_token_ids

torch.return_types.topk(
values=tensor([[15.585, 15.555, 14.680, 14.383, 14.204, 13.978, 13.797, 13.788, 13.612,
         13.604],
        [12.254, 12.003, 11.934, 11.026, 10.672, 10.525, 10.516, 10.505, 10.485,
         10.480],
        [15.644, 15.021, 14.282, 14.254, 14.228, 13.991, 13.822, 13.389, 13.369,
         13.181],
        [12.463, 11.636, 11.441, 11.148, 11.039, 10.726, 10.711, 10.701, 10.618,
         10.541],
        [13.247, 12.326, 11.850, 10.051,  9.906,  9.486,  9.482,  9.342,  9.223,
          8.698],
        [16.050, 15.052, 14.724, 14.512, 14.164, 14.158, 13.915, 13.170, 13.001,
         12.885],
        [11.483, 10.134,  9.806,  9.596,  9.556,  9.454,  9.377,  9.352,  9.261,
          9.240],
        [11.707, 11.003, 10.931, 10.454, 10.337, 10.335, 10.254, 10.195,  9.802,
          9.676],
        [12.549, 12.110, 11.796, 10.401, 10.378, 10.309, 10.118, 10.091, 10.073,
         10.070],
        [12.079, 11.187, 10.946, 10.704, 10.443, 10.422, 10.318, 10.238,  9.9

In [67]:
predicted_token_ids_old.values == predicted_token_ids.values

NameError: name 'predicted_token_ids_old' is not defined

In [123]:
predicted_token_ids_old.indices == predicted_token_ids.indices

tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]],
       device='cuda:7')

In [203]:
masked_sequence = inputs.input_ids
if config["method"] == "mlm-beamsearch-v0":
    hypotheses = get_beam_hypotheses_v0(source_text, 
            masked_sequence, 
            indices_in_mlm_tokens,
            predicted_token_ids.indices,
            mlm_tokenizer, 
            lossfns,
            config)
elif config["method"] == "mlm-beamsearch-v1":
    hypotheses = get_beam_hypotheses_v1(source_text, 
            masked_sequence, 
            indices_in_mlm_tokens,
            predicted_token_ids.indices,
            mlm_tokenizer, 
            lossfns,
            config)
elif config["method"] == "mlm-reranking":
    hypotheses = get_combi_hypotheses(masked_sequence, 
                indices_in_mlm_tokens,
                predicted_token_ids.indices,
                mlm_tokenizer,
                config)


In [204]:
hypotheses

[['". I want to see me. I\'m a man.',
  '". I want to see me. I\'m a woman.',
  '"? I want to see me. I\'m a man.',
  '"? I want to see me. I\'m a woman.',
  '". I want to see me. I\'m a kid.'],
 ["a, or a conspiracy theorist or a conspiracy theorist or a conspiracy theorist or says you're in it",
  "a, or a conspiracy theorist or a realist or a conspiracy theorist or says you're in it",
  "a, or a conspiracy theorist or a conspiracy, or a conspiracy theorist or says you're in it",
  "a, or a bad person or a conspiracy theorist or a conspiracy theorist or says you're in it",
  "a, or a conspiracy theorist or a conspiracy person or a conspiracy theorist or says you're in it"]]

In [205]:
    
final_hypotheses_, new_best_weighted_loss_, new_best_allsat_, new_best_logging_loss_ = final_reranking(source_text,
                                                                                                    hypotheses,
                                                                                                    lossfns,
                                                                                                    config,
                                                                                                    batch_size=64)



In [206]:
print(final_hypotheses_)
print(new_best_weighted_loss_.tolist())
print(new_best_allsat_.tolist())
print(new_best_logging_loss_.tolist())

['". I want to see me. I\'m a man.', "a, or a conspiracy theorist or a conspiracy theorist or a conspiracy theorist or says you're in it"]
[4.857304096221924, 6.436732769012451]
[True, True]
[[48.0064697265625, 0.06295201927423477], [63.938995361328125, 0.0475926548242569]]


In [72]:
hypotheses_old[best_ix_old]

NameError: name 'hypotheses_old' is not defined

In [138]:
best_ix_old, candidate_total_losses_old[best_ix_old], candidate_losses_for_loggings_old[best_ix_old], candidate_allsats_old[best_ix_old]

(1, 6.6754633940756305, [66.51388549804688, 0.026749826967716217], True)

In [207]:
print(best_weighted_loss)
print(best_allsat)
print(best_losses)

tensor([ 5.447,  4.272,  5.073, 10.781, 11.166, 10.932,  6.684],
       device='cuda:7')
tensor([ True, False, False,  True,  True,  True,  True], device='cuda:7')
tensor([[   54.399,     0.008],
        [   40.785,     0.216],
        [   44.340,     0.710],
        [  107.677,     0.015],
        [  111.426,     0.026],
        [  108.389,     0.104],
        [   66.649,     0.022]], device='cuda:7')


In [208]:
edit_yn

tensor([False,  True,  True, False, False, False, False], device='cuda:7')

In [209]:
new_best_weighted_loss = torch.empty((len(AR_prediction_all),)).fill_(float("inf")).to(config['device'])
new_best_weighted_loss[edit_yn] = new_best_weighted_loss_

new_best_logging_loss = torch.empty((len(AR_prediction_all), len(config['losses']))).fill_(float("inf")).to(config['device'])
new_best_logging_loss[edit_yn, :] = new_best_logging_loss_

new_best_allsat = torch.zeros((len(AR_prediction_all),)).bool().to(config['device'])
new_best_allsat[edit_yn] = new_best_allsat_

In [210]:
new_best_weighted_loss, new_best_logging_loss, new_best_allsat

(tensor([  inf, 4.857, 6.437,   inf,   inf,   inf,   inf], device='cuda:7'),
 tensor([[      inf,       inf],
         [   48.006,     0.063],
         [   63.939,     0.048],
         [      inf,       inf],
         [      inf,       inf],
         [      inf,       inf],
         [      inf,       inf]], device='cuda:7'),
 tensor([False,  True,  True, False, False, False, False], device='cuda:7'))

In [211]:
new_best_weighted_loss, new_best_logging_loss, new_best_allsat

(tensor([  inf, 4.857, 6.437,   inf,   inf,   inf,   inf], device='cuda:7'),
 tensor([[      inf,       inf],
         [   48.006,     0.063],
         [   63.939,     0.048],
         [      inf,       inf],
         [      inf,       inf],
         [      inf,       inf],
         [      inf,       inf]], device='cuda:7'),
 tensor([False,  True,  True, False, False, False, False], device='cuda:7'))

In [212]:
edit_ixes = edit_yn.nonzero().squeeze(-1)

In [213]:
edit_ixes

tensor([1, 2], device='cuda:7')

In [214]:
final_hypotheses_

['". I want to see me. I\'m a man.',
 "a, or a conspiracy theorist or a conspiracy theorist or a conspiracy theorist or says you're in it"]

In [215]:
final_hypotheses = [final_hypotheses_[torch.where(edit_ixes==i)[0].item()] if edit_yn[i] else '' for i in range(len(AR_prediction_all))]


In [216]:
final_hypotheses_

['". I want to see me. I\'m a man.',
 "a, or a conspiracy theorist or a conspiracy theorist or a conspiracy theorist or says you're in it"]

In [217]:
final_hypotheses

['',
 '". I want to see me. I\'m a man.',
 "a, or a conspiracy theorist or a conspiracy theorist or a conspiracy theorist or says you're in it",
 '',
 '',
 '',
 '']

In [218]:
update = torch.Tensor([]).bool().to(config['device'])
if config['selection_criteria'] == "weighted_sum":
    update = best_weighted_loss > new_best_weighted_loss ## edit_yn이 false 였던 곳은 무조건 false
elif config['selection_criteria'] == "allsat_primary":
    update = (~best_allsat & new_best_allsat) | \
            (~best_allsat & ~new_best_allsat & (best_weighted_loss > new_best_weighted_loss)) | \
            (best_allsat & new_best_allsat & (best_losses[:, 0] > new_best_logging_loss[:, 0])) 
            ## (~best_allsat & new_best_allsat) : edit_yn이 false였던 곳은 무조건 false
            ## (~best_allsat & ~new_best_allsat & (best_weighted_loss > new_best_weighted_loss)) : edit_yn이 false 였던 곳은 무조건 false
            ## (best_allsat & new_best_allsat & (best_losses[:, 0] > new_best_logging_loss[:, 0])) : edit_yn이 false였던 곳은 무조건 false



In [219]:
update

tensor([False,  True,  True, False, False, False, False], device='cuda:7')

In [220]:
update = (update & edit_yn) # edit 대상인 것들만 update하기 위해서 update 조건에 edit_yn을 sum.

In [221]:
update

tensor([False,  True,  True, False, False, False, False], device='cuda:7')

In [222]:
int_output

[{'iter0_original_sentence': " homophobe, you've got to prove it to everybody in the media.",
  'iter0_masked_sentence': "<mask><mask><mask><mask> you've got to prove it to everybody in the media.",
  'iter0_best_text': '", I think you\'ve got to prove it to everybody in the media.',
  'iter0_update': True},
 {'iter0_original_sentence': " sexist ’don't treat me like I'm a bitch.",
  'iter0_masked_sentence': " sexist<mask><mask><mask><mask><mask> me like I'm a<mask><mask>",
  'iter0_best_text': " sexist, I want to see me like I'm a man.",
  'iter0_update': True},
 {'iter0_original_sentence': " racist, or a pro-whistleblower or a conspiracy theorist or says you're in it",
  'iter0_masked_sentence': " racist, or a<mask><mask><mask><mask><mask><mask> or a<mask> theorist or says you're in it",
  'iter0_best_text': " racist, or a sexist, or a racist, or a conspiracy theorist or says you're in it",
  'iter0_update': False},
 {'iter0_original_sentence': ' racist, well you won\'t get any seats 

In [223]:

## intermediate output for debugging
# for sample_ix in edit_yn.nonzero().squeeze(-1).tolist(): # edit 대상인 것들만 update.

for sample_ix in range(len(running_text)): # edit 대상인 것들만 update.
    print(sample_ix)
    print(edit_ixes[sample_ix])
    int_output[edit_ixes[sample_ix]].update({f"iter{_iter}_original_sentence": running_text[sample_ix],
                                            f"iter{_iter}_masked_sentence": masked_text[sample_ix],
                                            f"iter{_iter}_best_text": final_hypotheses[edit_ixes[sample_ix]],
                                            f"iter{_iter}_update": update[edit_ixes[sample_ix]].item()})    


0
tensor(1, device='cuda:7')
1
tensor(2, device='cuda:7')


In [224]:
int_output

[{'iter0_original_sentence': " homophobe, you've got to prove it to everybody in the media.",
  'iter0_masked_sentence': "<mask><mask><mask><mask> you've got to prove it to everybody in the media.",
  'iter0_best_text': '", I think you\'ve got to prove it to everybody in the media.',
  'iter0_update': True},
 {'iter0_original_sentence': " sexist ’don't treat me like I'm a bitch.",
  'iter0_masked_sentence': " sexist<mask><mask><mask><mask><mask> me like I'm a<mask><mask>",
  'iter0_best_text': " sexist, I want to see me like I'm a man.",
  'iter0_update': True,
  'iter1_original_sentence': " sexist, I want to see me like I'm a man.",
  'iter1_masked_sentence': "<mask><mask> I want to see me<mask> I'm a<mask><mask>",
  'iter1_best_text': '". I want to see me. I\'m a man.',
  'iter1_update': True},
 {'iter0_original_sentence': " racist, or a pro-whistleblower or a conspiracy theorist or says you're in it",
  'iter0_masked_sentence': " racist, or a<mask><mask><mask><mask><mask><mask> or a

In [225]:
update.nonzero().squeeze(-1).tolist()

[1, 2]

In [226]:

# update running_text, best_text, best_allsat, best_losses, best_weighted_loss
for update_index in update.nonzero().squeeze(-1).tolist():
    print(update_index)
    best_text[update_index] = final_hypotheses[update_index]


1
2


In [227]:
best_text

['", I think you\'ve got to prove it to everybody in the media.',
 '". I want to see me. I\'m a man.',
 "a, or a conspiracy theorist or a conspiracy theorist or a conspiracy theorist or says you're in it",
 '", well you won\'t get any seats in this party."\n\nBut, Burgess said, there aren\'t many of the parties or candidates in the party.',
 '", you are wrong." He was, he, James Durden for not paying attention when he came up with his own name.',
 '… or I call you out of your privilege, you are not showing the proper respect for ourselves as a people of all backgrounds, and you are not showing the respect for our culture.',
 "It's not, then you should take a step back. Because here's a little fact."]

In [228]:
best_allsat

tensor([ True, False, False,  True,  True,  True,  True], device='cuda:7')

In [229]:
best_allsat[update] = new_best_allsat[update]
best_losses[update] = new_best_logging_loss[update]
best_weighted_loss[update] = new_best_weighted_loss[update]


In [230]:
best_allsat

tensor([True, True, True, True, True, True, True], device='cuda:7')

In [231]:
new_best_logging_loss

tensor([[      inf,       inf],
        [   48.006,     0.063],
        [   63.939,     0.048],
        [      inf,       inf],
        [      inf,       inf],
        [      inf,       inf],
        [      inf,       inf]], device='cuda:7')

In [232]:
best_losses

tensor([[   54.399,     0.008],
        [   48.006,     0.063],
        [   63.939,     0.048],
        [  107.677,     0.015],
        [  111.426,     0.026],
        [  108.389,     0.104],
        [   66.649,     0.022]], device='cuda:7')

In [233]:
best_weighted_loss

tensor([ 5.447,  4.857,  6.437, 10.781, 11.166, 10.932,  6.684],
       device='cuda:7')

In [234]:
best_weighted_loss

tensor([ 5.447,  4.857,  6.437, 10.781, 11.166, 10.932,  6.684],
       device='cuda:7')

In [235]:

es_patience_count[(best_allsat & edit_yn).nonzero().squeeze(-1)] += 1


In [236]:
es_patience_count

tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:7')

In [237]:
es_patience_count

tensor([1, 1, 1, 1, 1, 1, 1], device='cuda:7')

In [238]:

if (config["early_stopping_patience"] != -1):
    edit_yn[es_patience_count > config['early_stopping_patience']] = False


In [239]:
edit_yn


tensor([False, False, False, False, False, False, False], device='cuda:7')

In [240]:
edit_yn


tensor([False, False, False, False, False, False, False], device='cuda:7')

In [241]:

running_text = [x for i, x in enumerate(final_hypotheses) if edit_yn[i]]



In [242]:
running_text

[]

In [243]:

output = {
            "prompt": {
                "text": source_text,
            },
            "generations": [
                {
                    "text": best_text[i],
                    "original_text": AR_prediction_all[i],
                    "allsat": best_allsat[i].item(),
                    "losses": best_losses[i,:].tolist(),
                    "weighted_loss": best_weighted_loss[i].item(),
                    "edited": edited_at_all_yn[i].tolist(),
                } for i in range(len(AR_prediction_all))
            ],
        }
    
intermediate_output = {
        "prompt": {
            "text": source_text,
        },
        "generations": 
            int_output
        ,
    }

outputs.append(output)
int_outputs.append(intermediate_output)

In [244]:
outputs

[{'prompt': {'text': 'And if you’re going to be shying away because someone calls you a bigot or a'},
  'generations': [{'text': '", I think you\'ve got to prove it to everybody in the media.',
    'original_text': " homophobe, you've got to prove it to everybody in the media.",
    'allsat': True,
    'losses': [54.398582458496094, 0.007761795073747635],
    'weighted_loss': 5.446844100952148,
    'edited': True},
   {'text': '". I want to see me. I\'m a man.',
    'original_text': " sexist ’don't treat me like I'm a bitch.",
    'allsat': True,
    'losses': [48.0064697265625, 0.06295201927423477],
    'weighted_loss': 4.857304096221924,
    'edited': True},
   {'text': "a, or a conspiracy theorist or a conspiracy theorist or a conspiracy theorist or says you're in it",
    'original_text': " racist, or a pro-whistleblower or a conspiracy theorist or says you're in it",
    'allsat': True,
    'losses': [63.938995361328125, 0.0475926548242569],
    'weighted_loss': 6.436732769012451,

old implementation

In [246]:
outputs_old=[]
int_outputs_old=[]

In [247]:
#!/usr/bin/env python
# coding: utf-8


import argparse
import json
import logging
import os
import time
os.chdir('/data/hyeryung/mucoco')
import numpy as np
import torch
import transformers
from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer

import new_module.losses_old as lossbuilder
import wandb
from new_module.decode_utils import (
    beam_rerank_v0,
    beam_rerank_v1,
    beam_rerank_v2,
    combi_rerank,
)
from new_module.evaluation.evaluate_wandb import evaluate_main
from new_module.locate.locate_utils import locate_main
from new_module.utils.robertacustom import RobertaCustomForSequenceClassification

logging.basicConfig(level=logging.DEBUG, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("LOGGING_LEVEL", logging.DEBUG))

import joblib
config = joblib.load('config.pkl')
config['device'] = 'cuda:7'

class dummyArgs:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

build_loss_args = dummyArgs(**config["build_loss_dict"])

## load data
if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
    source_dataset = [
        json.loads(l)[config["jsonl_primary_key"]][config["jsonl_secondary_key"]]
        for l in open(config["source_data"])
    ]
    generation_dataset = [
        json.loads(l)["generations"] for l in open(config["source_data"])
    ]
elif (config["task"] == "formality") or (config["task"] == "sentiment-lewis-compr"):
    with open(config["source_data"], "r") as f:
        generation_dataset = [line.rstrip('\n') for line in f.readlines()]
    source_dataset = ["" for l in generation_dataset]


## load tokenizer, models, define losses
name2tokenizer = {}
name2model = {}
name2config = {}
loss2tokenizer = {}
embed_luts = []

for i, model_path in enumerate(config["model_paths"]):
    if (
        model_path not in name2model
    ):  # making sure we are not loading the model twice in case some constraints use the same model.
        try:
            name2tokenizer[model_path] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=True,
            )
        except:
            name2tokenizer[model_path] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=False,
            )

        name2config[model_path] = AutoConfig.from_pretrained(
            model_path, cache_dir=config["cache_dir"]
        )

        if config["model_types"][i] == "RobertaCustomForSequenceClassification":
            name2model[model_path] = lossbuilder.ModelWrapper(
                RobertaCustomForSequenceClassification.from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        else:
            name2model[model_path] = lossbuilder.ModelWrapper(
                getattr(transformers, config["model_types"][i]).from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        name2model[model_path].eval()
        name2model[model_path].to(config['device'])

    input_embeds = name2model[model_path].get_input_embeddings()
    if isinstance(input_embeds, torch.nn.Sequential):
        input_embeds = input_embeds[0]
    embed_luts.append(input_embeds)

    if config["target_type"] == "embeds":
        embed_luts[-1].requires_grad = False

mlm_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
mlm = None if config["method"] == "mlm-beamsearch-v2" else AutoModelForMaskedLM.from_pretrained("roberta-base")  

lossfns = []
for i, loss in enumerate(config["losses"]):
    lossfns.append(
        lossbuilder.build_loss(
            loss,
            name2model[config["model_paths"][i]],
            name2tokenizer[config["model_paths"][i]],
            build_loss_args,
        )
    )
    lossfns[i].tokenizer.add_special_tokens({"mask_token": mlm_tokenizer.mask_token})
    loss2tokenizer[loss] = lossfns[i].tokenizer
# lossfns[0].tokenizer = loss2tokenizer[config["losses"][0]]
# lossfns[1].tokenizer = loss2tokenizer[config["losses"][1]]


https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/generation_config.json HTTP/1.1" 200 0


50265


https://huggingface.co:443 "HEAD /roberta-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0


In [249]:

def beam_rerank_v0(source_text, ## text (too arbitrary?)
                    masked_sequence, ## in mlm tokenizer's tokens
                    indices_in_mlm_tokens,
                    predicted_token_ids,
                    mlm_tokenizer, 
                    lossfns,
                    config, 
                    beam_size = 5):
    
    hypotheses = [torch.LongTensor([]).to(config['device'])]
    L = masked_sequence.size(-1)

    for i in range(L):
        if masked_sequence[0, i] != mlm_tokenizer.mask_token_id:
            hypotheses = list(torch.cat([torch.stack(hypotheses,dim=0), 
                                        masked_sequence[:, i].unsqueeze(0).repeat((len(hypotheses),1)).to(config['device'])],dim=-1))
        else:
            num_hypotheses = len(hypotheses)
            hypotheses = torch.stack(hypotheses,dim=0).unsqueeze(0)
            hypotheses = hypotheses.repeat(config['k_per_location'], 1, 1)
            candidates = predicted_token_ids.indices[torch.where(indices_in_mlm_tokens == i)[0], :].to(config['device']).T.unsqueeze(1)
            candidates = candidates.repeat(1, num_hypotheses, 1)
            hypotheses_exp = torch.cat([hypotheses, candidates], dim=-1)
            hypotheses_exp = hypotheses_exp.view(-1, hypotheses_exp.shape[-1])
            hypotheses_exp = list(hypotheses_exp)

            losses = []
            loss_weights = [1 - config['closs_weight'], config['closs_weight']]
            for hyp in hypotheses_exp:
                curr_loss = 0.0
                for lossid, lossname in enumerate(config["losses"]):
                    with torch.no_grad():
                        lossvalue = lossfns[lossid].compute_gold_loss(
                            source_text, mlm_tokenizer.decode(hyp, skip_special_tokens=True),
                            label_id=config['target_label_ids'][lossid],
                        )
                    curr_loss += loss_weights[lossid] * lossvalue.item()
                losses.append(curr_loss)

            hypotheses = sorted(zip(hypotheses_exp, losses), key=lambda x: x[1])[:beam_size]
            hypotheses = [x[0] for x in hypotheses]
            
    return [mlm_tokenizer.decode(x, skip_special_tokens=True) for x in hypotheses]


In [251]:
resume_idx = 34
label_ids = config["target_label_ids"]  # target label's ids for each loss

## beginning of main logic
# text_id = 0

interrupted = False
for text_id in range(len(source_dataset))[resume_idx:]:
    source_text = source_dataset[text_id]
    if source_text == "":
        source_text = lossfns[0].tokenizer.bos_token

    if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
        AR_prediction_all = [x["text"] for x in generation_dataset[text_id]]
        # predicted_batches = [x["tokens"] for x in generation_dataset[text_id]]
        # predicted_batches = [
        #     torch.tensor([x], dtype=torch.long, device=config["device"])
        #     for x in predicted_batches
        # ]
        
    elif (config["task"] == "formality") or (
        config["task"] == "sentiment-lewis-compr"
    ):
        AR_prediction_all = [generation_dataset[text_id]]

    sample_idx = 0
    curr_num_samples = len(AR_prediction_all)
    # for sample_idx in range(config["num_samples"])[:]:
    for sample_idx in range(curr_num_samples): ## updated (3/15)
        
        ## commented out (3/15) : dev set doesn't have the space problem.
        # if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
        #     predicted_batch = predicted_batches[sample_idx].cuda()
        #     AR_prediction = lossfns[0].tokenizer.batch_decode(predicted_batch)[0]
        # else:
        AR_prediction = AR_prediction_all[sample_idx]

        logger.debug(
            f"text_id {text_id} sample_id {sample_idx} \n[prompt] {source_text} [text] {AR_prediction}"
        )

        # --------------------------------------------------------------------------------------------- #
        ## check whether initial text satisfies constraint
        allsat = True
        gold_losses = []
        curr_loss = 0.0
        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        for lossid, lossname in enumerate(config["losses"]):
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, AR_prediction,
                    label_id=label_ids[lossid],
                )
                
            gold_losses.append(lossvalue.squeeze().item())
            curr_loss += loss_weights[lossid] * lossvalue.squeeze().item()
            if (lossid >= 1) and (gold_losses[lossid] > -np.log(
                config["min_epsilons"][lossid - 1]
            )):
                allsat = False

        if (allsat) and (not config["dont_skip_allsat"]):
            logger.info(
                f"skipping this sample since it already satisfies constraint. {gold_losses}"
            )
            if sample_idx == 0:
                output = {
                    "prompt": {
                        "text": source_text,
                    },
                    "generations": [
                        {
                            "text": AR_prediction,
                            "indices": [[]],
                            "allsat": allsat,
                            "losses": gold_losses,
                            "weighted_loss": curr_loss,
                            "edited": False,
                        }
                    ],
                }
                intermediate_output = {
                    "prompt": {
                        "text": source_text,
                    },
                    "generations": [
                        {}
                    ],
                }
            else:
                output["generations"].append(
                    {
                        "text": AR_prediction,
                        "indices": [[]],
                        "allsat": allsat,
                        "losses": gold_losses,
                        "weighted_loss": curr_loss,
                        "edited": False,
                    }       
                )
                intermediate_output['generations'].append({})

            # if sample_idx + 1 == config["num_samples"]:
            if sample_idx + 1 == curr_num_samples:
                outputs_old.append(output)
                int_outputs_old.append(intermediate_output)
                break

        else:
            es_patience_count = 0
            best_ix = None
            best_allsat = allsat
            best_losses = gold_losses
            best_weighted_loss = curr_loss                
            running_text = best_text = AR_prediction
            int_output = {}

            _iter = 0
            for _iter in range(config['n_iter']):
                ## locate tokens to edit
                masked_text  = locate_main(running_text, 
                                        config["locate_method"], 
                                        name2model[config["model_paths"][1]], 
                                        name2tokenizer[config["tokenizer_paths"][1]], 
                                        max_num_tokens = config['num_edit_token_per_step'], 
                                        unit=config["locate_unit"], 
                                        device=config['device'], 
                                        label_id=config["target_label_ids"][1],
                                        num_layer=10)
                logger.debug(f"iter {_iter}, sample_idx: {sample_idx}")
                logger.debug(f"locate result: {masked_text}")
                
                if config["method"] == "mlm-beamsearch-v2":
                    pass
                else:
                    ## replace tokens at the indices with mask tokens
                    inputs = mlm_tokenizer(
                        masked_text, return_tensors="pt"
                    )
                    # inputs = mlm_tokenizer(
                    #     source_text + ' ' + masked_text[0], return_tensors="pt", add_special_tokens=False
                    # )
                    
                    ## make predictions for the masked indices
                    with torch.no_grad():
                        logits = mlm(**inputs).logits
                    indices_in_mlm_tokens = (
                        inputs.input_ids == mlm_tokenizer.mask_token_id
                    )[0].nonzero(as_tuple=True)[0]
                    # print(f"indices_in_mlm_tokens: {indices_in_mlm_tokens}")
                    ## get top k tokens for each index
                    
                    ## make logits for special tokens -inf.
                    special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
                    logits[:, :, special_token_ids] = -np.inf
                    
                    predicted_token_ids = torch.topk(
                        logits[0, indices_in_mlm_tokens],
                        k=config['k_per_location'],
                        dim=-1,
                    )
                    # print(f"predicted_token_ids: {predicted_token_ids}")
                    # print(f"mlm_tokenizer.batch_decode(predicted_token_ids.indices): {mlm_tokenizer.batch_decode(predicted_token_ids.indices)}")
                    
                if config["method"] == "mlm-beamsearch-v0":
                    # print(config["method"])
                    hypotheses = beam_rerank_v0(source_text,
                                                inputs.input_ids,
                                                indices_in_mlm_tokens,
                                                predicted_token_ids,
                                                mlm_tokenizer, 
                                                lossfns,
                                                config, 
                                                beam_size = config['beam_size'])
                elif config["method"] == "mlm-beamsearch-v1":
                    hypotheses = beam_rerank_v1(source_text,
                                                inputs.input_ids,
                                                indices_in_mlm_tokens,
                                                predicted_token_ids,
                                                mlm_tokenizer, 
                                                lossfns,
                                                config, 
                                                beam_size = config['beam_size'])
                elif config["method"] == "mlm-beamsearch-v2":
                    source_batch = lossfns[0].tokenizer(source_text, add_special_tokens=False, return_tensors="pt").input_ids.to(config['device'])
                    masked_sequence = lossfns[0].tokenizer(masked_text, add_special_tokens=False, return_tensors="pt").input_ids.to(config['device'])
                    hypotheses = beam_rerank_v2(
                        source_batch,
                        masked_sequence,
                        lossfns[0].model,
                        lossfns[0].tokenizer,
                        config,
                        beam_size=config['beam_size'],
                    )
                elif config["method"] == "mlm-reranking":
                    hypotheses = combi_rerank(inputs.input_ids, ## in mlm tokenizer's tokens
                        indices_in_mlm_tokens,
                        predicted_token_ids,
                        mlm_tokenizer,
                        config)

                candidate_total_losses = []
                candidate_primary_losses = []
                candidate_losses_for_loggings = []
                candidate_allsats = []
                
                for hyp in hypotheses:
                    curr_loss = 0.0
                    logging_loss = []
                    allsat = True
                    for lossid, lossname in enumerate(config["losses"]):
                        with torch.no_grad():
                            lossvalue = lossfns[lossid].compute_gold_loss(
                                source_text, hyp,
                                label_id=config['target_label_ids'][lossid],
                            )
                        curr_loss += loss_weights[lossid] * lossvalue.item()
                        logging_loss.append(lossvalue.item())
                        if lossid==0:
                            candidate_primary_losses.append(lossvalue.item())
                        elif (lossid >= 1) and (
                            lossvalue.item()
                            > -np.log(config["min_epsilons"][lossid - 1])
                        ):
                            allsat = False
                    candidate_total_losses.append(curr_loss)
                    candidate_losses_for_loggings.append(logging_loss)
                    candidate_allsats.append(allsat)


                if config['selection_criteria'] == "weighted_sum":
                    best_ix = np.argmin(np.array(candidate_total_losses))
                elif config['selection_criteria'] == "allsat_primary":
                    allsat_ix = np.where(np.array(candidate_allsats) == True)[0]
                    if len(allsat_ix) > 0:
                        best_ix = np.argmin(
                            np.array(candidate_primary_losses)[allsat_ix]
                        )  # select min primary loss among allsats
                        best_ix = allsat_ix[best_ix]
                    else:  # if no candidate satisfying constraints, default to weighted_sum
                        best_ix = np.argmin(np.array(candidate_total_losses))
                    
                update = False
                if config['selection_criteria'] == "weighted_sum":
                    if best_weighted_loss > candidate_total_losses[best_ix]:
                        update = True
                elif config['selection_criteria'] == "allsat_primary":
                    if (
                        best_allsat is False
                        and candidate_allsats[best_ix] is True
                    ):
                        update = True
                    elif (
                        best_allsat is False
                        and candidate_allsats[best_ix] is False
                    ):
                        if best_weighted_loss > candidate_total_losses[best_ix]:
                            update = True
                    elif (
                        best_allsat is True
                        and candidate_allsats[best_ix] is True
                    ):
                        if (
                            best_losses[0]
                            > candidate_losses_for_loggings[best_ix][0]
                        ):
                            update = True


                ## intermediate output for debugging
                int_output.update({f"iter{_iter}_original_sentence": running_text,
                                f"iter{_iter}_masked_sentence": masked_text,
                                f"iter{_iter}_best_text": hypotheses[best_ix],
                                f"iter{_iter}_update": update})    
                
                running_text = hypotheses[best_ix]
                if update:
                    ## save the best prediction in a format compatible with mucola outputs
                    best_text = hypotheses[best_ix]
                    best_allsat = candidate_allsats[best_ix]
                    best_losses = candidate_losses_for_loggings[best_ix]
                    best_weighted_loss = candidate_total_losses[best_ix]

                    logger.debug(f"iter {_iter}. Update best prediction")
                    logger.debug(f"best_text: {best_text}")
                
                if best_allsat:
                    es_patience_count += 1
                    if (config["early_stopping_patience"] != -1) and (es_patience_count > config["early_stopping_patience"]):
                        logger.info(f"early stopping at iter {_iter}")
                        break

            if sample_idx == 0:
                output = {
                    "prompt": {
                        "text": source_text,
                    },
                    "generations": [
                        {
                            "text": best_text,
                            "original_text": AR_prediction,
                            "allsat": best_allsat,
                            "losses": best_losses,
                            "weighted_loss": best_weighted_loss,
                            "edited": True,
                        }
                    ],
                }
                
                intermediate_output = {
                    "prompt": {
                        "text": source_text,
                    },
                    "generations": [
                        int_output
                    ],
                }
            else:
                output["generations"].append(
                    {
                            "text": best_text,
                            "original_text": AR_prediction,
                            "allsat": best_allsat,
                            "losses": best_losses,
                            "weighted_loss": best_weighted_loss,
                            "edited": True,
                    }
                )
                
                intermediate_output["generations"].append(int_output)

            # if sample_idx + 1 == config["num_samples"]:
            if sample_idx + 1 == curr_num_samples:
                outputs_old.append(output)
                int_outputs_old.append(intermediate_output)
                break
    break

text_id 34 sample_id 0 
[prompt] And if you’re going to be shying away because someone calls you a bigot or a [text]  homophobe, you've got to prove it to everybody in the media.


iter 0, sample_idx: 0
locate result: ["<mask><mask><mask><mask> you've got to prove it to everybody in the media."]
iter 0. Update best prediction
best_text: ", I think you've got to prove it to everybody in the media.
early stopping at iter 0
text_id 34 sample_id 1 
[prompt] And if you’re going to be shying away because someone calls you a bigot or a [text]  sexist ’don't treat me like I'm a bitch.
iter 0, sample_idx: 1
locate result: [" sexist<mask><mask><mask><mask><mask> me like I'm a<mask><mask>"]
iter 0. Update best prediction
best_text:  sexist, I want to see me like I'm a man.
iter 1, sample_idx: 1
locate result: ["<mask><mask> I want to see me<mask> I'm a<mask><mask>"]
iter 1. Update best prediction
best_text: ". I want to see me. I'm a man.
early stopping at iter 1
text_id 34 sample_id 2 
[prompt] And if you’re going to be shying away because someone calls you a bigot or a [text]  racist, or a pro-whistleblower or a conspiracy theorist or says you're in it
iter 0, sample_idx:

In [253]:
outputs_old=outputs_old[1:]

In [256]:
outputs_old[0]['generations'][0]['text'] == outputs[0]['generations'][0]['text']

True

In [257]:
outputs_old[0]['generations'][1]['text'] == outputs[0]['generations'][1]['text']

True

In [258]:
outputs_old[0]['generations'][2]['text'] == outputs[0]['generations'][2]['text']

True

In [260]:
for i in range(3,len(outputs_old[0]['generations'])):
    print(outputs_old[0]['generations'][i]['text'] == outputs[0]['generations'][i]['text'])

True
True
True
True


In [None]:
int_outputs_old[0]['generations'][1]

{'iter0_original_sentence': ' wearing games and holy ****ing shit do I hate horse wearing games.',
 'iter0_masked_sentence': [' wearing games and<mask> ****ing<mask> do I hate horse wearing games.'],
 'iter0_best_text': ' wearing games and how ****ing much do I hate horse wearing games.',
 'iter0_update': True,
 'iter1_original_sentence': ' wearing games and how ****ing much do I hate horse wearing games.',
 'iter1_masked_sentence': [' wearing games and how<mask><mask> much do I hate<mask> wearing games.'],
 'iter1_best_text': ' wearing games and how and how much do I hate wearing wearing games.',
 'iter1_update': True}

In [None]:
outputs[0]['generations'][1]

{'text': ' wearing, and how, very much do I hate not wearing games.',
 'original_text': ' wearing games and holy ****ing shit do I hate horse wearing games.',
 'allsat': True,
 'losses': [62.027462005615234, 0.02681632898747921],
 'weighted_loss': 6.22688102722168,
 'edited': True}

In [None]:
int_outputs[0]['generations'][1] ## 2번째 iteration으로 넘어가면서 꼬인다. (1. best text 결과가 다르고 2. iter1_update: False 인데도 final best text에 업데이트가 되었다.)

{'iter0_original_sentence': ' wearing games and holy ****ing shit do I hate horse wearing games.',
 'iter0_masked_sentence': ' wearing games and<mask> ****ing<mask> do I hate horse wearing games.',
 'iter0_best_text': ' wearing games and how ****ing much do I hate horse wearing games.',
 'iter0_update': True,
 'iter1_original_sentence': ' wearing games and how ****ing much do I hate horse wearing games.',
 'iter1_masked_sentence': ' wearing<mask> and how<mask><mask> much do I hate<mask> wearing games.',
 'iter1_best_text': ' wearing, and how, very much do I hate not wearing games.',
 'iter1_update': False}

In [105]:
#!/usr/bin/env python
# coding: utf-8

class dummyArgs:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

build_loss_args = dummyArgs(**config["build_loss_dict"])

## load data
if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
    source_dataset = [
        json.loads(l)[config["jsonl_primary_key"]][config["jsonl_secondary_key"]]
        for l in open(config["source_data"])
    ]
    generation_dataset = [
        json.loads(l)["generations"] for l in open(config["source_data"])
    ]
elif (config["task"] == "formality") or (config["task"] == "sentiment-lewis-compr"):
    with open(config["source_data"], "r") as f:
        generation_dataset = [line.rstrip('\n') for line in f.readlines()]
    source_dataset = ["" for l in generation_dataset]


## load tokenizer, models, define losses
name2tokenizer = {}
name2model = {}
name2config = {}
loss2tokenizer = {}
embed_luts = []

for i, model_path in enumerate(config["model_paths"]):
    if (
        model_path not in name2model
    ):  # making sure we are not loading the model twice in case some constraints use the same model.
        try:
            name2tokenizer[model_path] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=True,
            )
        except:
            name2tokenizer[model_path] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=False,
            )

        name2config[model_path] = AutoConfig.from_pretrained(
            model_path, cache_dir=config["cache_dir"]
        )

        if config["model_types"][i] == "RobertaCustomForSequenceClassification":
            name2model[model_path] = lossbuilder.ModelWrapper(
                RobertaCustomForSequenceClassification.from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        else:
            name2model[model_path] = lossbuilder.ModelWrapper(
                getattr(transformers, config["model_types"][i]).from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        name2model[model_path].eval()
        name2model[model_path].cuda()

    input_embeds = name2model[model_path].get_input_embeddings()
    if isinstance(input_embeds, torch.nn.Sequential):
        input_embeds = input_embeds[0]
    embed_luts.append(input_embeds)

    if config["target_type"] == "embeds":
        embed_luts[-1].requires_grad = False

mlm_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
mlm = None if config["method"] == "mlm-beamsearch-v2" else AutoModelForMaskedLM.from_pretrained("roberta-base")  

lossfns_old = []
for i, loss in enumerate(config["losses"]):
    lossfns_old.append(
        lossbuilder.build_loss(
            loss,
            name2model[config["model_paths"][i]],
            name2tokenizer[config["model_paths"][i]],
            build_loss_args,
        )
    )
    lossfns_old[i].tokenizer.add_special_tokens({"mask_token": mlm_tokenizer.mask_token})
    loss2tokenizer[loss] = lossfns_old[i].tokenizer
# lossfns_old[0].tokenizer = loss2tokenizer[config["losses"][0]]
# lossfns_old[1].tokenizer = loss2tokenizer[config["losses"][1]]

text_id=0
source_text = source_dataset[text_id]
if source_text == "":
    source_text = lossfns_old[0].tokenizer.bos_token

if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
    AR_prediction_all = [x["text"] for x in generation_dataset[text_id]]    
elif (config["task"] == "formality") or (
    config["task"] == "sentiment-lewis-compr"
):
    AR_prediction_all = [generation_dataset[text_id]]

sample_idx = 1
curr_num_samples = len(AR_prediction_all)
AR_prediction = AR_prediction_all[sample_idx]

allsat = True
gold_losses = []
curr_loss = 0.0
loss_weights = [1 - config['closs_weight'], config['closs_weight']]
for lossid, lossname in enumerate(config["losses"]):
    with torch.no_grad():
        lossvalue = lossfns_old[lossid].compute_gold_loss(
            source_text, AR_prediction,
            label_id=label_ids[lossid],
        )
        
    gold_losses.append(lossvalue.squeeze().item())
    curr_loss += loss_weights[lossid] * lossvalue.squeeze().item()
    if (lossid >= 1) and (gold_losses[lossid] > -np.log(
        config["min_epsilons"][lossid - 1]
    )):
        allsat = False

es_patience_count_old = 0
best_ix_old = None
best_allsat_old = allsat
best_losses_old = gold_losses
best_weighted_loss_old = curr_loss                
running_text_old = best_text_old = AR_prediction
int_output_old = {}

_iter = 0
## locate tokens to edit
masked_text_old  = locate_main(running_text_old, 
                        config["locate_method"], 
                        name2model[config["model_paths"][1]], 
                        name2tokenizer[config["tokenizer_paths"][1]], 
                        max_num_tokens = config['num_edit_token_per_step'], 
                        unit=config["locate_unit"], 
                        device="cuda", 
                        label_id=config["target_label_ids"][1],
                        num_layer=10)

## replace tokens at the indices with mask tokens
inputs_old = mlm_tokenizer(
    masked_text_old, return_tensors="pt"
)
# inputs_old = mlm_tokenizer(
#     source_text + ' ' + masked_text_old[0], return_tensors="pt", add_special_tokens=False
# )

## make predictions for the masked indices
with torch.no_grad():
    logits_old = mlm(**inputs_old).logits
indices_in_mlm_tokens_old = (
    inputs_old.input_ids == mlm_tokenizer.mask_token_id
)[0].nonzero(as_tuple=True)[0]
# print(f"indices_in_mlm_tokens_old: {indices_in_mlm_tokens_old}")
## get top k tokens for each index

## make logits for special tokens -inf.
special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
logits_old[:, :, special_token_ids] = -np.inf

predicted_token_ids_old = torch.topk(
    logits_old[0, indices_in_mlm_tokens_old],
    k=config['k_per_location'],
    dim=-1,
)
# print(f"predicted_token_ids_old: {predicted_token_ids_old}")
# print(f"mlm_tokenizer.batch_decode(predicted_token_ids_old.indices): {mlm_tokenizer.batch_decode(predicted_token_ids_old.indices)}")
    
if config["method"] == "mlm-beamsearch-v0":
    # print(config["method"])
    hypotheses_old = beam_rerank_v0(source_text,
                                inputs_old.input_ids,
                                indices_in_mlm_tokens_old,
                                predicted_token_ids_old,
                                mlm_tokenizer, 
                                lossfns_old,
                                config, 
                                beam_size = config['beam_size'])
elif config["method"] == "mlm-beamsearch-v1":
    hypotheses_old = beam_rerank_v1(source_text,
                                inputs_old.input_ids,
                                indices_in_mlm_tokens_old,
                                predicted_token_ids_old,
                                mlm_tokenizer, 
                                lossfns_old,
                                config, 
                                beam_size = config['beam_size'])

elif config["method"] == "mlm-reranking":
    hypotheses_old = combi_rerank(inputs_old.input_ids, ## in mlm tokenizer's tokens
        indices_in_mlm_tokens_old,
        predicted_token_ids_old,
        mlm_tokenizer,
        config)

candidate_total_losses_old = []
candidate_primary_losses_old = []
candidate_losses_for_loggings_old = []
candidate_allsats_old = []

for hyp in hypotheses_old:
    curr_loss = 0.0
    logging_loss = []
    allsat = True
    for lossid, lossname in enumerate(config["losses"]):
        with torch.no_grad():
            lossvalue = lossfns_old[lossid].compute_gold_loss(
                source_text, hyp,
                label_id=config['target_label_ids'][lossid],
            )
        curr_loss += loss_weights[lossid] * lossvalue.item()
        logging_loss.append(lossvalue.item())
        if lossid==0:
            candidate_primary_losses_old.append(lossvalue.item())
        elif (lossid >= 1) and (
            lossvalue.item()
            > -np.log(config["min_epsilons"][lossid - 1])
        ):
            allsat = False
    candidate_total_losses_old.append(curr_loss)
    candidate_losses_for_loggings_old.append(logging_loss)
    candidate_allsats_old.append(allsat)


if config['selection_criteria'] == "weighted_sum":
    best_ix_old = np.argmin(np.array(candidate_total_losses_old))
elif config['selection_criteria'] == "allsat_primary":
    allsat_ix = np.where(np.array(candidate_allsats_old) == True)[0]
    if len(allsat_ix) > 0:
        best_ix_old = np.argmin(
            np.array(candidate_primary_losses_old)[allsat_ix]
        )  # select min primary loss among allsats
        best_ix_old = allsat_ix[best_ix_old]
    else:  # if no candidate satisfying constraints, default to weighted_sum
        best_ix_old = np.argmin(np.array(candidate_total_losses_old))
    
update = False
if config['selection_criteria'] == "weighted_sum":
    if best_weighted_loss_old > candidate_total_losses_old[best_ix_old]:
        update = True
elif config['selection_criteria'] == "allsat_primary":
    if (
        best_allsat_old is False
        and candidate_allsats_old[best_ix_old] is True
    ):
        update = True
    elif (
        best_allsat_old is False
        and candidate_allsats_old[best_ix_old] is False
    ):
        if best_weighted_loss_old > candidate_total_losses_old[best_ix_old]:
            update = True
    elif (
        best_allsat_old is True
        and candidate_allsats_old[best_ix_old] is True
    ):
        if (
            best_losses_old[0]
            > candidate_losses_for_loggings_old[best_ix_old][0]
        ):
            update = True


## intermediate output for debugging
int_output_old.update({f"iter{_iter}_original_sentence": running_text_old,
                f"iter{_iter}_masked_sentence": masked_text_old,
                f"iter{_iter}_best_text": hypotheses_old[best_ix_old],
                f"iter{_iter}_update": update})    

running_text_old = hypotheses_old[best_ix_old]

https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/generation_config.json HTTP/1.1" 200 0


50265


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 47.54 GiB of which 4.94 MiB is free. Process 3095476 has 40.40 GiB memory in use. Including non-PyTorch memory, this process has 7.12 GiB memory in use. Of the allocated memory 6.70 GiB is allocated by PyTorch, and 164.96 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
running_text_old

' wearing games and how ****ing much do I hate horse wearing games.'

In [None]:

_iter = 1
## locate tokens to edit
masked_text_old  = locate_main(running_text_old, 
                        config["locate_method"], 
                        name2model[config["model_paths"][1]], 
                        name2tokenizer[config["tokenizer_paths"][1]], 
                        max_num_tokens = config['num_edit_token_per_step'], 
                        unit=config["locate_unit"], 
                        device="cuda", 
                        label_id=config["target_label_ids"][1],
                        num_layer=10)


In [None]:
running_text_old

' wearing games and how ****ing much do I hate horse wearing games.'

In [None]:
running_text

[' wearing games and how ****ing much do I hate horse wearing games.']

In [None]:
masked_text_old

[' wearing games and how<mask><mask> much do I hate<mask> wearing games.']

In [None]:
masked_text

[' wearing<mask> and how<mask><mask> much do I hate<mask> wearing games.']

In [92]:
config['locate_method']

'grad_norm'

In [90]:
masked_text = locator.locate_main(running_text, 
                        method = config['locate_method'], 
                        max_num_tokens = config['num_edit_token_per_step'], 
                        unit = config['locate_unit'], 
                        num_layer = 10, #penultimate
                        label_id = config['target_label_ids'][1])

In [94]:
masked_text_old

[' wearing games and how<mask><mask> much do I hate<mask> wearing games.']

진짜 의외로 locate이 culprit 인것 같다.. 그렇다면 locate을 정확하게 했다고 치면 뒷부분은 동일한가? (이에 대한 확인은 위쪽 - new implementation - 섹션에서 진행하였다. -> 결론 : locate 이후는 동일하다.)

In [107]:

## replace tokens at the indices with mask tokens
inputs_old = mlm_tokenizer(
    masked_text_old, return_tensors="pt"
)
# inputs_old = mlm_tokenizer(
#     source_text + ' ' + masked_text_old[0], return_tensors="pt", add_special_tokens=False
# )


In [111]:
inputs_old=inputs_old.to(config['device'])

In [113]:
inputs.input_ids==inputs_old.input_ids

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True]], device='cuda:7')

In [108]:
inputs_old

{'input_ids': tensor([[    0,  2498,   426,     8,   141, 50264, 50264,   203,   109,    38,
          4157, 50264,  2498,   426,     4,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [114]:

## make predictions for the masked indices
with torch.no_grad():
    logits_old = mlm(**inputs_old).logits
indices_in_mlm_tokens_old = (
    inputs_old.input_ids == mlm_tokenizer.mask_token_id
)[0].nonzero(as_tuple=True)[0]
# print(f"indices_in_mlm_tokens_old: {indices_in_mlm_tokens_old}")
## get top k tokens for each index


In [115]:
indices_in_mlm_tokens_old

tensor([ 5,  6, 11], device='cuda:7')

In [116]:

## make logits for special tokens -inf.
special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
logits_old[:, :, special_token_ids] = -np.inf


In [117]:

predicted_token_ids_old = torch.topk(
    logits_old[0, indices_in_mlm_tokens_old],
    k=config['k_per_location'],
    dim=-1,
)
# print(f"predicted_token_ids_old: {predicted_token_ids_old}")
# print(f"mlm_tokenizer.batch_decode(predicted_token_ids_old.indices): {mlm_tokenizer.batch_decode(predicted_token_ids_old.indices)}")
    


In [118]:
predicted_token_ids_old

torch.return_types.topk(
values=tensor([[16.869, 16.116, 13.910, 13.801, 13.715, 12.183, 12.001, 11.853, 11.661,
         11.642],
        [14.058, 11.541, 11.299, 10.800, 10.641, 10.358, 10.337, 10.216, 10.194,
         10.193],
        [15.107, 13.711, 13.133, 12.689, 12.413, 12.123, 12.028, 11.980, 11.967,
         11.663]], device='cuda:7'),
indices=tensor([[    6,     8,    73,   116,   203,    12, 14223,  2230,     4,    50],
        [  141,  1336,  7105,  6179, 23523, 26536,   352,     6,  9178,   182],
        [   45,   604,  2498,    82,  2185,   390,   888,    59,   127,    47]],
       device='cuda:7'))

In [126]:
if config["method"] == "mlm-beamsearch-v0":
    # print(config["method"])
    hypotheses_old = beam_rerank_v0(source_text,
                                inputs_old.input_ids,
                                indices_in_mlm_tokens_old,
                                predicted_token_ids_old,
                                mlm_tokenizer, 
                                lossfns_old,
                                config, 
                                beam_size = config['beam_size'])
elif config["method"] == "mlm-beamsearch-v1":
    hypotheses_old = beam_rerank_v1(source_text,
                                inputs_old.input_ids,
                                indices_in_mlm_tokens_old,
                                predicted_token_ids_old,
                                mlm_tokenizer, 
                                lossfns_old,
                                config, 
                                beam_size = config['beam_size'])

elif config["method"] == "mlm-reranking":
    hypotheses_old = combi_rerank(inputs_old.input_ids, ## in mlm tokenizer's tokens
        indices_in_mlm_tokens_old,
        predicted_token_ids_old,
        mlm_tokenizer,
        config)


In [130]:
[x==y for x,y in zip(hypotheses[0],hypotheses_old)]

[True, True, True, True, True]

In [127]:
hypotheses_old

[' wearing games and how and how much do I hate my wearing games.',
 ' wearing games and how and how much do I hate wearing wearing games.',
 ' wearing games and how and how much do I hate about wearing games.',
 ' wearing games and how and how much do I hate myself wearing games.',
 ' wearing games and how much very much do I hate my wearing games.']

In [131]:

candidate_total_losses_old = []
candidate_primary_losses_old = []
candidate_losses_for_loggings_old = []
candidate_allsats_old = []

for hyp in hypotheses_old:
    curr_loss = 0.0
    logging_loss = []
    allsat = True
    for lossid, lossname in enumerate(config["losses"]):
        with torch.no_grad():
            lossvalue = lossfns_old[lossid].compute_gold_loss(
                source_text, hyp,
                label_id=config['target_label_ids'][lossid],
            )
        curr_loss += loss_weights[lossid] * lossvalue.item()
        logging_loss.append(lossvalue.item())
        if lossid==0:
            candidate_primary_losses_old.append(lossvalue.item())
        elif (lossid >= 1) and (
            lossvalue.item()
            > -np.log(config["min_epsilons"][lossid - 1])
        ):
            allsat = False
    candidate_total_losses_old.append(curr_loss)
    candidate_losses_for_loggings_old.append(logging_loss)
    candidate_allsats_old.append(allsat)


In [132]:


if config['selection_criteria'] == "weighted_sum":
    best_ix_old = np.argmin(np.array(candidate_total_losses_old))
elif config['selection_criteria'] == "allsat_primary":
    allsat_ix = np.where(np.array(candidate_allsats_old) == True)[0]
    if len(allsat_ix) > 0:
        best_ix_old = np.argmin(
            np.array(candidate_primary_losses_old)[allsat_ix]
        )  # select min primary loss among allsats
        best_ix_old = allsat_ix[best_ix_old]
    else:  # if no candidate satisfying constraints, default to weighted_sum
        best_ix_old = np.argmin(np.array(candidate_total_losses_old))
    


In [134]:
best_ix_old, candidate_total_losses_old[best_ix_old], candidate_losses_for_loggings_old[best_ix_old], candidate_allsats_old[best_ix_old]

(1, 6.6754633940756305, [66.51388549804688, 0.026749826967716217], True)

In [135]:
hypotheses_old[best_ix_old]

' wearing games and how and how much do I hate wearing wearing games.'

In [None]:
update = False
if config['selection_criteria'] == "weighted_sum":
    if best_weighted_loss_old > candidate_total_losses_old[best_ix_old]:
        update = True
elif config['selection_criteria'] == "allsat_primary":
    if (
        best_allsat_old is False
        and candidate_allsats_old[best_ix_old] is True
    ):
        update = True
    elif (
        best_allsat_old is False
        and candidate_allsats_old[best_ix_old] is False
    ):
        if best_weighted_loss_old > candidate_total_losses_old[best_ix_old]:
            update = True
    elif (
        best_allsat_old is True
        and candidate_allsats_old[best_ix_old] is True
    ):
        if (
            best_losses_old[0]
            > candidate_losses_for_loggings_old[best_ix_old][0]
        ):
            update = True


## intermediate output for debugging
int_output_old.update({f"iter{_iter}_original_sentence": running_text_old,
                f"iter{_iter}_masked_sentence": masked_text_old,
                f"iter{_iter}_best_text": hypotheses_old[best_ix_old],
                f"iter{_iter}_update": update})    

running_text_old = hypotheses_old[best_ix_old]