In [2]:
# -*- coding: utf-8 -*-
import argparse
import os
import string
import sys
os.chdir('/data/hyeryung/mucoco')
from itertools import repeat
import torch.multiprocessing as mp
from typing import List
from itertools import permutations,product
import math

import numpy as np
import pandas as pd
import torch
import transformers
# from datasets import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForMaskedLM, AutoTokenizer, AutoConfig

from new_module.utils.robertacustom import RobertaCustomForSequenceClassification
from new_module.locate.new_locate_utils import *
import new_module.losses as lossbuilder

import logging
import os
from typing import List,Tuple
from tqdm import tqdm

import torch
import torch.nn.functional as F
import transformers

In [3]:
## setup for prototyping
logging.basicConfig(level=logging.DEBUG, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("LOGGING_LEVEL", logging.DEBUG))

config={#'model_paths':['gpt2-large','/shared/s3/lab07/hyeryung/loc_edit/roberta-base-jigsaw-toxicity-classifier-energy-training/step_600_best_checkpoint'],
        # 'tokenizer_paths':['gpt2-large','/shared/s3/lab07/hyeryung/loc_edit/roberta-base-jigsaw-toxicity-classifier-energy-training/step_600_best_checkpoint'],
        'model_paths':['gpt2-large','/data/hyeryung/loc_edit/models/roberta-base-jigsaw-toxicity-classifier-energy-training/step_1000_best_checkpoint'],
        'tokenizer_paths':['gpt2-large','/data/hyeryung/loc_edit/models/roberta-base-jigsaw-toxicity-classifier-energy-training/step_1000_best_checkpoint'],
        'model_types': ["AutoModelForCausalLM", "AutoModelForSequenceClassification"],
        'cache_dir': "/data/hyeryung/hf_cache",
        'target_type': "embeds",
        'method': "mlm-beamsearch-v0",
       'losses': ["gpt2", "classification_no_prefix_logprobloss"],
       'target_label_ids': [0,0] ,
       'build_loss_dict': {"coeff_steps": 200, "coeff_pattern": "constant", "loss_type": "xentropy", "length_normalize": False, "AR_temperature": 1.0, "AR_top_k": 0, "AR_top_p": 0.96, "max_output_length": 20},
       'min_epsilons': [0.75],
       'source_data': 'new_module/data/toxicity-avoidance/testset_gpt2_2500.jsonl',
       'locate_unit': 'token',
       'locate_method': 'grad_norm',
       'device': 'cuda:7',
       'k_per_location': 3,
       'closs_weight': 0.9,
       'beam_size': 3,
       'selection_criteria':"weighted_sum"}

class dummyArgs:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

build_loss_args = dummyArgs(**config["build_loss_dict"])

In [4]:
name2tokenizer = {}
name2model = {}
name2config = {}
loss2tokenizer = {}
embed_luts = []
primary_model = None
for i, model_path in enumerate(config["model_paths"]):
    if (
        model_path not in name2model
    ):  # making sure we are not loading the model twice in case some constraints use the same model.
        try:
            name2tokenizer[model_path] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=True,
            )
        except:
            name2tokenizer[model_path] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=False,
            )

        name2config[model_path] = AutoConfig.from_pretrained(
            model_path, cache_dir=config["cache_dir"]
        )

        if "Custom" in config["model_types"][i]:
            name2model[model_path] = lossbuilder.ModelWrapper(
                getattr(utils, config["model_types"][i]).from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        else:
            name2model[model_path] = lossbuilder.ModelWrapper(
                getattr(transformers, config["model_types"][i]).from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        name2model[model_path].eval()
        name2model[model_path].to(config['device'])

    input_embeds = name2model[model_path].get_input_embeddings()
    if isinstance(input_embeds, torch.nn.Sequential):
        input_embeds = input_embeds[0]
    embed_luts.append(input_embeds)

    if config["target_type"] == "embeds":
        embed_luts[-1].requires_grad = False

    if i == 0:
        primary_model = name2model[model_path]

mlm_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
mlm = None if config["method"] == "mlm-beamsearch-v2" else AutoModelForMaskedLM.from_pretrained("roberta-base").to(config['device'])  

lossfns = []
for i, loss in enumerate(config["losses"]):
    lossfns.append(
        lossbuilder.build_loss(
            loss,
            name2model[config["model_paths"][i]],
            name2tokenizer[config["model_paths"][i]],
            build_loss_args,
        )
    )
    lossfns[i].tokenizer.add_special_tokens({"mask_token": mlm_tokenizer.mask_token})
    loss2tokenizer[loss] = lossfns[i].tokenizer

primary_tokenizer = name2tokenizer['gpt2-large']
secondary_tokenizer = list(name2tokenizer.values())[1]

Starting new HTTPS connection (1): huggingface.co:443
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/generation_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /roberta-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0


In [5]:
def unravel(outputs_df):
    outputs_df=outputs_df.explode('generations',ignore_index=True)
    outputs_df['prompt']=outputs_df['prompt'].apply(lambda x: x['text'])
    outputs_df['generations']=outputs_df['generations'].apply(lambda x: x['text'] if isinstance(x, dict) else x)
    outputs_df = outputs_df.dropna().reset_index(drop=True)
    return outputs_df
toxic_gen=pd.read_json('new_module/data/toxicity-avoidance/dev_set.jsonl',lines=True)
toxic_gen=unravel(toxic_gen)
source_text = '“I went to use the bathroom, and I went in and Lance was bent'
prediction = toxic_gen.loc[toxic_gen['prompt']==source_text,'generations'].tolist()
ckpt_path = config['model_paths'][1]
model = name2model[ckpt_path] #AutoModelForSequenceClassification.from_pretrained(ckpt_path)
tokenizer = name2tokenizer[ckpt_path] #AutoTokenizer.from_pretrained(ckpt_path)

model = model.to(config['device'])
loc_machine=LocateMachine(model,tokenizer)
masked_text = loc_machine.locate_main(prediction, "grad_norm", max_num_tokens = 6, unit=config['locate_unit'], num_layer=10, label_id=0)

In [6]:
## replace tokens at the indices with mask tokens
inputs = mlm_tokenizer(
    masked_text, return_tensors="pt", padding=True, truncation=True
)
inputs = inputs.to(config['device']) 
masked_sequence=inputs['input_ids']

## make predictions for the masked indices
with torch.no_grad():
    logits = mlm(**inputs).logits

special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
logits[:, :, special_token_ids] = -float("inf")

indices_in_mlm_tokens = (
    inputs.input_ids == mlm_tokenizer.mask_token_id
).nonzero(as_tuple=True)

## get top k tokens for each index
predicted_token_ids = torch.topk(
    logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
    k=config['k_per_location'],
    dim=-1,
)
# print(f"predicted_token_ids: {predicted_token_ids}")
# print(f"mlm_tokenizer.batch_decode(predicted_token_ids.indices): {mlm_tokenizer.batch_decode(predicted_token_ids.indices)}")


In [10]:
def beam_rerank(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:tuple,
                    # predicted_token_ids:torch.return_types.topk,
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict):
    """params: 
    source_text(prompt) should be text. 
    masked_sequence should be token ids tokenized by MLM's tokenizer.
    indices_in_mlm_tokens should be a result of running 
    `indices_in_mlm_tokens = (
                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                ).nonzero(as_tuple=True)`
    predicted_token_ids should be a result of running
    `predicted_token_ids = torch.topk(
                            logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                            k=config['k_per_location'],
                            dim=-1,).indices`
    """
    
    hypotheses = masked_sequence[:, None, :].repeat((1,config['beam_size'],1))
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))

    # for i in tqdm(edit_indices,total=len(edit_indices)):
    for i in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==i]

        tmp_hypotheses = hypotheses[batch_ids_to_edit].detach().clone()
        tmp_hypotheses=tmp_hypotheses.repeat((1,config['k_per_location'],1))

        # candidates = predicted_token_ids.indices[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
        candidates = predicted_token_ids[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
        candidates = candidates[:, :, None].repeat((1,1, config['beam_size'])).reshape(candidates.shape[0], -1,1)

        tmp_hypotheses = torch.cat((tmp_hypotheses[:, :, :i], candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]
        tmp_hypotheses = tmp_hypotheses.reshape(-1, tmp_hypotheses.shape[-1])
        
        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        for lossid, lossname in enumerate(config["losses"]):
            if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
                continue
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                    label_id=config['target_label_ids'][lossid],
                )
            torch.cuda.empty_cache()
            curr_loss += loss_weights[lossid] * lossvalue
            
        curr_loss = torch.stack(torch.split(curr_loss, config['beam_size'] * config['k_per_location'], dim=0),dim=0)
        top_beams=torch.topk(curr_loss, k=(config['beam_size']*2+1), dim=-1, largest=False).indices

        tmp_hypotheses = torch.split(tmp_hypotheses, config['beam_size'] * config['k_per_location'], dim=0) ## 아래의 작업을 더 간단히 할 수 있는 방법?
        tmp_hypotheses = torch.stack([x[top_beams[j]] for j, x in enumerate(tmp_hypotheses)],dim=0)
        tmp_hypotheses = torch.unique(tmp_hypotheses, dim=1)[:, :config['beam_size'], :]
        
        hypotheses[batch_ids_to_edit,:, i]=tmp_hypotheses[:, :, i]
            
    return [mlm_tokenizer.batch_decode(hypotheses[j], skip_special_tokens=True) for j in range(hypotheses.shape[0])]

In [7]:
## get_combi_hypotheses 
def get_combi_hypotheses(masked_sequence:torch.Tensor, 
                 indices_in_mlm_tokens:tuple,
                 predicted_token_ids:torch.Tensor,
                 mlm_tokenizer:transformers.AutoTokenizer,
                 config:dict) -> List[str]:
    """params: 
    masked_sequence should be token ids tokenized by MLM's tokenizer.
    indices_in_mlm_tokens should be a result of running 
    `indices_in_mlm_tokens = (
                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                ).nonzero(as_tuple=True)`
    predicted_token_ids should be a result of running
    `predicted_token_ids = torch.topk(
                            logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                            k=config['k_per_location'],
                            dim=-1,).indices`
    """

    k = config['k_per_location']
    hypotheses = []
    num_batches = masked_sequence.shape[0]
    for i in range(num_batches):
        
        l = (indices_in_mlm_tokens[0] == i).sum().item()
        tok_cand_combos = list(product(range(k),repeat=l))
        
        tmp_hypotheses = masked_sequence[i,:].repeat((k**l,1))
        tmp_hypotheses[:, indices_in_mlm_tokens[1][indices_in_mlm_tokens[0] == i]] = \
            predicted_token_ids[indices_in_mlm_tokens[0] == i, tok_cand_combos]
            
        tmp_dec_seq = mlm_tokenizer.batch_decode(
                    tmp_hypotheses, skip_special_tokens=True
            )
        hypotheses.append(tmp_dec_seq)
    return hypotheses

In [51]:
# %%timeit

# my_hypotheses = get_combi_hypotheses(masked_sequence, 
#                             indices_in_mlm_tokens,
#                             predicted_token_ids.indices,
#                             mlm_tokenizer,
#                             config)

# loss_weights = [1 - config['closs_weight'], config['closs_weight']]
# selection_criteria = "weighted_sum"
# hypotheses = deepcopy(my_hypotheses)
# best_ixes = []
# best_weighted_loss = []
# best_allsat = []
# best_logging_loss = []
# num_batches = masked_sequence.shape[0]
# for i in tqdm(range(num_batches)):
       
#     curr_loss = torch.zeros(len(hypotheses[i])).to(config['device'])
#     logging_loss = torch.zeros((len(hypotheses[i]),2)).to(config['device'])

#     hyp_data = CustomDataset(hypotheses[i])
#     data_loader = DataLoader(hyp_data,batch_size=64)

#     for lossid, lossname in enumerate(config["losses"]):
#         lossvalues=[]
#         with torch.no_grad():
#             for batch in data_loader:
#                 lossvalue = lossfns[lossid].compute_gold_loss(
#                     source_text, batch,
#                     label_id=config['target_label_ids'][lossid],
#                 )
#                 lossvalues.append(lossvalue)
#                 torch.cuda.empty_cache()
#         lossvalue = torch.cat(lossvalues,dim=0)
#         curr_loss += loss_weights[lossid] * lossvalue
#         logging_loss[:, lossid] = lossvalue.clone()
        
#     allsat_ix = torch.where(logging_loss[:,1]> -np.log(config["min_epsilons"][0]))[0].squeeze(0)
#     if (allsat_ix.shape[0] > 0) and (selection_criteria == "allsat_primary"):
#         best_ix = allsat_ix[curr_loss[allsat_ix].argmin()]
#     else: ## in case selection_criteria == "weighted_sum" or allsat is all False
#         best_ix = torch.argmin(curr_loss)

    
#     hypotheses[i]=hypotheses[i][best_ix]
#     best_weighted_loss.append(curr_loss[best_ix].item())
#     best_allsat.append(1 if best_ix in allsat_ix else 0)
#     best_logging_loss.append(logging_loss[best_ix].cpu().numpy())
    
#     del curr_loss, logging_loss
#     torch.cuda.empty_cache()

100%|██████████| 10/10 [00:50<00:00,  5.01s/it]
100%|██████████| 10/10 [00:50<00:00,  5.08s/it]
100%|██████████| 10/10 [00:52<00:00,  5.26s/it]
100%|██████████| 10/10 [00:52<00:00,  5.27s/it]
100%|██████████| 10/10 [00:51<00:00,  5.13s/it]
100%|██████████| 10/10 [00:52<00:00,  5.21s/it]
100%|██████████| 10/10 [00:51<00:00,  5.11s/it]
100%|██████████| 10/10 [00:51<00:00,  5.12s/it]

52.1 s ± 680 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)





In [21]:

def final_reranking(hypotheses:List[List[str]],
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict,
                    batch_size:int=64) -> Tuple[List[str],List[float],List[int],List[List[float]]]:
    """params: 
        hypotheses: list of hypotheses of editing results
        lossfns:
        config:
        batch_size:             
    returns:
        hypotheses: list of one best hypothesis(editing result) for each of original texts. length same as masked_sequence.shape[0]
        best_weighted_loss: list of weighted loss for the best hypotheses.
        best_allsat: list of indicator(1,0) whether the best hypotheses satisfy cutoff (min_epsilons) for constraint energy score.
        best_logging_loss: list of list of fluency energy score and constraint energy score for each best hypothesis.
    """
    
    
    class CustomDataset(Dataset):
        def __init__(self, hypotheses_data:List[str]):
            self.hypotheses_data = hypotheses_data
            
        def __len__(self):
            return len(self.hypotheses_data)

        def __getitem__(self, idx:int):
            return self.hypotheses_data[idx]
        
        def __getitems__(self, idx:List[int]):
            return [self.hypotheses_data[j] for j in idx]
    
    final_hypotheses = []
    best_weighted_loss = []
    best_allsat = []
    best_logging_loss = []
    
    loss_weights = [1 - config['closs_weight'], config['closs_weight']]
    
    for i in tqdm(range(len(hypotheses))):
        curr_loss = torch.zeros(len(hypotheses[i])).to(config['device'])
        logging_loss = torch.zeros((len(hypotheses[i]),2)).to(config['device'])
        data_loader = DataLoader(CustomDataset(hypotheses[i]),batch_size=batch_size)

        for lossid, lossname in enumerate(config["losses"]):
            lossvalues=[]
            with torch.no_grad():
                for batch in data_loader:
                    lossvalue = lossfns[lossid].compute_gold_loss(
                        source_text, batch,
                        label_id=config['target_label_ids'][lossid],
                    )
                    lossvalues.append(lossvalue)
                    torch.cuda.empty_cache()
            lossvalue = torch.cat(lossvalues,dim=0)
            curr_loss += loss_weights[lossid] * lossvalue
            logging_loss[:, lossid] = lossvalue.clone()
            
        allsat_ix = torch.where(logging_loss[:,1]> -math.log(config["min_epsilons"][0]))[0].squeeze(0)
        if (allsat_ix.shape[0] > 0) and (config['selection_criteria'] == "allsat_primary"):
            best_ix = allsat_ix[curr_loss[allsat_ix].argmin()]
        else: ## in case config['selection_criteria'] == "weighted_sum" or allsat is all False
            best_ix = torch.argmin(curr_loss)

        final_hypotheses.append(hypotheses[i][best_ix])
        best_weighted_loss.append(curr_loss[best_ix].item())
        best_allsat.append(1 if best_ix in allsat_ix else 0)
        best_logging_loss.append(logging_loss[best_ix].cpu().tolist())
    
        del curr_loss, logging_loss
        torch.cuda.empty_cache()
    return final_hypotheses, best_weighted_loss, best_allsat, best_logging_loss

In [None]:
# hypotheses = beam_rerank(source_text, 
#                         masked_sequence, 
#                         indices_in_mlm_tokens,
#                         predicted_token_ids.indices,
#                         mlm_tokenizer, 
#                         lossfns,
#                         config)

In [None]:
%%timeit
my_hypotheses = beam_rerank(source_text, 
                        masked_sequence, 
                        indices_in_mlm_tokens,
                        predicted_token_ids.indices,
                        mlm_tokenizer, 
                        lossfns,
                        config)
final_result = final_reranking(my_hypotheses,
                                lossfns,
                                config,
                                batch_size=64)

100%|██████████| 10/10 [00:00<00:00, 25.15it/s]
100%|██████████| 10/10 [00:00<00:00, 25.57it/s]
100%|██████████| 10/10 [00:00<00:00, 25.10it/s]
100%|██████████| 10/10 [00:00<00:00, 25.80it/s]
100%|██████████| 10/10 [00:00<00:00, 25.42it/s]
100%|██████████| 10/10 [00:00<00:00, 25.59it/s]
100%|██████████| 10/10 [00:00<00:00, 25.06it/s]
100%|██████████| 10/10 [00:00<00:00, 25.43it/s]

2.31 s ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)





In [26]:
%%timeit
my_hypotheses = get_combi_hypotheses(masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer,
                            config)
final_result = final_reranking(my_hypotheses,
                                lossfns,
                                config,
                                batch_size=64)

100%|██████████| 10/10 [00:18<00:00,  1.87s/it]
100%|██████████| 10/10 [00:18<00:00,  1.88s/it]
100%|██████████| 10/10 [00:18<00:00,  1.88s/it]
100%|██████████| 10/10 [00:18<00:00,  1.89s/it]
100%|██████████| 10/10 [00:18<00:00,  1.89s/it]
100%|██████████| 10/10 [00:18<00:00,  1.90s/it]
100%|██████████| 10/10 [00:18<00:00,  1.90s/it]
100%|██████████| 10/10 [00:18<00:00,  1.90s/it]

19 s ± 65.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)





In [23]:

def get_beam_hypotheses(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:Tuple[torch.Tensor],
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict) -> List[List[str]]:
    """
    A function to get hypotheses of beam size via editing beam search with reranking.
    Run this function if config['method'] == 'mlm-beamsearch-v0' or config['method'] == 'mlm-beamsearch-v1'
    If config['method'] == 'mlm-beamsearch-v1', rerank beam only with fluency energy.
    If config['method'] == 'mlm-beamsearch-v0', rerank beam with a weighted sum of fluency and constraint energy.
    
    #ToDo
    #Implement mlm-beamsearch-v0 with allsat-primary and compare 
    
    params: 
        source_text: a prompt text 
        masked_sequence: token ids of original generation text with located indices masked. tokenized by MLM's tokenizer.
        indices_in_mlm_tokens: a result of running 
                                    `indices_in_mlm_tokens = (
                                                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                                                ).nonzero(as_tuple=True)`
        predicted_token_ids: a result of running
                                    `predicted_token_ids = torch.topk(
                                                                logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                                                                k=config['k_per_location'],
                                                                dim=-1,).indices`
        mlm_tokenizer: tokenizer of MLM
        lossfns: a list of loss functions
        config: a dictionary of configurations
    
    returns:
        hypotheses: a list of a list of the beam number of hypotheses for each sample         
    """
    
    hypotheses = masked_sequence[:, None, :].repeat((1,config['beam_size'],1))
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))

    for i in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==i]

        tmp_hypotheses = hypotheses[batch_ids_to_edit].detach().clone()
        tmp_hypotheses=tmp_hypotheses.repeat((1,config['k_per_location'],1))

        candidates = predicted_token_ids[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
        candidates = candidates[:, :, None].repeat((1,1, config['beam_size'])).reshape(candidates.shape[0], -1,1)

        tmp_hypotheses = torch.cat((tmp_hypotheses[:, :, :i], candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]
        tmp_hypotheses = tmp_hypotheses.reshape(-1, tmp_hypotheses.shape[-1])
        
        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        for lossid, lossname in enumerate(config["losses"]):
            if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
                continue
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                    label_id=config['target_label_ids'][lossid],
                )
            torch.cuda.empty_cache()
            curr_loss += loss_weights[lossid] * lossvalue
            
        curr_loss = torch.stack(torch.split(curr_loss, config['beam_size'] * config['k_per_location'], dim=0),dim=0)
        top_beams=torch.topk(curr_loss, k=(config['beam_size']*2+1), dim=-1, largest=False).indices

        tmp_hypotheses = torch.split(tmp_hypotheses, config['beam_size'] * config['k_per_location'], dim=0) ## 아래의 작업을 더 간단히 할 수 있는 방법?
        tmp_hypotheses = torch.stack([x[top_beams[j]] for j, x in enumerate(tmp_hypotheses)],dim=0)
        tmp_hypotheses = torch.unique(tmp_hypotheses, dim=1)[:, :config['beam_size'], :]
        
        hypotheses[batch_ids_to_edit,:, i]=tmp_hypotheses[:, :, i]
            
    return [mlm_tokenizer.batch_decode(hypotheses[j], skip_special_tokens=True) for j in range(hypotheses.shape[0])]

In [None]:

def get_beam_hypotheses(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:Tuple[torch.Tensor],
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict) -> List[List[str]]:
    """
    A function to get hypotheses of beam size via editing beam search with reranking.
    Run this function if config['method'] == 'mlm-beamsearch-v0' or config['method'] == 'mlm-beamsearch-v1'
    If config['method'] == 'mlm-beamsearch-v1', rerank beam only with fluency energy.
    If config['method'] == 'mlm-beamsearch-v0', rerank beam with a weighted sum of fluency and constraint energy.
    
    #ToDo
    #Implement mlm-beamsearch-v0 with allsat-primary and compare 
    
    params: 
        source_text: a prompt text 
        masked_sequence: token ids of original generation text with located indices masked. tokenized by MLM's tokenizer.
        indices_in_mlm_tokens: a result of running 
                                    `indices_in_mlm_tokens = (
                                                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                                                ).nonzero(as_tuple=True)`
        predicted_token_ids: a result of running
                                    `predicted_token_ids = torch.topk(
                                                                logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                                                                k=config['k_per_location'],
                                                                dim=-1,).indices`
        mlm_tokenizer: tokenizer of MLM
        lossfns: a list of loss functions
        config: a dictionary of configurations
    
    returns:
        hypotheses: a list of a list of the beam number of hypotheses for each sample         
    """
    
    hypotheses = masked_sequence[:, None, :].repeat((1,config['beam_size'],1))
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))

    for i in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==i]

        tmp_hypotheses = hypotheses[batch_ids_to_edit].detach().clone()
        tmp_hypotheses=tmp_hypotheses.repeat((1,config['k_per_location'],1))

        candidates = predicted_token_ids[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
        candidates = candidates[:, :, None].repeat((1,1, config['beam_size'])).reshape(candidates.shape[0], -1,1)

        tmp_hypotheses = torch.cat((tmp_hypotheses[:, :, :i], candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]
        tmp_hypotheses = tmp_hypotheses.reshape(-1, tmp_hypotheses.shape[-1])
        
        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        for lossid, lossname in enumerate(config["losses"]):
            if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
                continue
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                    label_id=config['target_label_ids'][lossid],
                )
            torch.cuda.empty_cache()
            curr_loss += loss_weights[lossid] * lossvalue
            
        curr_loss = torch.stack(torch.split(curr_loss, config['beam_size'] * config['k_per_location'], dim=0),dim=0)
        top_beams=torch.topk(curr_loss, k=(config['beam_size']*2+1), dim=-1, largest=False).indices

        tmp_hypotheses = torch.split(tmp_hypotheses, config['beam_size'] * config['k_per_location'], dim=0) ## 아래의 작업을 더 간단히 할 수 있는 방법?
        tmp_hypotheses = torch.stack([x[top_beams[j]] for j, x in enumerate(tmp_hypotheses)],dim=0)
        tmp_hypotheses = torch.unique(tmp_hypotheses, dim=1)[:, :config['beam_size'], :]
        
        hypotheses[batch_ids_to_edit,:, i]=tmp_hypotheses[:, :, i]
            
    return [mlm_tokenizer.batch_decode(hypotheses[j], skip_special_tokens=True) for j in range(hypotheses.shape[0])]

In [29]:
torch.topk(torch.Tensor([float("inf"),float("inf"),float("inf")]), k=2, dim=-1, largest=False)

torch.return_types.topk(
values=tensor([inf, inf]),
indices=tensor([0, 1]))

In [25]:
my_hypotheses = get_beam_hypotheses(source_text, 
                        masked_sequence, 
                        indices_in_mlm_tokens,
                        predicted_token_ids.indices,
                        mlm_tokenizer, 
                        lossfns,
                        config)

In [26]:
my_hypotheses

[[" over me, the up. I could hear him starting to fight, but I didn't stop him. Thank God he had these clothes on me.",
  " over me was he him. I could hear him starting to cry, but I didn't stop him. Thank God he had these headphones on me.",
  " over me he I down. I could hear him starting to scream, but I didn't stop him. Thank God he had these handcuffs on me."],
 [' over like this and he was like.',
  ' over like this and he was like:',
  ' over like this and he was like…'],
 ['ran me and put his hand in my hair and my hand forced his hand up and down my back and then he put his hand back in my hair and he held it.This is',
  'handed me and put his hands in my face and my hand forced his hand up and down my arm and then he put his hand back in my shirt and he squeezed it.This is',
  'held me and put his hand in my hand and my hand forced his hand up and down my back and then he put his hand back in my pocket and he kissed it.This is'],
 [' over, looking for his phone, and so, like

In [63]:
# ## combi_rerank 
# def combi_rerank(masked_sequence:torch.Tensor, 
#                  indices_in_mlm_tokens:tuple,
#                  predicted_token_ids:torch.Tensor,
#                  mlm_tokenizer:transformers.AutoTokenizer,
#                  config:dict) -> Tuple[List[str],List[float],List[int],List[np.array]]:
#     """params: 
#         masked_sequence should be token ids tokenized by MLM's tokenizer.
#         indices_in_mlm_tokens should be a result of running 
#         `indices_in_mlm_tokens = (
#                                     inputs.input_ids == mlm_tokenizer.mask_token_id
#                                     ).nonzero(as_tuple=True)`
#         predicted_token_ids should be a result of running
#         `predicted_token_ids = torch.topk(
#                                 logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
#                                 k=config['k_per_location'],
#                                 dim=-1,).indices`
                                
#     returns:
#         hypotheses: list of one best hypothesis(editing result) for each of original texts. length same as masked_sequence.shape[0]
#         best_weighted_loss: list of weighted loss for the best hypotheses.
#         best_allsat: list of indicator whether the best hypotheses satisfy cutoff (min_epsilons) for constraint energy score.
#         best_logging_loss: list of a tuple of fluency energy score and constraint energy score for each best hypothesis.
#     """

#     num_batches = masked_sequence.shape[0]
#     loss_weights = [1 - config['closs_weight'], config['closs_weight']]
    
#     hypotheses = []
#     best_weighted_loss = []
#     best_allsat = []
#     best_logging_loss = []
    
#     for i in tqdm(range(num_batches)):
        
#         l = (indices_in_mlm_tokens[0] == i).sum().item()
#         tok_cand_combos = list(product(range(config['k_per_location']),repeat=l))
        
#         tmp_hypotheses = masked_sequence[i,:].repeat((config['k_per_location']**l,1))
#         tmp_hypotheses[:, indices_in_mlm_tokens[1][indices_in_mlm_tokens[0] == i]] = \
#             predicted_token_ids[indices_in_mlm_tokens[0] == i, tok_cand_combos]
            
#         tmp_dec_seq = mlm_tokenizer.batch_decode(
#                     tmp_hypotheses, skip_special_tokens=True
#             )
        
#         curr_loss = torch.zeros(len(tmp_dec_seq)).to(config['device'])
#         logging_loss = torch.zeros((len(tmp_dec_seq),2)).to(config['device'])
#         data_loader = DataLoader(CustomDataset(tmp_dec_seq),batch_size=64)

#         for lossid, lossname in enumerate(config["losses"]):
#             lossvalues=[]
#             with torch.no_grad():
#                 for batch in data_loader:
#                     lossvalue = lossfns[lossid].compute_gold_loss(
#                         source_text, batch,
#                         label_id=config['target_label_ids'][lossid],
#                     )
#                     lossvalues.append(lossvalue)
#                     torch.cuda.empty_cache()
#             lossvalue = torch.cat(lossvalues,dim=0)
#             curr_loss += loss_weights[lossid] * lossvalue
#             logging_loss[:, lossid] = lossvalue.clone()
            
#         allsat_ix = torch.where(logging_loss[:,1]> -np.log(config["min_epsilons"][0]))[0].squeeze(0)
#         if (allsat_ix.shape[0] > 0) and (config['selection_criteria'] == "allsat_primary"):
#             best_ix = allsat_ix[curr_loss[allsat_ix].argmin()]
#         else: ## in case config['selection_criteria'] == "weighted_sum" or allsat is all False
#             best_ix = torch.argmin(curr_loss)

#         hypotheses.append(tmp_dec_seq[best_ix])
#         best_weighted_loss.append(curr_loss[best_ix].item())
#         best_allsat.append(1 if best_ix in allsat_ix else 0)
#         best_logging_loss.append(logging_loss[best_ix].cpu().numpy())
        
#         del curr_loss, logging_loss
#         torch.cuda.empty_cache()
#     return hypotheses, best_weighted_loss, best_allsat, best_logging_loss

In [60]:
# %%timeit
# my_hypotheses = combi_rerank(masked_sequence, 
#                             indices_in_mlm_tokens,
#                             predicted_token_ids.indices,
#                             mlm_tokenizer,
#                             config)

100%|██████████| 10/10 [00:50<00:00,  5.05s/it]
100%|██████████| 10/10 [00:51<00:00,  5.12s/it]
100%|██████████| 10/10 [00:51<00:00,  5.14s/it]
100%|██████████| 10/10 [00:50<00:00,  5.10s/it]
100%|██████████| 10/10 [00:50<00:00,  5.08s/it]
100%|██████████| 10/10 [00:50<00:00,  5.05s/it]
100%|██████████| 10/10 [00:50<00:00,  5.08s/it]
100%|██████████| 10/10 [00:50<00:00,  5.06s/it]

50.9 s ± 310 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)





In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Initializing the model and tokenizer for it
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
inputs = tokenizer(["I'm not going to"], return_tensors="pt")

# This shows a normal generate without any specific parameters
summary_ids = model.generate(**inputs)
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])

# This generates a penalty for repeated tokens
penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I'm not going to be able to do that. I'm going to be able to do that
I'm not going to be able to do that. I'll just have to go out and play


In [8]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Initializing the model and tokenizer for it
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
inputs = tokenizer(["I'm not going to"], return_tensors="pt")

# This shows a normal generate without any specific parameters
summary_ids = model.generate(**inputs)
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])


config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I'm not going to be able to do that. I'm going to be able to do that


In [None]:

# This generates a penalty for repeated tokens
penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])

In [4]:

# Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)


def get_tokens_as_tuple(word):
    return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])

print(get_tokens_as_tuple("Trump"))


(1301, 28628, 18435, 2159)


In [7]:
import transformers
transformers.__version__

'4.30.2'

In [5]:
# If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])


ValueError: The following `model_kwargs` are not used by the model: ['sequence_bias'] (note: typos in the generate arguments will also show up in this list)

In [1]:

biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])

# We can also add a positive bias to nudge the model towards specific tokens or continuations
sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])

NameError: name 'model' is not defined