In [2]:
# -*- coding: utf-8 -*-
import argparse
import os
import string
import sys
os.chdir('/data/hyeryung/mucoco')
from itertools import repeat
import torch.multiprocessing as mp
from typing import List
from itertools import permutations,product
import math

import numpy as np
import pandas as pd
import torch
import transformers
# from datasets import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForMaskedLM, AutoTokenizer, AutoConfig

from new_module.utils.robertacustom import RobertaCustomForSequenceClassification
from new_module.locate.new_locate_utils import *
import new_module.losses as lossbuilder

import logging
import os
from typing import List,Tuple
from tqdm import tqdm

import torch
import torch.nn.functional as F
import transformers

In [3]:
## setup for prototyping
logging.basicConfig(level=logging.DEBUG, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("LOGGING_LEVEL", logging.DEBUG))

config={#'model_paths':['gpt2-large','/shared/s3/lab07/hyeryung/loc_edit/roberta-base-jigsaw-toxicity-classifier-energy-training/step_600_best_checkpoint'],
        # 'tokenizer_paths':['gpt2-large','/shared/s3/lab07/hyeryung/loc_edit/roberta-base-jigsaw-toxicity-classifier-energy-training/step_600_best_checkpoint'],
        'model_paths':['gpt2-large','/data/hyeryung/loc_edit/models/roberta-base-jigsaw-toxicity-classifier-energy-training/step_1000_best_checkpoint'],
        'tokenizer_paths':['gpt2-large','/data/hyeryung/loc_edit/models/roberta-base-jigsaw-toxicity-classifier-energy-training/step_1000_best_checkpoint'],
        'model_types': ["AutoModelForCausalLM", "AutoModelForSequenceClassification"],
        'cache_dir': "/data/hyeryung/hf_cache",
        'target_type': "embeds",
        'method': "mlm-beamsearch-v0",
       'losses': ["gpt2", "classification_no_prefix_logprobloss"],
       'target_label_ids': [0,0] ,
       'build_loss_dict': {"coeff_steps": 200, "coeff_pattern": "constant", "loss_type": "xentropy", "length_normalize": False, "AR_temperature": 1.0, "AR_top_k": 0, "AR_top_p": 0.96, "max_output_length": 20},
       'min_epsilons': [0.75],
       'source_data': 'new_module/data/toxicity-avoidance/testset_gpt2_2500.jsonl',
       'locate_unit': 'token',
       'locate_method': 'grad_norm',
       'device': 'cuda:7',
       'k_per_location': 3,
       'closs_weight': 0.9,
       'beam_size': 3,
       'selection_criteria':"weighted_sum"}

class dummyArgs:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

build_loss_args = dummyArgs(**config["build_loss_dict"])

In [4]:
name2tokenizer = {}
name2model = {}
name2config = {}
loss2tokenizer = {}
embed_luts = []
primary_model = None
for i, model_path in enumerate(config["model_paths"]):
    if (
        model_path not in name2model
    ):  # making sure we are not loading the model twice in case some constraints use the same model.
        try:
            name2tokenizer[model_path] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=True,
            )
        except:
            name2tokenizer[model_path] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=False,
            )

        name2config[model_path] = AutoConfig.from_pretrained(
            model_path, cache_dir=config["cache_dir"]
        )

        if "Custom" in config["model_types"][i]:
            name2model[model_path] = lossbuilder.ModelWrapper(
                getattr(utils, config["model_types"][i]).from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        else:
            name2model[model_path] = lossbuilder.ModelWrapper(
                getattr(transformers, config["model_types"][i]).from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        name2model[model_path].eval()
        name2model[model_path].to(config['device'])

    input_embeds = name2model[model_path].get_input_embeddings()
    if isinstance(input_embeds, torch.nn.Sequential):
        input_embeds = input_embeds[0]
    embed_luts.append(input_embeds)

    if config["target_type"] == "embeds":
        embed_luts[-1].requires_grad = False

    if i == 0:
        primary_model = name2model[model_path]

mlm_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
mlm = None if config["method"] == "mlm-beamsearch-v2" else AutoModelForMaskedLM.from_pretrained("roberta-base").to(config['device'])  

lossfns = []
for i, loss in enumerate(config["losses"]):
    lossfns.append(
        lossbuilder.build_loss(
            loss,
            name2model[config["model_paths"][i]],
            name2tokenizer[config["model_paths"][i]],
            build_loss_args,
        )
    )
    lossfns[i].tokenizer.add_special_tokens({"mask_token": mlm_tokenizer.mask_token})
    loss2tokenizer[loss] = lossfns[i].tokenizer

primary_tokenizer = name2tokenizer['gpt2-large']
secondary_tokenizer = list(name2tokenizer.values())[1]

Starting new HTTPS connection (1): huggingface.co:443
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/generation_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /roberta-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0


In [5]:
def unravel(outputs_df):
    outputs_df=outputs_df.explode('generations',ignore_index=True)
    outputs_df['prompt']=outputs_df['prompt'].apply(lambda x: x['text'])
    outputs_df['generations']=outputs_df['generations'].apply(lambda x: x['text'] if isinstance(x, dict) else x)
    outputs_df = outputs_df.dropna().reset_index(drop=True)
    return outputs_df
toxic_gen=pd.read_json('new_module/data/toxicity-avoidance/dev_set.jsonl',lines=True)
toxic_gen=unravel(toxic_gen)
source_text = '“I went to use the bathroom, and I went in and Lance was bent'
prediction = toxic_gen.loc[toxic_gen['prompt']==source_text,'generations'].tolist()
ckpt_path = config['model_paths'][1]
model = name2model[ckpt_path] #AutoModelForSequenceClassification.from_pretrained(ckpt_path)
tokenizer = name2tokenizer[ckpt_path] #AutoTokenizer.from_pretrained(ckpt_path)

model = model.to(config['device'])
loc_machine=LocateMachine(model,tokenizer)
masked_text = loc_machine.locate_main(prediction, "grad_norm", max_num_tokens = 6, unit=config['locate_unit'], num_layer=10, label_id=0)

In [6]:
## replace tokens at the indices with mask tokens
inputs = mlm_tokenizer(
    masked_text, return_tensors="pt", padding=True, truncation=True
)
inputs = inputs.to(config['device']) 
masked_sequence=inputs['input_ids']

## make predictions for the masked indices
with torch.no_grad():
    logits = mlm(**inputs).logits

special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
logits[:, :, special_token_ids] = -float("inf")

indices_in_mlm_tokens = (
    inputs.input_ids == mlm_tokenizer.mask_token_id
).nonzero(as_tuple=True)

## get top k tokens for each index
predicted_token_ids = torch.topk(
    logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
    k=config['k_per_location'],
    dim=-1,
)
# print(f"predicted_token_ids: {predicted_token_ids}")
# print(f"mlm_tokenizer.batch_decode(predicted_token_ids.indices): {mlm_tokenizer.batch_decode(predicted_token_ids.indices)}")


In [10]:
def beam_rerank(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:tuple,
                    # predicted_token_ids:torch.return_types.topk,
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict):
    """params: 
    source_text(prompt) should be text. 
    masked_sequence should be token ids tokenized by MLM's tokenizer.
    indices_in_mlm_tokens should be a result of running 
    `indices_in_mlm_tokens = (
                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                ).nonzero(as_tuple=True)`
    predicted_token_ids should be a result of running
    `predicted_token_ids = torch.topk(
                            logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                            k=config['k_per_location'],
                            dim=-1,).indices`
    """
    
    hypotheses = masked_sequence[:, None, :].repeat((1,config['beam_size'],1))
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))

    # for i in tqdm(edit_indices,total=len(edit_indices)):
    for i in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==i]

        tmp_hypotheses = hypotheses[batch_ids_to_edit].detach().clone()
        tmp_hypotheses=tmp_hypotheses.repeat((1,config['k_per_location'],1))

        # candidates = predicted_token_ids.indices[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
        candidates = predicted_token_ids[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
        candidates = candidates[:, :, None].repeat((1,1, config['beam_size'])).reshape(candidates.shape[0], -1,1)

        tmp_hypotheses = torch.cat((tmp_hypotheses[:, :, :i], candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]
        tmp_hypotheses = tmp_hypotheses.reshape(-1, tmp_hypotheses.shape[-1])
        
        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        for lossid, lossname in enumerate(config["losses"]):
            if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
                continue
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                    label_id=config['target_label_ids'][lossid],
                )
            torch.cuda.empty_cache()
            curr_loss += loss_weights[lossid] * lossvalue
            
        curr_loss = torch.stack(torch.split(curr_loss, config['beam_size'] * config['k_per_location'], dim=0),dim=0)
        top_beams=torch.topk(curr_loss, k=(config['beam_size']*2+1), dim=-1, largest=False).indices

        tmp_hypotheses = torch.split(tmp_hypotheses, config['beam_size'] * config['k_per_location'], dim=0) ## 아래의 작업을 더 간단히 할 수 있는 방법?
        tmp_hypotheses = torch.stack([x[top_beams[j]] for j, x in enumerate(tmp_hypotheses)],dim=0)
        tmp_hypotheses = torch.unique(tmp_hypotheses, dim=1)[:, :config['beam_size'], :]
        
        hypotheses[batch_ids_to_edit,:, i]=tmp_hypotheses[:, :, i]
            
    return [mlm_tokenizer.batch_decode(hypotheses[j], skip_special_tokens=True) for j in range(hypotheses.shape[0])]

In [7]:
## get_combi_hypotheses 
def get_combi_hypotheses(masked_sequence:torch.Tensor, 
                 indices_in_mlm_tokens:tuple,
                 predicted_token_ids:torch.Tensor,
                 mlm_tokenizer:transformers.AutoTokenizer,
                 config:dict) -> List[str]:
    """params: 
    masked_sequence should be token ids tokenized by MLM's tokenizer.
    indices_in_mlm_tokens should be a result of running 
    `indices_in_mlm_tokens = (
                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                ).nonzero(as_tuple=True)`
    predicted_token_ids should be a result of running
    `predicted_token_ids = torch.topk(
                            logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                            k=config['k_per_location'],
                            dim=-1,).indices`
    """

    k = config['k_per_location']
    hypotheses = []
    num_batches = masked_sequence.shape[0]
    for i in range(num_batches):
        
        l = (indices_in_mlm_tokens[0] == i).sum().item()
        tok_cand_combos = list(product(range(k),repeat=l))
        
        tmp_hypotheses = masked_sequence[i,:].repeat((k**l,1))
        tmp_hypotheses[:, indices_in_mlm_tokens[1][indices_in_mlm_tokens[0] == i]] = \
            predicted_token_ids[indices_in_mlm_tokens[0] == i, tok_cand_combos]
            
        tmp_dec_seq = mlm_tokenizer.batch_decode(
                    tmp_hypotheses, skip_special_tokens=True
            )
        hypotheses.append(tmp_dec_seq)
    return hypotheses

In [51]:
# %%timeit

# my_hypotheses = get_combi_hypotheses(masked_sequence, 
#                             indices_in_mlm_tokens,
#                             predicted_token_ids.indices,
#                             mlm_tokenizer,
#                             config)

# loss_weights = [1 - config['closs_weight'], config['closs_weight']]
# selection_criteria = "weighted_sum"
# hypotheses = deepcopy(my_hypotheses)
# best_ixes = []
# best_weighted_loss = []
# best_allsat = []
# best_logging_loss = []
# num_batches = masked_sequence.shape[0]
# for i in tqdm(range(num_batches)):
       
#     curr_loss = torch.zeros(len(hypotheses[i])).to(config['device'])
#     logging_loss = torch.zeros((len(hypotheses[i]),2)).to(config['device'])

#     hyp_data = CustomDataset(hypotheses[i])
#     data_loader = DataLoader(hyp_data,batch_size=64)

#     for lossid, lossname in enumerate(config["losses"]):
#         lossvalues=[]
#         with torch.no_grad():
#             for batch in data_loader:
#                 lossvalue = lossfns[lossid].compute_gold_loss(
#                     source_text, batch,
#                     label_id=config['target_label_ids'][lossid],
#                 )
#                 lossvalues.append(lossvalue)
#                 torch.cuda.empty_cache()
#         lossvalue = torch.cat(lossvalues,dim=0)
#         curr_loss += loss_weights[lossid] * lossvalue
#         logging_loss[:, lossid] = lossvalue.clone()
        
#     allsat_ix = torch.where(logging_loss[:,1]> -np.log(config["min_epsilons"][0]))[0].squeeze(0)
#     if (allsat_ix.shape[0] > 0) and (selection_criteria == "allsat_primary"):
#         best_ix = allsat_ix[curr_loss[allsat_ix].argmin()]
#     else: ## in case selection_criteria == "weighted_sum" or allsat is all False
#         best_ix = torch.argmin(curr_loss)

    
#     hypotheses[i]=hypotheses[i][best_ix]
#     best_weighted_loss.append(curr_loss[best_ix].item())
#     best_allsat.append(1 if best_ix in allsat_ix else 0)
#     best_logging_loss.append(logging_loss[best_ix].cpu().numpy())
    
#     del curr_loss, logging_loss
#     torch.cuda.empty_cache()

100%|██████████| 10/10 [00:50<00:00,  5.01s/it]
100%|██████████| 10/10 [00:50<00:00,  5.08s/it]
100%|██████████| 10/10 [00:52<00:00,  5.26s/it]
100%|██████████| 10/10 [00:52<00:00,  5.27s/it]
100%|██████████| 10/10 [00:51<00:00,  5.13s/it]
100%|██████████| 10/10 [00:52<00:00,  5.21s/it]
100%|██████████| 10/10 [00:51<00:00,  5.11s/it]
100%|██████████| 10/10 [00:51<00:00,  5.12s/it]

52.1 s ± 680 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)





In [21]:

def final_reranking(hypotheses:List[List[str]],
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict,
                    batch_size:int=64) -> Tuple[List[str],List[float],List[int],List[List[float]]]:
    """params: 
        hypotheses: list of hypotheses of editing results
        lossfns:
        config:
        batch_size:             
    returns:
        hypotheses: list of one best hypothesis(editing result) for each of original texts. length same as masked_sequence.shape[0]
        best_weighted_loss: list of weighted loss for the best hypotheses.
        best_allsat: list of indicator(1,0) whether the best hypotheses satisfy cutoff (min_epsilons) for constraint energy score.
        best_logging_loss: list of list of fluency energy score and constraint energy score for each best hypothesis.
    """
    
    
    class CustomDataset(Dataset):
        def __init__(self, hypotheses_data:List[str]):
            self.hypotheses_data = hypotheses_data
            
        def __len__(self):
            return len(self.hypotheses_data)

        def __getitem__(self, idx:int):
            return self.hypotheses_data[idx]
        
        def __getitems__(self, idx:List[int]):
            return [self.hypotheses_data[j] for j in idx]
    
    final_hypotheses = []
    best_weighted_loss = []
    best_allsat = []
    best_logging_loss = []
    
    loss_weights = [1 - config['closs_weight'], config['closs_weight']]
    
    for i in tqdm(range(len(hypotheses))):
        curr_loss = torch.zeros(len(hypotheses[i])).to(config['device'])
        logging_loss = torch.zeros((len(hypotheses[i]),2)).to(config['device'])
        data_loader = DataLoader(CustomDataset(hypotheses[i]),batch_size=batch_size)

        for lossid, lossname in enumerate(config["losses"]):
            lossvalues=[]
            with torch.no_grad():
                for batch in data_loader:
                    lossvalue = lossfns[lossid].compute_gold_loss(
                        source_text, batch,
                        label_id=config['target_label_ids'][lossid],
                    )
                    lossvalues.append(lossvalue)
                    torch.cuda.empty_cache()
            lossvalue = torch.cat(lossvalues,dim=0)
            curr_loss += loss_weights[lossid] * lossvalue
            logging_loss[:, lossid] = lossvalue.clone()
            
        allsat_ix = torch.where(logging_loss[:,1]> -math.log(config["min_epsilons"][0]))[0].squeeze(0)
        if (allsat_ix.shape[0] > 0) and (config['selection_criteria'] == "allsat_primary"):
            best_ix = allsat_ix[curr_loss[allsat_ix].argmin()]
        else: ## in case config['selection_criteria'] == "weighted_sum" or allsat is all False
            best_ix = torch.argmin(curr_loss)

        final_hypotheses.append(hypotheses[i][best_ix])
        best_weighted_loss.append(curr_loss[best_ix].item())
        best_allsat.append(1 if best_ix in allsat_ix else 0)
        best_logging_loss.append(logging_loss[best_ix].cpu().tolist())
    
        del curr_loss, logging_loss
        torch.cuda.empty_cache()
    return final_hypotheses, best_weighted_loss, best_allsat, best_logging_loss

In [None]:
# hypotheses = beam_rerank(source_text, 
#                         masked_sequence, 
#                         indices_in_mlm_tokens,
#                         predicted_token_ids.indices,
#                         mlm_tokenizer, 
#                         lossfns,
#                         config)

In [None]:
%%timeit
my_hypotheses = beam_rerank(source_text, 
                        masked_sequence, 
                        indices_in_mlm_tokens,
                        predicted_token_ids.indices,
                        mlm_tokenizer, 
                        lossfns,
                        config)
final_result = final_reranking(my_hypotheses,
                                lossfns,
                                config,
                                batch_size=64)

100%|██████████| 10/10 [00:00<00:00, 25.15it/s]
100%|██████████| 10/10 [00:00<00:00, 25.57it/s]
100%|██████████| 10/10 [00:00<00:00, 25.10it/s]
100%|██████████| 10/10 [00:00<00:00, 25.80it/s]
100%|██████████| 10/10 [00:00<00:00, 25.42it/s]
100%|██████████| 10/10 [00:00<00:00, 25.59it/s]
100%|██████████| 10/10 [00:00<00:00, 25.06it/s]
100%|██████████| 10/10 [00:00<00:00, 25.43it/s]

2.31 s ± 15.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)





In [26]:
%%timeit
my_hypotheses = get_combi_hypotheses(masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer,
                            config)
final_result = final_reranking(my_hypotheses,
                                lossfns,
                                config,
                                batch_size=64)

100%|██████████| 10/10 [00:18<00:00,  1.87s/it]
100%|██████████| 10/10 [00:18<00:00,  1.88s/it]
100%|██████████| 10/10 [00:18<00:00,  1.88s/it]
100%|██████████| 10/10 [00:18<00:00,  1.89s/it]
100%|██████████| 10/10 [00:18<00:00,  1.89s/it]
100%|██████████| 10/10 [00:18<00:00,  1.90s/it]
100%|██████████| 10/10 [00:18<00:00,  1.90s/it]
100%|██████████| 10/10 [00:18<00:00,  1.90s/it]

19 s ± 65.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)





In [23]:

def get_beam_hypotheses(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:Tuple[torch.Tensor],
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict) -> List[List[str]]:
    """
    A function to get hypotheses of beam size via editing beam search with reranking.
    Run this function if config['method'] == 'mlm-beamsearch-v0' or config['method'] == 'mlm-beamsearch-v1'
    If config['method'] == 'mlm-beamsearch-v1', rerank beam only with fluency energy.
    If config['method'] == 'mlm-beamsearch-v0', rerank beam with a weighted sum of fluency and constraint energy.
    
    #ToDo
    #Implement mlm-beamsearch-v0 with allsat-primary and compare 
    
    params: 
        source_text: a prompt text 
        masked_sequence: token ids of original generation text with located indices masked. tokenized by MLM's tokenizer.
        indices_in_mlm_tokens: a result of running 
                                    `indices_in_mlm_tokens = (
                                                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                                                ).nonzero(as_tuple=True)`
        predicted_token_ids: a result of running
                                    `predicted_token_ids = torch.topk(
                                                                logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                                                                k=config['k_per_location'],
                                                                dim=-1,).indices`
        mlm_tokenizer: tokenizer of MLM
        lossfns: a list of loss functions
        config: a dictionary of configurations
    
    returns:
        hypotheses: a list of a list of the beam number of hypotheses for each sample         
    """
    
    hypotheses = masked_sequence[:, None, :].repeat((1,config['beam_size'],1))
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))

    for i in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==i]

        tmp_hypotheses = hypotheses[batch_ids_to_edit].detach().clone()
        tmp_hypotheses=tmp_hypotheses.repeat((1,config['k_per_location'],1))

        candidates = predicted_token_ids[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
        candidates = candidates[:, :, None].repeat((1,1, config['beam_size'])).reshape(candidates.shape[0], -1,1)

        tmp_hypotheses = torch.cat((tmp_hypotheses[:, :, :i], candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]
        tmp_hypotheses = tmp_hypotheses.reshape(-1, tmp_hypotheses.shape[-1])
        
        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        for lossid, lossname in enumerate(config["losses"]):
            if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
                continue
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                    label_id=config['target_label_ids'][lossid],
                )
            torch.cuda.empty_cache()
            curr_loss += loss_weights[lossid] * lossvalue
            
        curr_loss = torch.stack(torch.split(curr_loss, config['beam_size'] * config['k_per_location'], dim=0),dim=0)
        top_beams=torch.topk(curr_loss, k=(config['beam_size']*2+1), dim=-1, largest=False).indices

        tmp_hypotheses = torch.split(tmp_hypotheses, config['beam_size'] * config['k_per_location'], dim=0) ## 아래의 작업을 더 간단히 할 수 있는 방법?
        tmp_hypotheses = torch.stack([x[top_beams[j]] for j, x in enumerate(tmp_hypotheses)],dim=0)
        tmp_hypotheses = torch.unique(tmp_hypotheses, dim=1)[:, :config['beam_size'], :]
        
        hypotheses[batch_ids_to_edit,:, i]=tmp_hypotheses[:, :, i]
            
    return [mlm_tokenizer.batch_decode(hypotheses[j], skip_special_tokens=True) for j in range(hypotheses.shape[0])]

In [None]:

def get_beam_hypotheses(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:Tuple[torch.Tensor],
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict) -> List[List[str]]:
    """
    A function to get hypotheses of beam size via editing beam search with reranking.
    Run this function if config['method'] == 'mlm-beamsearch-v0' or config['method'] == 'mlm-beamsearch-v1'
    If config['method'] == 'mlm-beamsearch-v1', rerank beam only with fluency energy.
    If config['method'] == 'mlm-beamsearch-v0', rerank beam with a weighted sum of fluency and constraint energy.
    
    #ToDo
    #Implement mlm-beamsearch-v0 with allsat-primary and compare 
    
    params: 
        source_text: a prompt text 
        masked_sequence: token ids of original generation text with located indices masked. tokenized by MLM's tokenizer.
        indices_in_mlm_tokens: a result of running 
                                    `indices_in_mlm_tokens = (
                                                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                                                ).nonzero(as_tuple=True)`
        predicted_token_ids: a result of running
                                    `predicted_token_ids = torch.topk(
                                                                logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                                                                k=config['k_per_location'],
                                                                dim=-1,).indices`
        mlm_tokenizer: tokenizer of MLM
        lossfns: a list of loss functions
        config: a dictionary of configurations
    
    returns:
        hypotheses: a list of a list of the beam number of hypotheses for each sample         
    """
    
    hypotheses = masked_sequence[:, None, :].repeat((1,config['beam_size'],1))
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))

    for i in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==i]

        tmp_hypotheses = hypotheses[batch_ids_to_edit].detach().clone()
        tmp_hypotheses=tmp_hypotheses.repeat((1,config['k_per_location'],1))

        candidates = predicted_token_ids[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
        candidates = candidates[:, :, None].repeat((1,1, config['beam_size'])).reshape(candidates.shape[0], -1,1)

        tmp_hypotheses = torch.cat((tmp_hypotheses[:, :, :i], candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]
        tmp_hypotheses = tmp_hypotheses.reshape(-1, tmp_hypotheses.shape[-1])
        
        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        for lossid, lossname in enumerate(config["losses"]):
            if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
                continue
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                    label_id=config['target_label_ids'][lossid],
                )
            torch.cuda.empty_cache()
            curr_loss += loss_weights[lossid] * lossvalue
            
        curr_loss = torch.stack(torch.split(curr_loss, config['beam_size'] * config['k_per_location'], dim=0),dim=0)
        top_beams=torch.topk(curr_loss, k=(config['beam_size']*2+1), dim=-1, largest=False).indices

        tmp_hypotheses = torch.split(tmp_hypotheses, config['beam_size'] * config['k_per_location'], dim=0) ## 아래의 작업을 더 간단히 할 수 있는 방법?
        tmp_hypotheses = torch.stack([x[top_beams[j]] for j, x in enumerate(tmp_hypotheses)],dim=0)
        tmp_hypotheses = torch.unique(tmp_hypotheses, dim=1)[:, :config['beam_size'], :]
        
        hypotheses[batch_ids_to_edit,:, i]=tmp_hypotheses[:, :, i]
            
    return [mlm_tokenizer.batch_decode(hypotheses[j], skip_special_tokens=True) for j in range(hypotheses.shape[0])]

In [29]:
torch.topk(torch.Tensor([float("inf"),float("inf"),float("inf")]), k=2, dim=-1, largest=False)

torch.return_types.topk(
values=tensor([inf, inf]),
indices=tensor([0, 1]))

In [25]:
my_hypotheses = get_beam_hypotheses(source_text, 
                        masked_sequence, 
                        indices_in_mlm_tokens,
                        predicted_token_ids.indices,
                        mlm_tokenizer, 
                        lossfns,
                        config)

In [26]:
my_hypotheses

[[" over me, the up. I could hear him starting to fight, but I didn't stop him. Thank God he had these clothes on me.",
  " over me was he him. I could hear him starting to cry, but I didn't stop him. Thank God he had these headphones on me.",
  " over me he I down. I could hear him starting to scream, but I didn't stop him. Thank God he had these handcuffs on me."],
 [' over like this and he was like.',
  ' over like this and he was like:',
  ' over like this and he was like…'],
 ['ran me and put his hand in my hair and my hand forced his hand up and down my back and then he put his hand back in my hair and he held it.This is',
  'handed me and put his hands in my face and my hand forced his hand up and down my arm and then he put his hand back in my shirt and he squeezed it.This is',
  'held me and put his hand in my hand and my hand forced his hand up and down my back and then he put his hand back in my pocket and he kissed it.This is'],
 [' over, looking for his phone, and so, like

In [63]:
# ## combi_rerank 
# def combi_rerank(masked_sequence:torch.Tensor, 
#                  indices_in_mlm_tokens:tuple,
#                  predicted_token_ids:torch.Tensor,
#                  mlm_tokenizer:transformers.AutoTokenizer,
#                  config:dict) -> Tuple[List[str],List[float],List[int],List[np.array]]:
#     """params: 
#         masked_sequence should be token ids tokenized by MLM's tokenizer.
#         indices_in_mlm_tokens should be a result of running 
#         `indices_in_mlm_tokens = (
#                                     inputs.input_ids == mlm_tokenizer.mask_token_id
#                                     ).nonzero(as_tuple=True)`
#         predicted_token_ids should be a result of running
#         `predicted_token_ids = torch.topk(
#                                 logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
#                                 k=config['k_per_location'],
#                                 dim=-1,).indices`
                                
#     returns:
#         hypotheses: list of one best hypothesis(editing result) for each of original texts. length same as masked_sequence.shape[0]
#         best_weighted_loss: list of weighted loss for the best hypotheses.
#         best_allsat: list of indicator whether the best hypotheses satisfy cutoff (min_epsilons) for constraint energy score.
#         best_logging_loss: list of a tuple of fluency energy score and constraint energy score for each best hypothesis.
#     """

#     num_batches = masked_sequence.shape[0]
#     loss_weights = [1 - config['closs_weight'], config['closs_weight']]
    
#     hypotheses = []
#     best_weighted_loss = []
#     best_allsat = []
#     best_logging_loss = []
    
#     for i in tqdm(range(num_batches)):
        
#         l = (indices_in_mlm_tokens[0] == i).sum().item()
#         tok_cand_combos = list(product(range(config['k_per_location']),repeat=l))
        
#         tmp_hypotheses = masked_sequence[i,:].repeat((config['k_per_location']**l,1))
#         tmp_hypotheses[:, indices_in_mlm_tokens[1][indices_in_mlm_tokens[0] == i]] = \
#             predicted_token_ids[indices_in_mlm_tokens[0] == i, tok_cand_combos]
            
#         tmp_dec_seq = mlm_tokenizer.batch_decode(
#                     tmp_hypotheses, skip_special_tokens=True
#             )
        
#         curr_loss = torch.zeros(len(tmp_dec_seq)).to(config['device'])
#         logging_loss = torch.zeros((len(tmp_dec_seq),2)).to(config['device'])
#         data_loader = DataLoader(CustomDataset(tmp_dec_seq),batch_size=64)

#         for lossid, lossname in enumerate(config["losses"]):
#             lossvalues=[]
#             with torch.no_grad():
#                 for batch in data_loader:
#                     lossvalue = lossfns[lossid].compute_gold_loss(
#                         source_text, batch,
#                         label_id=config['target_label_ids'][lossid],
#                     )
#                     lossvalues.append(lossvalue)
#                     torch.cuda.empty_cache()
#             lossvalue = torch.cat(lossvalues,dim=0)
#             curr_loss += loss_weights[lossid] * lossvalue
#             logging_loss[:, lossid] = lossvalue.clone()
            
#         allsat_ix = torch.where(logging_loss[:,1]> -np.log(config["min_epsilons"][0]))[0].squeeze(0)
#         if (allsat_ix.shape[0] > 0) and (config['selection_criteria'] == "allsat_primary"):
#             best_ix = allsat_ix[curr_loss[allsat_ix].argmin()]
#         else: ## in case config['selection_criteria'] == "weighted_sum" or allsat is all False
#             best_ix = torch.argmin(curr_loss)

#         hypotheses.append(tmp_dec_seq[best_ix])
#         best_weighted_loss.append(curr_loss[best_ix].item())
#         best_allsat.append(1 if best_ix in allsat_ix else 0)
#         best_logging_loss.append(logging_loss[best_ix].cpu().numpy())
        
#         del curr_loss, logging_loss
#         torch.cuda.empty_cache()
#     return hypotheses, best_weighted_loss, best_allsat, best_logging_loss

In [60]:
# %%timeit
# my_hypotheses = combi_rerank(masked_sequence, 
#                             indices_in_mlm_tokens,
#                             predicted_token_ids.indices,
#                             mlm_tokenizer,
#                             config)

100%|██████████| 10/10 [00:50<00:00,  5.05s/it]
100%|██████████| 10/10 [00:51<00:00,  5.12s/it]
100%|██████████| 10/10 [00:51<00:00,  5.14s/it]
100%|██████████| 10/10 [00:50<00:00,  5.10s/it]
100%|██████████| 10/10 [00:50<00:00,  5.08s/it]
100%|██████████| 10/10 [00:50<00:00,  5.05s/it]
100%|██████████| 10/10 [00:50<00:00,  5.08s/it]
100%|██████████| 10/10 [00:50<00:00,  5.06s/it]

50.9 s ± 310 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)





In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Initializing the model and tokenizer for it
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
inputs = tokenizer(["I'm not going to"], return_tensors="pt")

# This shows a normal generate without any specific parameters
summary_ids = model.generate(**inputs)
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])

# This generates a penalty for repeated tokens
penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I'm not going to be able to do that. I'm going to be able to do that
I'm not going to be able to do that. I'll just have to go out and play


In [8]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Initializing the model and tokenizer for it
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
inputs = tokenizer(["I'm not going to"], return_tensors="pt")

# This shows a normal generate without any specific parameters
summary_ids = model.generate(**inputs)
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])


config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I'm not going to be able to do that. I'm going to be able to do that


In [None]:

# This generates a penalty for repeated tokens
penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])

In [4]:

# Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)


def get_tokens_as_tuple(word):
    return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])

print(get_tokens_as_tuple("Trump"))


(1301, 28628, 18435, 2159)


In [7]:
import transformers
transformers.__version__

'4.30.2'

In [5]:
# If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])


ValueError: The following `model_kwargs` are not used by the model: ['sequence_bias'] (note: typos in the generate arguments will also show up in this list)

In [1]:

biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])

# We can also add a positive bias to nudge the model towards specific tokens or continuations
sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])

NameError: name 'model' is not defined

 # Check for discrepancy

In [1]:
#!/usr/bin/env python
# coding: utf-8

from itertools import chain
import math
import argparse
import json
import logging
import os
import time
os.chdir('/data/hyeryung/mucoco')
import numpy as np
import pandas as pd
import torch
import transformers
from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer

from copy import deepcopy
import new_module.losses as lossbuilder
import new_module.losses_old as lossbuilder_old
import wandb
from new_module.decode_utils import (
    beam_rerank_v0,
    beam_rerank_v1,
    beam_rerank_v2,
    combi_rerank,
)
# from new_module.new_decode_utils import get_beam_hypotheses, get_combi_hypotheses, final_reranking
from new_module.new_decode_utils import get_beam_hypotheses_v0, get_beam_hypotheses_v1, get_combi_hypotheses, final_reranking
from new_module.evaluate_wandb import evaluate_main
from new_module.locate.new_locate_utils import LocateMachine
from new_module.locate.locate_utils import locate_main
from new_module.utils.robertacustom import RobertaCustomForSequenceClassification

logging.basicConfig(level=logging.DEBUG, format="%(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("LOGGING_LEVEL", logging.DEBUG))
import joblib
config = joblib.load('config.pkl')
config['device'] = 'cuda:1'

In [2]:
from typing import Tuple, List

In [3]:
class dummyArgs:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

build_loss_args = dummyArgs(**config["build_loss_dict"])

## load data
if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
    source_dataset = [
        json.loads(l)[config["jsonl_primary_key"]][config["jsonl_secondary_key"]]
        for l in open(config["source_data"])
    ]
    generation_dataset = [
        json.loads(l)["generations"] for l in open(config["source_data"])
    ]
elif (config["task"] == "formality") or (config["task"] == "sentiment-lewis-compr"):
    with open(config["source_data"], "r") as f:
        generation_dataset = [line.rstrip('\n') for line in f.readlines()]
    source_dataset = ["" for l in generation_dataset]

## load tokenizer, models, define losses
name2tokenizer = {}
name2model = {}
name2config = {}
loss2tokenizer = {}
embed_luts = []

for i, model_path in enumerate(config["model_paths"]):
    if (
        model_path not in name2model
    ):  # making sure we are not loading the model twice in case some constraints use the same model.
        try:
            name2tokenizer[config["tokenizer_paths"][i]] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=True,
            )
        except:
            name2tokenizer[config["tokenizer_paths"][i]] = AutoTokenizer.from_pretrained(
                config["tokenizer_paths"][i],
                cache_dir=config["cache_dir"],
                use_fast=False,
            )

        name2config[model_path] = AutoConfig.from_pretrained(
            model_path, cache_dir=config["cache_dir"]
        )

        if config["model_types"][i] == "RobertaCustomForSequenceClassification":
            name2model[model_path] = lossbuilder.ModelWrapper(
                RobertaCustomForSequenceClassification.from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        else:
            name2model[model_path] = lossbuilder.ModelWrapper(
                getattr(transformers, config["model_types"][i]).from_pretrained(
                    model_path,
                    config=name2config[model_path],
                    cache_dir=config["cache_dir"],
                )
            )
        name2model[model_path].eval()
        name2model[model_path].to(config['device'])

    input_embeds = name2model[model_path].get_input_embeddings()
    if isinstance(input_embeds, torch.nn.Sequential):
        input_embeds = input_embeds[0]
    embed_luts.append(input_embeds)

    if config["target_type"] == "embeds":
        embed_luts[-1].requires_grad = False

mlm_tokenizer = AutoTokenizer.from_pretrained("roberta-base")
mlm = None if config["method"] == "mlm-beamsearch-v2" else AutoModelForMaskedLM.from_pretrained("roberta-base").to(config['device'])


Starting new HTTPS connection (1): huggingface.co:443


https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /gpt2-large/resolve/main/generation_config.json HTTP/1.1" 200 0


50265


https://huggingface.co:443 "HEAD /roberta-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
https://huggingface.co:443 "HEAD /roberta-base/resolve/main/config.json HTTP/1.1" 200 0


In [4]:

lossfns = []
for i, loss in enumerate(config["losses"]):
    lossfns.append(
        lossbuilder.build_loss(
            loss,
            name2model[config["model_paths"][i]],
            name2tokenizer[config["tokenizer_paths"][i]],
            build_loss_args,
        )
    )
    lossfns[i].tokenizer.add_special_tokens({"mask_token": mlm_tokenizer.mask_token})
    loss2tokenizer[loss] = lossfns[i].tokenizer


In [5]:
loss_weights = [1 - config['closs_weight'], config['closs_weight']]
# for text_id in range(len(source_dataset))[resume_idx:]:
text_id = 3
source_text = source_dataset[text_id]
if source_text == "":
    source_text = lossfns[0].tokenizer.bos_token

if (config["task"] == "toxicity") or (config["task"] == "sentiment"):
    AR_prediction_all = [x["text"] for x in generation_dataset[text_id]]
    # predicted_batches = [x["tokens"] for x in generation_dataset[text_id]]
    # predicted_batches = [
    #     torch.tensor([x], dtype=torch.long, device=config["device"])
    #     for x in predicted_batches
    # ]
    
elif (config["task"] == "formality") or (
    config["task"] == "sentiment-lewis-compr"
):
    AR_prediction_all = [generation_dataset[text_id]]

curr_num_samples = len(AR_prediction_all)

In [6]:

# define an object to locate problematic phrases
locator = LocateMachine(lossfns[1].model, lossfns[1].tokenizer)
running_text = best_text = deepcopy(AR_prediction_all)
masked_text = locator.locate_main(running_text, 
                        method = config['locate_method'], 
                        max_num_tokens = config['num_edit_token_per_step'], 
                        unit = config['locate_unit'], 
                        num_layer = -2, #penultimate
                        label_id = config['target_label_ids'][1])

## replace tokens at the indices with mask tokens
                
inputs = mlm_tokenizer(
    masked_text, return_tensors="pt", padding=True, truncation=True
)
inputs = inputs.to(config['device']) 
masked_sequence=inputs['input_ids']


## make predictions for the masked indices
with torch.no_grad():
    logits = mlm(**inputs).logits

special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
logits[:, :, special_token_ids] = -float("inf")


indices_in_mlm_tokens = (
    inputs.input_ids == mlm_tokenizer.mask_token_id
).nonzero(as_tuple=True)

## get top k tokens for each index
predicted_token_ids = torch.topk(
    logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
    k=config['k_per_location'],
    dim=-1,
)


In [6]:
hypotheses = get_beam_hypotheses(source_text, 
                            masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer, 
                            lossfns,
                            config)

old version

In [11]:
## replace tokens at the indices with mask tokens
inputs_old = []
indices_in_mlm_tokens_old = []
predicted_token_ids_old = []
for masked_text_ in masked_text:
    
    ## replace tokens at the indices with mask tokens
    inputs_old_ = mlm_tokenizer(
        masked_text_, return_tensors="pt"
    )
    inputs_old_ = inputs_old_.to(config['device'])
    ## make predictions for the masked indices
    with torch.no_grad():
        logits_old = mlm(**inputs_old_).logits
    indices_in_mlm_tokens_old_ = (
        inputs_old_.input_ids == mlm_tokenizer.mask_token_id
    )[0].nonzero(as_tuple=True)[0]
    # print(f"indices_in_mlm_tokens: {indices_in_mlm_tokens}")
    ## get top k tokens for each index
    inputs_old.append(inputs_old_)
    
    ## make logits for special tokens -inf.
    special_token_ids = mlm_tokenizer.convert_tokens_to_ids(mlm_tokenizer.all_special_tokens)
    logits_old[:, :, special_token_ids] = -np.inf
    
    predicted_token_ids_old_ = torch.topk(
        logits_old[0, indices_in_mlm_tokens_old_],
        k=config['k_per_location'],
        dim=-1,
    )
    indices_in_mlm_tokens_old.append(indices_in_mlm_tokens_old_)
    predicted_token_ids_old.append(predicted_token_ids_old_)

In [14]:

def beam_rerank_v0(source_text, ## text (too arbitrary?)
                    masked_sequence, ## in mlm tokenizer's tokens
                    indices_in_mlm_tokens,
                    predicted_token_ids,
                    mlm_tokenizer, 
                    lossfns,
                    config, 
                    beam_size = 5):
    
    hypotheses = [torch.LongTensor([]).to(config['device'])]
    L = masked_sequence.size(-1)

    for i in range(L):
        if masked_sequence[0, i] != mlm_tokenizer.mask_token_id:
            hypotheses = list(torch.cat([torch.stack(hypotheses,dim=0), 
                                        masked_sequence[:, i].unsqueeze(0).repeat((len(hypotheses),1)).to(config['device'])],dim=-1))
        else:
            num_hypotheses = len(hypotheses)
            hypotheses = torch.stack(hypotheses,dim=0).unsqueeze(0)
            hypotheses = hypotheses.repeat(config['k_per_location'], 1, 1)
            candidates = predicted_token_ids.indices[torch.where(indices_in_mlm_tokens == i)[0], :].to(config['device']).T.unsqueeze(1)
            candidates = candidates.repeat(1, num_hypotheses, 1)
            hypotheses_exp = torch.cat([hypotheses, candidates], dim=-1)
            hypotheses_exp = hypotheses_exp.view(-1, hypotheses_exp.shape[-1])
            hypotheses_exp = list(hypotheses_exp)

            losses = []
            loss_weights = [1 - config['closs_weight'], config['closs_weight']]
            for hyp in hypotheses_exp:
                curr_loss = 0.0
                for lossid, lossname in enumerate(config["losses"]):
                    with torch.no_grad():
                        lossvalue = lossfns[lossid].compute_gold_loss(
                            source_text, mlm_tokenizer.decode(hyp, skip_special_tokens=True),
                            label_id=config['target_label_ids'][lossid],
                        )
                    curr_loss += loss_weights[lossid] * lossvalue.item()
                losses.append(curr_loss)

            hypotheses = sorted(zip(hypotheses_exp, losses), key=lambda x: x[1])[:beam_size]
            hypotheses = [x[0] for x in hypotheses]
            
    return [mlm_tokenizer.decode(x, skip_special_tokens=True) for x in hypotheses]


In [15]:
lossfns_old = []
loss2tokenizer_old = {}
for i, loss in enumerate(config["losses"]):
    lossfns_old.append(
        lossbuilder_old.build_loss(
            loss,
            name2model[config["model_paths"][i]],
            name2tokenizer[config["tokenizer_paths"][i]],
            build_loss_args,
        )
    )
    lossfns_old[i].tokenizer.add_special_tokens({"mask_token": mlm_tokenizer.mask_token})
    loss2tokenizer_old[loss] = lossfns_old[i].tokenizer

In [16]:
hypotheses_old = []
for text_id in range(len(masked_text)):
    hypotheses_old_ = beam_rerank_v0(source_text,
                                inputs_old[text_id].input_ids,
                                indices_in_mlm_tokens_old[text_id],
                                predicted_token_ids_old[text_id],
                                mlm_tokenizer, 
                                lossfns_old,
                                config, 
                                beam_size = config['beam_size'])
    hypotheses_old.append(hypotheses_old_)

두 결과가 다르다. 보다 자세히 디버깅 필요.

In [None]:
hypotheses

In [None]:
hypotheses_old

디버깅

In [None]:

# def get_beam_hypotheses(source_text:str, 
#                     masked_sequence:torch.Tensor, 
#                     indices_in_mlm_tokens:Tuple[torch.Tensor],
#                     predicted_token_ids:torch.Tensor,
#                     mlm_tokenizer:transformers.AutoTokenizer, 
#                     lossfns:List[lossbuilder.BaseLoss],
#                     config:dict) -> List[List[str]]:
#     """
#     A function to get hypotheses of beam size via editing beam search with reranking.
#     Run this function if config['method'] == 'mlm-beamsearch-v0' or config['method'] == 'mlm-beamsearch-v1'
#     If config['method'] == 'mlm-beamsearch-v1', rerank beam only with fluency energy.
#     If config['method'] == 'mlm-beamsearch-v0', rerank beam with a weighted sum of fluency and constraint energy.
    
#     #ToDo
#     #Implement mlm-beamsearch-v0 with allsat-primary and compare 
    
#     params: 
#         source_text: a prompt text 
#         masked_sequence: token ids of original generation text with located indices masked. tokenized by MLM's tokenizer.
#         indices_in_mlm_tokens: a result of running 
#                                     `indices_in_mlm_tokens = (
#                                                                 inputs.input_ids == mlm_tokenizer.mask_token_id
#                                                                 ).nonzero(as_tuple=True)`
#         predicted_token_ids: a result of running
#                                     `predicted_token_ids = torch.topk(
#                                                                 logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
#                                                                 k=config['k_per_location'],
#                                                                 dim=-1,).indices`
#         mlm_tokenizer: tokenizer of MLM
#         lossfns: a list of loss functions
#         config: a dictionary of configurations
    
#     returns:
#         hypotheses: a list of a list of the beam number of hypotheses for each sample         
#     """
    
#     hypotheses = masked_sequence[:, None, :].repeat((1,config['beam_size'],1))
#     edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))

#     for i in edit_indices:
        
#         batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==i]

#         tmp_hypotheses = hypotheses[batch_ids_to_edit].detach().clone()
#         tmp_hypotheses=tmp_hypotheses.repeat((1,config['k_per_location'],1))

#         candidates = predicted_token_ids[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
#         candidates = candidates[:, :, None].repeat((1,1, config['beam_size'])).reshape(candidates.shape[0], -1,1)

#         tmp_hypotheses = torch.cat((tmp_hypotheses[:, :, :i], candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]
#         tmp_hypotheses = tmp_hypotheses.reshape(-1, tmp_hypotheses.shape[-1])
        
#         loss_weights = [1 - config['closs_weight'], config['closs_weight']]
#         curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
#         for lossid, lossname in enumerate(config["losses"]):
#             if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
#                 continue
#             with torch.no_grad():
#                 lossvalue = lossfns[lossid].compute_gold_loss(
#                     source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
#                     label_id=config['target_label_ids'][lossid],
#                 )
#             torch.cuda.empty_cache()
#             curr_loss += loss_weights[lossid] * lossvalue
            
#         curr_loss = torch.stack(torch.split(curr_loss, config['beam_size'] * config['k_per_location'], dim=0),dim=0)
#         top_beams=torch.topk(curr_loss, k=(config['beam_size']*(config['k_per_location']-1)+1), dim=-1, largest=False).indices

#         tmp_hypotheses = torch.split(tmp_hypotheses, config['beam_size'] * config['k_per_location'], dim=0) ## 아래의 작업을 더 간단히 할 수 있는 방법?
#         tmp_hypotheses = torch.stack([x[top_beams[j]] for j, x in enumerate(tmp_hypotheses)],dim=0)
#         tmp_hypotheses = torch.unique(tmp_hypotheses, dim=1)[:, :config['beam_size'], :]
        
#         hypotheses[batch_ids_to_edit,:, i]=tmp_hypotheses[:, :, i]
            
#     return [mlm_tokenizer.batch_decode(hypotheses[j], skip_special_tokens=True) for j in range(hypotheses.shape[0])]


In [16]:
## debugged version
def get_beam_hypotheses(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:Tuple[torch.Tensor],
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict) -> List[List[str]]:
    """
    A function to get hypotheses of beam size via editing beam search with reranking.
    Run this function if config['method'] == 'mlm-beamsearch-v0' or config['method'] == 'mlm-beamsearch-v1'
    If config['method'] == 'mlm-beamsearch-v1', rerank beam only with fluency energy.
    If config['method'] == 'mlm-beamsearch-v0', rerank beam with a weighted sum of fluency and constraint energy.
    
    #ToDo
    #Implement mlm-beamsearch-v0 with allsat-primary and compare 
    
    params: 
        source_text: a prompt text 
        masked_sequence: token ids of original generation text with located indices masked. tokenized by MLM's tokenizer.
        indices_in_mlm_tokens: a result of running 
                                    `indices_in_mlm_tokens = (
                                                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                                                ).nonzero(as_tuple=True)`
        predicted_token_ids: a result of running
                                    `predicted_token_ids = torch.topk(
                                                                logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                                                                k=config['k_per_location'],
                                                                dim=-1,).indices`
        mlm_tokenizer: tokenizer of MLM
        lossfns: a list of loss functions
        config: a dictionary of configurations
    
    returns:
        hypotheses: a list of a list of the beam number of hypotheses for each sample         
    """
    
    def repeat_interleave_unravel(arr,split_blocks):
        arr_ = torch.split(arr.T,1,dim=1)
        arr_ = [x.repeat(1,split_blocks[i]).reshape(-1,1) for i,x in enumerate(arr_)]
        arr_ = torch.cat(arr_,dim=0)
        return arr_
    
    hypotheses = list(torch.split(masked_sequence,1,dim=0)) ## [torch.tensor([[a],[b],[c]]), torch.tensor([[d]])]
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))
    for curr_edit_index in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==curr_edit_index].tolist()
        num_initial_hypotheses = [len(hypotheses[i]) for i in batch_ids_to_edit] ## keep track of initial hypotheses count e.g. [3, 1]
        tmp_hypotheses = [hypotheses[i].repeat((config['k_per_location'],1)) for i in batch_ids_to_edit] ## [torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c]]), torch.tensor([[d],[d],[d]])]
        num_initial_tmp_hypotheses = [len(x) for x in tmp_hypotheses]
        tmp_hypotheses = torch.cat(tmp_hypotheses,dim=0) ## torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c],[d],[d],[d]])
        

        new_func_candidates = predicted_token_ids[indices_in_mlm_tokens[1]==curr_edit_index] ## shape: (len(batch_ids_to_edit), k_per_location) e.g. [[x,y,z],[q,w,e]]
        new_func_candidates = repeat_interleave_unravel(new_func_candidates,num_initial_hypotheses) ## shape: (sum(num_initial_hypotheses), k_per_location) e.g. [[x],[x],[x],[y],[y],[y],[z],[z],[z],[q],[w],[e]]
        new_func_candidates = new_func_candidates.to(config['device'])
        

        tmp_hypotheses = torch.cat((tmp_hypotheses[ :, :curr_edit_index], new_func_candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], new_func_candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]

        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        for lossid, lossname in enumerate(config["losses"]):
            if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
                continue
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                    label_id=config['target_label_ids'][lossid],
                )
            torch.cuda.empty_cache()
            curr_loss += loss_weights[lossid] * lossvalue
        curr_loss = torch.split(curr_loss, num_initial_tmp_hypotheses, dim=0)
        top_beams = [torch.topk(x, k=config['beam_size'], dim=-1, largest=False).indices for x in curr_loss]

        tmp_hypotheses = torch.split(tmp_hypotheses, num_initial_tmp_hypotheses, dim=0)
        for jx, ix in enumerate(batch_ids_to_edit):

            hypotheses[ix] = torch.cat([tmp_hypotheses[jx][top_beams[jx]], masked_sequence[ix][curr_edit_index+1:].unsqueeze(0).repeat(config['beam_size'],1)], dim=-1)
            
    return [mlm_tokenizer.batch_decode(x, skip_special_tokens=True) for x in hypotheses]

In [None]:
## old version code에서 score 결과 및 중간에 나오는 텐서들 추출
sample_id = 0
inside_func_masked_sequence= inputs_old[sample_id].input_ids
inside_func_hypotheses = [torch.LongTensor([]).to(config['device'])]
L = inside_func_masked_sequence.size(-1)

for i in range(L):
    if inside_func_masked_sequence[0, i] != mlm_tokenizer.mask_token_id:
        inside_func_hypotheses = list(torch.cat([torch.stack(inside_func_hypotheses,dim=0), 
                                    inside_func_masked_sequence[:, i].unsqueeze(0).repeat((len(inside_func_hypotheses),1)).to(config['device'])],dim=-1))
    else:
        num_inside_func_hypotheses = len(inside_func_hypotheses)
        inside_func_hypotheses = torch.stack(inside_func_hypotheses,dim=0).unsqueeze(0)
        inside_func_hypotheses = inside_func_hypotheses.repeat(config['k_per_location'], 1, 1)
        inside_func_candidates = predicted_token_ids_old[sample_id].indices[torch.where(indices_in_mlm_tokens_old[sample_id] == i)[0], :].to(config['device']).T.unsqueeze(1)
        inside_func_candidates = inside_func_candidates.repeat(1, num_inside_func_hypotheses, 1)
        inside_func_hypotheses_exp = torch.cat([inside_func_hypotheses, inside_func_candidates], dim=-1)
        inside_func_hypotheses_exp = inside_func_hypotheses_exp.view(-1, inside_func_hypotheses_exp.shape[-1])
        inside_func_hypotheses_exp = list(inside_func_hypotheses_exp)

        inside_func_losses = []
        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        for hyp in inside_func_hypotheses_exp:
            curr_loss = 0.0
            for lossid, lossname in enumerate(config["losses"]):
                with torch.no_grad():
                    lossvalue = lossfns_old[lossid].compute_gold_loss(
                        source_text, mlm_tokenizer.decode(hyp, skip_special_tokens=True),
                        label_id=config['target_label_ids'][lossid],
                    )
                curr_loss += loss_weights[lossid] * lossvalue.item()
            inside_func_losses.append(curr_loss)

        inside_func_hypotheses = sorted(zip(inside_func_hypotheses_exp, inside_func_losses), key=lambda x: x[1])[:config['beam_size']]
        inside_func_hypotheses = [x[0] for x in inside_func_hypotheses]
        
        break

In [None]:
## 새로 짠 함수에서 처리 중간에 나오는 텐서 추출
hypotheses = masked_sequence[:, None, :].repeat((1,config['beam_size'],1))
edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))


In [None]:

for i in edit_indices:
    print(i)
    batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==i]
    print(batch_ids_to_edit)
    print('-'*30)
    tmp_hypotheses = hypotheses[batch_ids_to_edit].detach().clone()
    print(tmp_hypotheses)
    print('-'*30)
    tmp_hypotheses=tmp_hypotheses.repeat((1,config['k_per_location'],1))
    print(tmp_hypotheses)
    print('-'*30)
    print('-'*30)

    new_func_candidates = predicted_token_ids.indices[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
    print(new_func_candidates)
    print('-'*30)
    new_func_candidates = new_func_candidates[:, :, None].repeat((1,1, config['beam_size'])).reshape(new_func_candidates.shape[0], -1,1)
    print(new_func_candidates)
    print('-'*30)

    tmp_hypotheses = torch.cat((tmp_hypotheses[:, :, :i], new_func_candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], new_func_candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]
    print(tmp_hypotheses)
    print('-'*30)
    tmp_hypotheses = tmp_hypotheses.reshape(-1, tmp_hypotheses.shape[-1])
    print(tmp_hypotheses)
    print('-'*30)
    print('-'*30)

    loss_weights = [1 - config['closs_weight'], config['closs_weight']]
    curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
    for lossid, lossname in enumerate(config["losses"]):
        if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
            continue
        with torch.no_grad():
            lossvalue = lossfns[lossid].compute_gold_loss(
                source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                label_id=config['target_label_ids'][lossid],
            )
        torch.cuda.empty_cache()
        curr_loss += loss_weights[lossid] * lossvalue
    
    curr_loss_for_backup = curr_loss.detach().clone()
    curr_loss = torch.stack(torch.split(curr_loss, config['beam_size'] * config['k_per_location'], dim=0),dim=0)
    print(curr_loss)
    print('-'*30)    
    top_beams=torch.topk(curr_loss, k=(config['beam_size']*(config['k_per_location']-1)+1), dim=-1, largest=False).indices
    print(top_beams)
    print('-'*30)   
    print('-'*30)   

    tmp_hypotheses = torch.split(tmp_hypotheses, config['beam_size'] * config['k_per_location'], dim=0) ## 아래의 작업을 더 간단히 할 수 있는 방법?
    tmp_hypotheses_for_backup = deepcopy(tmp_hypotheses)
    print(tmp_hypotheses)
    print('-'*30)       
    
    tmp_hypotheses = torch.stack([x[top_beams[j]] for j, x in enumerate(tmp_hypotheses)],dim=0)
    tmp_hypotheses_for_backup_2 = tmp_hypotheses.detach().clone()
    print(tmp_hypotheses)
    print('-'*30)  
    
    tmp_hypotheses = torch.unique(tmp_hypotheses, dim=1)[:, :config['beam_size'], :]
    print(tmp_hypotheses)
    print('-'*30)  
    
    hypotheses[batch_ids_to_edit,:, i]=tmp_hypotheses[:, :, i]
    
    break

In [None]:
# torch.unique가 문제였다.
torch.unique(torch.Tensor([[7,1,3,3,4,5],[2,2,2,3,3,3]]),dim=-1,sorted=False)

문제 해결 할 수 있도록 재코딩

In [None]:
# hypotheses = list(torch.split(masked_sequence,1,dim=0))
# edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))
# i = edit_indices[0]
# # for i in edit_indices:
# batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==i]

# tmp_hypotheses = [hypotheses[i] for i in batch_ids_to_edit.tolist()]
# tmp_hypotheses = torch.cat(tmp_hypotheses,dim=0).repeat((config['k_per_location'],1))

# num_initial_hypotheses = [len(hypotheses[i]) for i in batch_ids_to_edit.tolist()]
# print(tmp_hypotheses.shape)
# print(num_initial_hypotheses)
# new_func_candidates = predicted_token_ids.indices[(indices_in_mlm_tokens[1]==i).nonzero().squeeze(-1),:]
# new_func_candidates = new_func_candidates.to(config['device'])
# # %%timeit
# # new_func_candidates_ = np.repeat(new_func_candidates.cpu(), num_initial_hypotheses, axis=0)
# # %%timeit

# # def func(arr,split_blocks):
# #     arr_ = torch.split(arr,1,dim=0)
# #     arr_ = [x.repeat(split_blocks[i],1) for i,x in enumerate(arr_)]
# #     arr_ = torch.cat(arr_,dim=0)
# #     return arr_
    
# # # new_func_candidates_ = torch.split(new_func_candidates,1,dim=0)
# # # new_func_candidates_ = [x.repeat(num_initial_hypotheses[i],1) for i,x in enumerate(new_func_candidates_)]
# # # new_func_candidates_ = torch.cat(new_func_candidates_,dim=0)
# # new_func_candidates_ = func(new_func_candidates,num_initial_hypotheses)

In [289]:
# def repeat_interleave_column_wise(arr,split_blocks):
#     arr_ = torch.split(arr,1,dim=1)
#     arr_ = [x.repeat(1,split_blocks[i]) for i,x in enumerate(arr_)]
#     arr_ = torch.cat(arr_,dim=1)
#     return arr_
def repeat_interleave_unravel(arr,split_blocks):
    arr_ = torch.split(arr.T,1,dim=1)
    arr_ = [x.repeat(1,split_blocks[i]).reshape(-1,1) for i,x in enumerate(arr_)]
    arr_ = torch.cat(arr_,dim=0)
    return arr_

In [290]:
hypotheses = list(torch.split(masked_sequence,1,dim=0)) ## [torch.tensor([[a],[b],[c]]), torch.tensor([[d]])]
edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))
for curr_edit_index in edit_indices:
    
    batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==curr_edit_index].tolist()
    num_initial_hypotheses = [len(hypotheses[i]) for i in batch_ids_to_edit] ## keep track of initial hypotheses count e.g. [3, 1]
    tmp_hypotheses = [hypotheses[i].repeat((config['k_per_location'],1)) for i in batch_ids_to_edit] ## [torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c]]), torch.tensor([[d],[d],[d]])]
    num_initial_tmp_hypotheses = [len(x) for x in tmp_hypotheses]
    tmp_hypotheses = torch.cat(tmp_hypotheses,dim=0) ## torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c],[d],[d],[d]])
    

    new_func_candidates = predicted_token_ids.indices[indices_in_mlm_tokens[1]==curr_edit_index] ## shape: (len(batch_ids_to_edit), k_per_location) e.g. [[x,y,z],[q,w,e]]
    new_func_candidates = repeat_interleave_unravel(new_func_candidates,num_initial_hypotheses) ## shape: (sum(num_initial_hypotheses), k_per_location) e.g. [[x],[x],[x],[y],[y],[y],[z],[z],[z],[q],[w],[e]]
    new_func_candidates = new_func_candidates.to(config['device'])
    

    tmp_hypotheses = torch.cat((tmp_hypotheses[ :, :curr_edit_index], new_func_candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], new_func_candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]

    loss_weights = [1 - config['closs_weight'], config['closs_weight']]
    curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
    for lossid, lossname in enumerate(config["losses"]):
        if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
            continue
        with torch.no_grad():
            lossvalue = lossfns[lossid].compute_gold_loss(
                source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                label_id=config['target_label_ids'][lossid],
            )
        torch.cuda.empty_cache()
        curr_loss += loss_weights[lossid] * lossvalue
    curr_loss = torch.split(curr_loss, num_initial_tmp_hypotheses, dim=0)
    top_beams = [torch.topk(x, k=config['beam_size'], dim=-1, largest=False).indices for x in curr_loss]

    tmp_hypotheses = torch.split(tmp_hypotheses, num_initial_tmp_hypotheses, dim=0)
    for jx, ix in enumerate(batch_ids_to_edit):

        hypotheses[ix] = torch.cat([tmp_hypotheses[jx][top_beams[jx]], masked_sequence[ix][curr_edit_index+1:].unsqueeze(0).repeat(config['beam_size'],1)], dim=-1)

In [293]:
hypotheses_old

[[' participating in a community project. AFFA is a, and a, for a.)',
  ' participating in a community project. AFFA is a, and a, for a.,',
  ' participating in a community project. AFFA is a, and a, for a..',
  ' participating in a community project. AFFA is a, and the, for the.,',
  ' participating in a community project. AFFA is a, and a, for a.]'],
 [' the fact that the word “s’ hates on ʳCriminally, the §$$$ by making \u200ferr\u200ft all girls,‼',
  ' the fact that the word “s’ hates on ʳCriminally, the §$$$ by making \u200ferr\u200f, all girls,‼',
  ' the fact that the word “s’ hates on ʳCriminally, in §$$$ by making \u200ferr\u200f, all girls,‼',
  ' the fact that the word “s’ hates on ʳCriminally, and §$$$ by making \u200ferr\u200ft all girls,‼',
  ' the fact that the word “s’ hates on ʳCriminally, and §$$$ by making \u200ferr\u200f, all girls,‼']]

In [294]:
[mlm_tokenizer.batch_decode(x,skip_special_tokens=True) for x in hypotheses]

True

In [295]:
[mlm_tokenizer.batch_decode(x,skip_special_tokens=True) for x in hypotheses]==hypotheses_old

True

In [None]:
## 여전히 결과가 다르다.  -> 아래에서 문제 해결 완료

In [191]:
## old version code에서 score 결과 및 중간에 나오는 텐서들 추출
sample_id = 0
inside_func_masked_sequence= inputs_old[sample_id].input_ids
inside_func_hypotheses = [torch.LongTensor([]).to(config['device'])]
L = inside_func_masked_sequence.size(-1)


In [194]:

for i in range(6,L):
    if inside_func_masked_sequence[0, i] != mlm_tokenizer.mask_token_id:
        inside_func_hypotheses = list(torch.cat([torch.stack(inside_func_hypotheses,dim=0), 
                                    inside_func_masked_sequence[:, i].unsqueeze(0).repeat((len(inside_func_hypotheses),1)).to(config['device'])],dim=-1))
    else:
        num_inside_func_hypotheses = len(inside_func_hypotheses)
        inside_func_hypotheses = torch.stack(inside_func_hypotheses,dim=0).unsqueeze(0)
        inside_func_hypotheses = inside_func_hypotheses.repeat(config['k_per_location'], 1, 1)
        inside_func_candidates = predicted_token_ids_old[sample_id].indices[torch.where(indices_in_mlm_tokens_old[sample_id] == i)[0], :].to(config['device']).T.unsqueeze(1)
        inside_func_candidates = inside_func_candidates.repeat(1, num_inside_func_hypotheses, 1)
        inside_func_hypotheses_exp = torch.cat([inside_func_hypotheses, inside_func_candidates], dim=-1)
        inside_func_hypotheses_exp = inside_func_hypotheses_exp.view(-1, inside_func_hypotheses_exp.shape[-1])
        inside_func_hypotheses_exp = list(inside_func_hypotheses_exp)

        inside_func_losses = []
        inside_func_logging_losses = []
        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        for hyp in inside_func_hypotheses_exp:
            inside_func_logging_losses_ = []
            curr_loss = 0.0
            for lossid, lossname in enumerate(config["losses"]):
                with torch.no_grad():
                    lossvalue = lossfns_old[lossid].compute_gold_loss(
                        source_text, mlm_tokenizer.decode(hyp, skip_special_tokens=True),
                        label_id=config['target_label_ids'][lossid],
                    )
                curr_loss += loss_weights[lossid] * lossvalue.item()
                inside_func_logging_losses_.append(lossvalue.item())
            inside_func_losses.append(curr_loss)
            inside_func_logging_losses.append(inside_func_logging_losses_)

        inside_func_hypotheses = sorted(zip(inside_func_hypotheses_exp, inside_func_losses), key=lambda x: x[1])[:config['beam_size']]
        inside_func_hypotheses = [x[0] for x in inside_func_hypotheses]
        break

In [203]:
logging_loss.tolist()

[[25.619842529296875, 0.006565547082573175],
 [27.66191291809082, 0.016244199126958847],
 [27.646846771240234, 0.0031171089503914118],
 [23.519668579101562, 0.0008852138998918235],
 [29.10457992553711, 0.004278791137039661],
 [20.562965393066406, 0.007706434931606054],
 [26.684307098388672, 0.014183899387717247],
 [27.279190063476562, 0.003786657238379121],
 [22.170913696289062, 0.0009748950251378119],
 [29.198673248291016, 0.003246515290811658],
 [24.92827606201172, 0.006835410837084055],
 [27.849733352661133, 0.04606309533119202],
 [21.119571685791016, 0.003120079869404435],
 [25.584442138671875, 0.0011926926672458649],
 [24.726680755615234, 0.0033429949544370174],
 [20.22182273864746, 0.007419057190418243],
 [22.00562286376953, 0.015473551116883755],
 [25.013351440429688, 0.0028388698119670153],
 [19.766510009765625, 0.0011063652345910668],
 [28.44176483154297, 0.0026339145842939615],
 [24.887611389160156, 0.004260985646396875],
 [26.682390213012695, 0.008359324187040329],
 [25.9007

In [196]:
inside_func_logging_losses # 25.619842529296875 # 25.619840621948242

[[25.619840621948242, 0.006565547082573175],
 [27.661916732788086, 0.016244199126958847],
 [27.6468505859375, 0.0031171089503914118],
 [23.519678115844727, 0.0008852138998918235],
 [29.104576110839844, 0.004278909880667925],
 [24.928274154663086, 0.006835410837084055],
 [26.68430519104004, 0.014183899387717247],
 [27.279197692871094, 0.003786657238379121],
 [22.17091941833496, 0.0009748950251378119],
 [29.198673248291016, 0.003246515290811658],
 [20.221820831298828, 0.007419057190418243],
 [22.0056209564209, 0.015473551116883755],
 [21.119569778442383, 0.003120079869404435],
 [25.584447860717773, 0.0011926926672458649],
 [24.726682662963867, 0.0033429949544370174],
 [24.88760757446289, 0.004260985646396875],
 [26.68239402770996, 0.008359324187040329],
 [25.900724411010742, 0.0021347845904529095],
 [19.766511917114258, 0.0011063652345910668],
 [28.441770553588867, 0.0026339145842939615],
 [23.970430374145508, 0.008097084239125252],
 [24.929170608520508, 0.023823320865631104],
 [22.91668

In [210]:
indiv_decode_result = [mlm_tokenizer.decode(x, skip_special_tokens=True) for x in inside_func_hypotheses_exp]

In [286]:
batch_decode_result = mlm_tokenizer.batch_decode(tmp_hypotheses[0],skip_special_tokens=True)

In [287]:
len(batch_decode_result), len(indiv_decode_result)

(50, 50)

In [288]:
for a,b in zip(sorted(indiv_decode_result),sorted(batch_decode_result)):
    if a != b:
        print(a)
        print(b)
        print('-'*100)

In [279]:
def repeat_interleave_unravel(arr,split_blocks):
    arr_ = torch.split(arr.T,1,dim=1)
    arr_ = [x.repeat(1,split_blocks[i]).reshape(-1,1) for i,x in enumerate(arr_)]
    arr_ = torch.cat(arr_,dim=0)
    return arr_

In [282]:
hypotheses = list(torch.split(masked_sequence,1,dim=0)) ## [torch.tensor([[a],[b],[c]]), torch.tensor([[d]])]
edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))

In [285]:
for curr_edit_index in edit_indices[2:]:
    
    batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==curr_edit_index].tolist()
    num_initial_hypotheses = [len(hypotheses[i]) for i in batch_ids_to_edit] ## keep track of initial hypotheses count e.g. [3, 1]
    tmp_hypotheses = [hypotheses[i].repeat((config['k_per_location'],1)) for i in batch_ids_to_edit] ## [torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c]]), torch.tensor([[d],[d],[d]])]
    num_initial_tmp_hypotheses = [len(x) for x in tmp_hypotheses]
    tmp_hypotheses = torch.cat(tmp_hypotheses,dim=0) ## torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c],[d],[d],[d]])
    
    new_func_candidates = predicted_token_ids.indices[indices_in_mlm_tokens[1]==curr_edit_index] ## shape: (len(batch_ids_to_edit), k_per_location) e.g. [[x,y,z],[q,w,e]]
    new_func_candidates = repeat_interleave_unravel(new_func_candidates,num_initial_hypotheses) ## shape: (sum(num_initial_hypotheses), k_per_location) e.g. [[x],[x],[x],[y],[y],[y],[z],[z],[z],[q],[w],[e]]
    new_func_candidates = new_func_candidates.to(config['device'])
    

    tmp_hypotheses = torch.cat((tmp_hypotheses[ :, :curr_edit_index], new_func_candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], new_func_candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]

    loss_weights = [1 - config['closs_weight'], config['closs_weight']]
    curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
    logging_loss = torch.zeros((tmp_hypotheses.shape[0],len(lossfns))).to(config['device'])
    for lossid, lossname in enumerate(config["losses"]):
        if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
            continue
        with torch.no_grad():
            lossvalue = lossfns[lossid].compute_gold_loss(
                source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                label_id=config['target_label_ids'][lossid],
            )
        torch.cuda.empty_cache()
        curr_loss += loss_weights[lossid] * lossvalue
        logging_loss[:, lossid] = lossvalue
    curr_loss = torch.split(curr_loss, num_initial_tmp_hypotheses, dim=0)
    top_beams = [torch.topk(x, k=config['beam_size'], dim=-1, largest=False).indices for x in curr_loss]

    tmp_hypotheses = torch.split(tmp_hypotheses, num_initial_tmp_hypotheses, dim=0)
    for jx, ix in enumerate(batch_ids_to_edit):

        hypotheses[ix] = torch.cat([tmp_hypotheses[jx][top_beams[jx]], masked_sequence[ix][curr_edit_index+1:].unsqueeze(0).repeat(config['beam_size'],1)], dim=-1)
    
    break

beam rerank v1 도 테스트

In [59]:
## debugged version
def get_beam_hypotheses(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:Tuple[torch.Tensor],
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict) -> List[List[str]]:
    """
    A function to get hypotheses of beam size via editing beam search with reranking.
    Run this function if config['method'] == 'mlm-beamsearch-v0' or config['method'] == 'mlm-beamsearch-v1'
    If config['method'] == 'mlm-beamsearch-v1', rerank beam only with fluency energy.
    If config['method'] == 'mlm-beamsearch-v0', rerank beam with a weighted sum of fluency and constraint energy.
    
    #ToDo
    #Implement mlm-beamsearch-v0 with allsat-primary and compare 
    
    params: 
        source_text: a prompt text 
        masked_sequence: token ids of original generation text with located indices masked. tokenized by MLM's tokenizer.
        indices_in_mlm_tokens: a result of running 
                                    `indices_in_mlm_tokens = (
                                                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                                                ).nonzero(as_tuple=True)`
        predicted_token_ids: a result of running
                                    `predicted_token_ids = torch.topk(
                                                                logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                                                                k=config['k_per_location'],
                                                                dim=-1,).indices`
        mlm_tokenizer: tokenizer of MLM
        lossfns: a list of loss functions
        config: a dictionary of configurations
    
    returns:
        hypotheses: a list of a list of the beam number of hypotheses for each sample         
    """
    
    def repeat_interleave_unravel(arr,split_blocks):
        arr_ = torch.split(arr.T,1,dim=1)
        arr_ = [x.repeat(1,split_blocks[i]).reshape(-1,1) for i,x in enumerate(arr_)]
        arr_ = torch.cat(arr_,dim=0)
        return arr_
    
    hypotheses = list(torch.split(masked_sequence,1,dim=0)) ## [torch.tensor([[a],[b],[c]]), torch.tensor([[d]])]
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))
    for curr_edit_index in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==curr_edit_index].tolist()
        num_initial_hypotheses = [len(hypotheses[i]) for i in batch_ids_to_edit] ## keep track of initial hypotheses count e.g. [3, 1]
        tmp_hypotheses = [hypotheses[i].repeat((config['k_per_location'],1)) for i in batch_ids_to_edit] ## [torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c]]), torch.tensor([[d],[d],[d]])]
        num_initial_tmp_hypotheses = [len(x) for x in tmp_hypotheses]
        tmp_hypotheses = torch.cat(tmp_hypotheses,dim=0) ## torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c],[d],[d],[d]])
        

        new_func_candidates = predicted_token_ids[indices_in_mlm_tokens[1]==curr_edit_index] ## shape: (len(batch_ids_to_edit), k_per_location) e.g. [[x,y,z],[q,w,e]]
        new_func_candidates = repeat_interleave_unravel(new_func_candidates,num_initial_hypotheses) ## shape: (sum(num_initial_hypotheses), k_per_location) e.g. [[x],[x],[x],[y],[y],[y],[z],[z],[z],[q],[w],[e]]
        new_func_candidates = new_func_candidates.to(config['device'])
        

        tmp_hypotheses = torch.cat((tmp_hypotheses[ :, :curr_edit_index], new_func_candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], new_func_candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]

        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        for lossid, lossname in enumerate(config["losses"]):
            if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
                continue
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                    label_id=config['target_label_ids'][lossid],
                )
            torch.cuda.empty_cache()
            curr_loss += loss_weights[lossid] * lossvalue
        curr_loss = torch.split(curr_loss, num_initial_tmp_hypotheses, dim=0)
        top_beams = [torch.topk(x, k=config['beam_size'], dim=-1, largest=False).indices for x in curr_loss]

        tmp_hypotheses = torch.split(tmp_hypotheses, num_initial_tmp_hypotheses, dim=0)
        for jx, ix in enumerate(batch_ids_to_edit):

            hypotheses[ix] = torch.cat([tmp_hypotheses[jx][top_beams[jx]], masked_sequence[ix][curr_edit_index+1:].unsqueeze(0).repeat(config['beam_size'],1)], dim=-1)
            
    return [mlm_tokenizer.batch_decode(x, skip_special_tokens=True) for x in hypotheses]

In [61]:
## old beam rerank v1
hypotheses_old = []
for text_id in range(len(masked_text)):
    hypotheses_old_ = beam_rerank_v1(source_text,
                                inputs_old[text_id].input_ids,
                                indices_in_mlm_tokens_old[text_id],
                                predicted_token_ids_old[text_id],
                                mlm_tokenizer, 
                                lossfns_old,
                                config, 
                                beam_size = config['beam_size'])
    hypotheses_old.append(hypotheses_old_)

In [60]:
## new 
config['method'] = 'mlm-beamsearch-v1'
hypotheses = get_beam_hypotheses(source_text, 
                            masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer, 
                            lossfns,
                            config)

In [44]:
## for comparison with v1
config['method'] = 'mlm-beamsearch-v0'
hypotheses_v0 = get_beam_hypotheses(source_text, 
                            masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer, 
                            lossfns,
                            config)

In [62]:
hypotheses==hypotheses_old ## v1 으로 뽑은 hypotheses에 대해서, new와 old가 같은 결과 뱉음. -> Good!

True

In [17]:
hypotheses_v0 == hypotheses ## v0과 v1으로 뽑은 결과는 다르다! -> Good!

False

In [18]:
hypotheses_v0

[[' participating in a community project. AFFA is a, and a, for a.)',
  ' participating in a community project. AFFA is a, and a, for a.,',
  ' participating in a community project. AFFA is a, and a, for a..',
  ' participating in a community project. AFFA is a, and the, for the.,',
  ' participating in a community project. AFFA is a, and a, for a.]'],
 [' the fact that the word “s’ hates on ʳCriminally, the §$$$ by making \u200ferr\u200ft all girls,‼',
  ' the fact that the word “s’ hates on ʳCriminally, the §$$$ by making \u200ferr\u200f, all girls,‼',
  ' the fact that the word “s’ hates on ʳCriminally, in §$$$ by making \u200ferr\u200f, all girls,‼',
  ' the fact that the word “s’ hates on ʳCriminally, and §$$$ by making \u200ferr\u200ft all girls,‼',
  ' the fact that the word “s’ hates on ʳCriminally, and §$$$ by making \u200ferr\u200f, all girls,‼']]

In [14]:
hypotheses_old

[[' participating in a community project. AFFA is a, and a, for a.)',
  ' participating in a community project. AFFA is a, and a, for a.,',
  ' participating in a community project. AFFA is a, and a, for a..',
  ' participating in a community project. AFFA is a, and the, for the.,',
  ' participating in a community project. AFFA is a, and a, for a.]'],
 [' the fact that the word, as is, hates on ʳCriminally, the §$$$ by making \u200ferr\u200ft all girls,‼',
  ' the fact that the word, as is, hates on ʳCriminally, a §$$$ by making \u200ferr\u200ft all girls,‼',
  ' the fact that the word, as is, hates on ʳCriminally, the §$$$ by making \u200ferr\u200f, all girls,‼',
  ' the fact that the word, as is, hates on ʳCriminally, a §$$$ by making \u200ferr\u200f, all girls,‼',
  ' the fact that the word, as is, hates on ʳCriminally, in §$$$ by making \u200ferr\u200ft all girls,‼']]

Combirerank도 확인해보기

In [12]:
config['k_per_location']=3
config['num_edit_token_per_step']=4

In [13]:
hypotheses = get_combi_hypotheses(masked_sequence, 
                                indices_in_mlm_tokens,
                                predicted_token_ids.indices,
                                mlm_tokenizer,
                                config)

In [15]:
## old 
hypotheses_old = []
for text_id in range(len(masked_text)):
    hypotheses_old_ = combi_rerank(
                                inputs_old[text_id].input_ids,
                                indices_in_mlm_tokens_old[text_id],
                                predicted_token_ids_old[text_id],
                                mlm_tokenizer, 
                                config, 
                                )
    hypotheses_old.append(hypotheses_old_)

In [18]:
hypotheses == hypotheses_old

False

In [None]:
## sorting의 문제였음. -> OK!
sorted(hypotheses_old[0])==sorted(hypotheses[0])

True

In [None]:
sorted(hypotheses_old[1])==sorted(hypotheses[1])

True

마지막으로 final reranking도 확인해보기

In [7]:
# config['method'] = 'mlm-beamsearch-v1'
hypotheses = get_beam_hypotheses_v1(source_text, 
                            masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer, 
                            lossfns,
                            config)

In [17]:
## old beam rerank v1
hypotheses_old = []
for text_id in range(len(masked_text)):
    hypotheses_old_ = beam_rerank_v1(source_text,
                                inputs_old[text_id].input_ids,
                                indices_in_mlm_tokens_old[text_id],
                                predicted_token_ids_old[text_id],
                                mlm_tokenizer, 
                                lossfns_old,
                                config, 
                                beam_size = config['beam_size'])
    hypotheses_old.append(hypotheses_old_)

In [18]:
hypotheses == hypotheses_old

True

In [19]:
final_hypotheses_, new_best_weighted_loss_, new_best_allsat_, new_best_logging_loss_ = final_reranking(source_text,
                                                                                                    hypotheses,
                                                                                                    lossfns,
                                                                                                    config,
                                                                                                    batch_size=64)

In [97]:
from torch.utils.data import DataLoader,Dataset
batch_size=64
class CustomDataset(Dataset):
    def __init__(self, hypotheses_data:List[str]):
        self.hypotheses_data = hypotheses_data
        
    def __len__(self):
        return len(self.hypotheses_data)

    def __getitem__(self, idx:int):
        return self.hypotheses_data[idx]
    
    def __getitems__(self, idx:List[int]):
        return [self.hypotheses_data[j] for j in idx]

final_hypotheses = []
best_weighted_loss = []
best_allsat = []
best_logging_loss = []

loss_weights = [1 - config['closs_weight'], config['closs_weight']]

# for i in tqdm(range(len(hypotheses))):
for i in range(len(hypotheses)):
    curr_loss = torch.zeros(len(hypotheses[i])).to(config['device'])
    logging_loss = torch.zeros((len(hypotheses[i]),2)).to(config['device'])
    data_loader = DataLoader(CustomDataset(hypotheses[i]),batch_size=batch_size)

    for lossid, lossname in enumerate(config["losses"]):
        lossvalues=[]
        with torch.no_grad():
            for batch in data_loader:
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, batch,
                    label_id=config['target_label_ids'][lossid],
                )
                lossvalues.append(lossvalue)
                torch.cuda.empty_cache()
        lossvalue = torch.cat(lossvalues,dim=0)
        curr_loss += loss_weights[lossid] * lossvalue
        logging_loss[:, lossid] = lossvalue.clone()
    
    print(logging_loss.tolist())
    allsat_ix = torch.where(logging_loss[:,1]< -math.log(config["min_epsilons"][0]))[0]
    if (len(allsat_ix) > 0) and (config['selection_criteria'] == "allsat_primary"):
    #if (allsat_ix.shape[0] > 0) and (config['selection_criteria'] == "allsat_primary"):
        best_ix = allsat_ix[curr_loss[allsat_ix].argmin()]
    else: ## in case config['selection_criteria'] == "weighted_sum" or allsat is all False
        best_ix = torch.argmin(curr_loss)

    final_hypotheses.append(hypotheses[i][best_ix])
    best_weighted_loss.append(curr_loss[best_ix].item())
    best_allsat.append(1 if best_ix in allsat_ix else 0)
    best_logging_loss.append(logging_loss[best_ix].cpu().tolist())

    del curr_loss, logging_loss
    torch.cuda.empty_cache()

[[74.82612609863281, 0.00542679475620389], [75.46199035644531, 0.005077206529676914], [75.87005615234375, 0.005122513044625521], [76.58781433105469, 0.0029306341893970966], [76.87216186523438, 0.006557730957865715]]
[[160.008544921875, 0.34055182337760925], [160.2454376220703, 0.36186039447784424], [161.38922119140625, 0.34558627009391785], [161.86358642578125, 0.35809406638145447], [161.68670654296875, 0.32997792959213257]]


In [45]:
i = 0
from torch.utils.data import DataLoader,Dataset
batch_size=64
class CustomDataset(Dataset):
    def __init__(self, hypotheses_data:List[str]):
        self.hypotheses_data = hypotheses_data
        
    def __len__(self):
        return len(self.hypotheses_data)

    def __getitem__(self, idx:int):
        return self.hypotheses_data[idx]
    
    def __getitems__(self, idx:List[int]):
        return [self.hypotheses_data[j] for j in idx]
curr_loss = torch.zeros(len(hypotheses[i])).to(config['device'])
logging_loss = torch.zeros((len(hypotheses[i]),2)).to(config['device'])
data_loader = DataLoader(CustomDataset(hypotheses[i]),batch_size=batch_size)

for lossid, lossname in enumerate(config["losses"]):
    lossvalues=[]
    with torch.no_grad():
        for batch in data_loader:
            lossvalue = lossfns[lossid].compute_gold_loss(
                source_text, batch,
                label_id=config['target_label_ids'][lossid],
            )
            lossvalues.append(lossvalue)
            torch.cuda.empty_cache()
    lossvalue = torch.cat(lossvalues,dim=0)
    curr_loss += loss_weights[lossid] * lossvalue
    logging_loss[:, lossid] = lossvalue.clone()

In [57]:
batch == hypotheses[0]

True

In [58]:
with torch.no_grad():
    for batch in data_loader:
        lossvalue = lossfns[lossid].compute_gold_loss(
            source_text, batch,
            label_id=config['target_label_ids'][lossid],
        )
        lossvalues.append(lossvalue)
        torch.cuda.empty_cache()

In [61]:
lossvalue = lossfns[0].compute_gold_loss(
            source_text, [hypotheses[0][0]],
            label_id=config['target_label_ids'][lossid],
        )

In [63]:
lossvalue.tolist()

[74.82614135742188]

In [60]:
lossvalues[1].tolist()

[74.82612609863281,
 75.46199035644531,
 75.87005615234375,
 76.58781433105469,
 76.87216186523438]

In [47]:
logging_loss.tolist()

[[74.82612609863281, 0.00542679475620389],
 [75.46199035644531, 0.005077206529676914],
 [75.87005615234375, 0.005122513044625521],
 [76.58781433105469, 0.0029306341893970966],
 [76.87216186523438, 0.006557730957865715]]

In [27]:
loss_weights

[0.09999999999999998, 0.9]

In [36]:
final_hypotheses_

[' participating in a community project. AFFA is a, and a, for a.)',
 ' the fact that the word, as is, hates on ʳCriminally, the §$$$ by making \u200ferr\u200ft all girls,‼']

In [49]:
hypotheses_old[1]

[' the fact that the word, as is, hates on ʳCriminally, the §$$$ by making \u200ferr\u200ft all girls,‼',
 ' the fact that the word, as is, hates on ʳCriminally, a §$$$ by making \u200ferr\u200ft all girls,‼',
 ' the fact that the word, as is, hates on ʳCriminally, the §$$$ by making \u200ferr\u200f, all girls,‼',
 ' the fact that the word, as is, hates on ʳCriminally, a §$$$ by making \u200ferr\u200f, all girls,‼',
 ' the fact that the word, as is, hates on ʳCriminally, in §$$$ by making \u200ferr\u200ft all girls,‼']

In [53]:
hyp = hypotheses_old[0][0]

In [54]:
lossid = 0
with torch.no_grad():
    lossvalue = lossfns_old[lossid].compute_gold_loss(
        source_text, hyp,
        label_id=config['target_label_ids'][lossid],
    )

In [55]:
lossvalue.item()

74.82614135742188

In [32]:
lossid = 1
with torch.no_grad():
    lossvalue = lossfns_old[lossid].compute_gold_loss(
        source_text, hyp,
        label_id=config['target_label_ids'][lossid],
    )

In [33]:
lossvalue.item()

0.3299781084060669

In [20]:
final_hypotheses_old_ = []
new_best_weighted_loss_old_ = []
new_best_allsat_old_ = []
new_best_logging_loss_old_ = []
for batch_id in range(len(hypotheses_old)):
    candidate_total_losses = []
    candidate_primary_losses = []
    candidate_losses_for_loggings = []
    candidate_allsats = []

    for hyp in hypotheses_old[batch_id]:
        curr_loss = 0.0
        logging_loss = []
        allsat = True
        for lossid, lossname in enumerate(config["losses"]):
            with torch.no_grad():
                lossvalue = lossfns_old[lossid].compute_gold_loss(
                    source_text, hyp,
                    label_id=config['target_label_ids'][lossid],
                )
            curr_loss += loss_weights[lossid] * lossvalue.item()
            logging_loss.append(lossvalue.item())
            if lossid==0:
                candidate_primary_losses.append(lossvalue.item())
            elif (lossid >= 1) and (
                lossvalue.item()
                > -np.log(config["min_epsilons"][lossid - 1])
            ):
                allsat = False
        candidate_total_losses.append(curr_loss)
        candidate_losses_for_loggings.append(logging_loss)
        candidate_allsats.append(allsat)


    if config['selection_criteria'] == "weighted_sum":
        best_ix = np.argmin(np.array(candidate_total_losses))
    elif config['selection_criteria'] == "allsat_primary":
        allsat_ix = np.where(np.array(candidate_allsats) == True)[0]
        if len(allsat_ix) > 0:
            best_ix = np.argmin(
                np.array(candidate_primary_losses)[allsat_ix]
            )  # select min primary loss among allsats
            best_ix = allsat_ix[best_ix]
        else:  # if no candidate satisfying constraints, default to weighted_sum
            best_ix = np.argmin(np.array(candidate_total_losses))
    final_hypotheses_old_.append(hypotheses_old[batch_id][best_ix])
    new_best_weighted_loss_old_.append(candidate_total_losses[best_ix])
    new_best_allsat_old_.append(candidate_allsats[best_ix])
    new_best_logging_loss_old_.append(candidate_losses_for_loggings[best_ix])

[74.82614135742188, 0.00542679475620389]
[75.46200561523438, 0.005077206529676914]
[75.87007141113281, 0.005122513044625521]
[76.58781433105469, 0.0029306341893970966]
[76.87217712402344, 0.006557730957865715]
[160.00857543945312, 0.34055131673812866]
[160.24545288085938, 0.3618605434894562]
[161.38922119140625, 0.3455859124660492]
[161.86361694335938, 0.3580939769744873]
[161.68667602539062, 0.3299781084060669]


In [21]:
final_hypotheses_

[' participating in a community project. AFFA is a, and a, for a.)',
 ' the fact that the word, as is, hates on ʳCriminally, the §$$$ by making \u200ferr\u200ft all girls,‼']

In [22]:
final_hypotheses_old_

[' participating in a community project. AFFA is a, and a, for a.)',
 ' the fact that the word, as is, hates on ʳCriminally, the §$$$ by making \u200ferr\u200ft all girls,‼']

In [26]:
new_best_weighted_loss_.tolist()

[7.487496852874756, 16.307350158691406]

In [23]:
new_best_weighted_loss_old_

[7.48749825102277, 16.307353729009623]

In [None]:
[[74.82612609863281, 0.00542679475620389], [75.46199035644531, 0.005077206529676914], [75.87005615234375, 0.005122513044625521], [76.58781433105469, 0.0029306341893970966], [76.87216186523438, 0.006557730957865715]]
[[160.008544921875, 0.34055182337760925], [160.2454376220703, 0.36186039447784424], [161.38922119140625, 0.34558627009391785], [161.86358642578125, 0.35809406638145447], [161.68670654296875, 0.32997792959213257]]

In [98]:
lossvalue = lossfns[0].compute_gold_loss(
                    source_text, hypotheses[0],
                    label_id=config['target_label_ids'][lossid],
                )

In [100]:
lossvalue.tolist()

[74.82612609863281,
 75.46199035644531,
 75.87005615234375,
 76.58781433105469,
 76.87216186523438]

In [83]:
final_hypotheses_old_

[' participating in a community project. AFFA is a, and a, for a.)',
 ' the fact that the word, as is, hates on ʳCriminally, the §$$$ by making \u200ferr\u200ft all girls,‼']

In [76]:
final_hypotheses_

[' participating in a community project. AFFA is a, and a, for a.)',
 ' the fact that the word, as is, hates on ʳCriminally, the §$$$ by making \u200ferr\u200ft all girls,‼']

In [84]:
new_best_weighted_loss_old_

[7.48749825102277, 16.307353729009623]

In [86]:
new_best_weighted_loss_.tolist()

[7.487496852874756, 16.307350158691406]

In [88]:
new_best_allsat_

tensor([ True, False], device='cuda:1')

In [87]:
new_best_allsat_old_

[True, False]

In [91]:
new_best_logging_loss_.tolist()

[[74.82612609863281, 0.00542679475620389],
 [160.008544921875, 0.34055182337760925]]

In [89]:
new_best_logging_loss_old_

[[74.82614135742188, 0.00542679475620389],
 [160.00857543945312, 0.34055131673812866]]

Seek extra ... ms reduction in time by getting rid of if statement within for loop (get_beam_hypotheses)

In [None]:
def get_beam_hypotheses_v1(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:Tuple[torch.Tensor],
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict) -> List[List[str]]:
    """
    A function to get hypotheses of beam size via editing beam search with reranking.
    Run this function if config['method'] == 'mlm-beamsearch-v0' or config['method'] == 'mlm-beamsearch-v1'
    If config['method'] == 'mlm-beamsearch-v1', rerank beam only with fluency energy.
    If config['method'] == 'mlm-beamsearch-v0', rerank beam with a weighted sum of fluency and constraint energy.
    
    #ToDo
    #Implement mlm-beamsearch-v0 with allsat-primary and compare 
    
    params: 
        source_text: a prompt text 
        masked_sequence: token ids of original generation text with located indices masked. tokenized by MLM's tokenizer.
        indices_in_mlm_tokens: a result of running 
                                    `indices_in_mlm_tokens = (
                                                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                                                ).nonzero(as_tuple=True)`
        predicted_token_ids: a result of running
                                    `predicted_token_ids = torch.topk(
                                                                logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                                                                k=config['k_per_location'],
                                                                dim=-1,).indices`
        mlm_tokenizer: tokenizer of MLM
        lossfns: a list of loss functions
        config: a dictionary of configurations
    
    returns:
        hypotheses: a list of a list of the beam number of hypotheses for each sample         
    """
    
    def repeat_interleave_unravel(arr,split_blocks):
        arr_ = torch.split(arr.T,1,dim=1)
        arr_ = [x.repeat(1,split_blocks[i]).reshape(-1,1) for i,x in enumerate(arr_)]
        arr_ = torch.cat(arr_,dim=0)
        return arr_
    
    hypotheses = list(torch.split(masked_sequence,1,dim=0)) ## [torch.tensor([[a],[b],[c]]), torch.tensor([[d]])]
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))
    for curr_edit_index in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==curr_edit_index].tolist()
        num_initial_hypotheses = [len(hypotheses[i]) for i in batch_ids_to_edit] ## keep track of initial hypotheses count e.g. [3, 1]
        tmp_hypotheses = [hypotheses[i].repeat((config['k_per_location'],1)) for i in batch_ids_to_edit] ## [torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c]]), torch.tensor([[d],[d],[d]])]
        num_initial_tmp_hypotheses = [len(x) for x in tmp_hypotheses]
        tmp_hypotheses = torch.cat(tmp_hypotheses,dim=0) ## torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c],[d],[d],[d]])
        

        new_func_candidates = predicted_token_ids[indices_in_mlm_tokens[1]==curr_edit_index] ## shape: (len(batch_ids_to_edit), k_per_location) e.g. [[x,y,z],[q,w,e]]
        new_func_candidates = repeat_interleave_unravel(new_func_candidates,num_initial_hypotheses) ## shape: (sum(num_initial_hypotheses), k_per_location) e.g. [[x],[x],[x],[y],[y],[y],[z],[z],[z],[q],[w],[e]]
        new_func_candidates = new_func_candidates.to(config['device'])
        

        tmp_hypotheses = torch.cat((tmp_hypotheses[ :, :curr_edit_index], new_func_candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], new_func_candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]

        # loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        # curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        lossid = 0
        # for lossid, lossname in enumerate(config["losses"]):
        #     if config['method'] == 'mlm-beamsearch-v1' and lossid > 0:
        #         continue
        #     with torch.no_grad():
        #         lossvalue = lossfns[lossid].compute_gold_loss(
        #             source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
        #             label_id=config['target_label_ids'][lossid],
        #         )
        #     torch.cuda.empty_cache()
        #     curr_loss += loss_weights[lossid] * lossvalue
        with torch.no_grad():
            lossvalue = lossfns[lossid].compute_gold_loss(
                source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                label_id=config['target_label_ids'][lossid],
            )
        torch.cuda.empty_cache()
        # curr_loss += loss_weights[lossid] * lossvalue
        curr_loss = torch.split(lossvalue, num_initial_tmp_hypotheses, dim=0)
        top_beams = [torch.topk(x, k=config['beam_size'], dim=-1, largest=False).indices for x in curr_loss]

        tmp_hypotheses = torch.split(tmp_hypotheses, num_initial_tmp_hypotheses, dim=0)
        for jx, ix in enumerate(batch_ids_to_edit):

            hypotheses[ix] = torch.cat([tmp_hypotheses[jx][top_beams[jx]], masked_sequence[ix][curr_edit_index+1:].unsqueeze(0).repeat(config['beam_size'],1)], dim=-1)
            
    return [mlm_tokenizer.batch_decode(x, skip_special_tokens=True) for x in hypotheses]

In [69]:
def get_beam_hypotheses_v0(source_text:str, 
                    masked_sequence:torch.Tensor, 
                    indices_in_mlm_tokens:Tuple[torch.Tensor],
                    predicted_token_ids:torch.Tensor,
                    mlm_tokenizer:transformers.AutoTokenizer, 
                    lossfns:List[lossbuilder.BaseLoss],
                    config:dict) -> List[List[str]]:
    """
    A function to get hypotheses of beam size via editing beam search with reranking.
    Run this function if config['method'] == 'mlm-beamsearch-v0' or config['method'] == 'mlm-beamsearch-v1'
    If config['method'] == 'mlm-beamsearch-v1', rerank beam only with fluency energy.
    If config['method'] == 'mlm-beamsearch-v0', rerank beam with a weighted sum of fluency and constraint energy.
    
    #ToDo
    #Implement mlm-beamsearch-v0 with allsat-primary and compare 
    
    params: 
        source_text: a prompt text 
        masked_sequence: token ids of original generation text with located indices masked. tokenized by MLM's tokenizer.
        indices_in_mlm_tokens: a result of running 
                                    `indices_in_mlm_tokens = (
                                                                inputs.input_ids == mlm_tokenizer.mask_token_id
                                                                ).nonzero(as_tuple=True)`
        predicted_token_ids: a result of running
                                    `predicted_token_ids = torch.topk(
                                                                logits[indices_in_mlm_tokens[0], indices_in_mlm_tokens[1], :],
                                                                k=config['k_per_location'],
                                                                dim=-1,).indices`
        mlm_tokenizer: tokenizer of MLM
        lossfns: a list of loss functions
        config: a dictionary of configurations
    
    returns:
        hypotheses: a list of a list of the beam number of hypotheses for each sample         
    """
    
    def repeat_interleave_unravel(arr,split_blocks):
        arr_ = torch.split(arr.T,1,dim=1)
        arr_ = [x.repeat(1,split_blocks[i]).reshape(-1,1) for i,x in enumerate(arr_)]
        arr_ = torch.cat(arr_,dim=0)
        return arr_
    
    hypotheses = list(torch.split(masked_sequence,1,dim=0)) ## [torch.tensor([[a],[b],[c]]), torch.tensor([[d]])]
    edit_indices = sorted(list(set(indices_in_mlm_tokens[1].tolist())))
    for curr_edit_index in edit_indices:
        
        batch_ids_to_edit = indices_in_mlm_tokens[0][indices_in_mlm_tokens[1]==curr_edit_index].tolist()
        num_initial_hypotheses = [len(hypotheses[i]) for i in batch_ids_to_edit] ## keep track of initial hypotheses count e.g. [3, 1]
        tmp_hypotheses = [hypotheses[i].repeat((config['k_per_location'],1)) for i in batch_ids_to_edit] ## [torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c]]), torch.tensor([[d],[d],[d]])]
        num_initial_tmp_hypotheses = [len(x) for x in tmp_hypotheses]
        tmp_hypotheses = torch.cat(tmp_hypotheses,dim=0) ## torch.tensor([[a],[b],[c],[a],[b],[c],[a],[b],[c],[d],[d],[d]])
        

        new_func_candidates = predicted_token_ids[indices_in_mlm_tokens[1]==curr_edit_index] ## shape: (len(batch_ids_to_edit), k_per_location) e.g. [[x,y,z],[q,w,e]]
        new_func_candidates = repeat_interleave_unravel(new_func_candidates,num_initial_hypotheses) ## shape: (sum(num_initial_hypotheses), k_per_location) e.g. [[x],[x],[x],[y],[y],[y],[z],[z],[z],[q],[w],[e]]
        new_func_candidates = new_func_candidates.to(config['device'])
        

        tmp_hypotheses = torch.cat((tmp_hypotheses[ :, :curr_edit_index], new_func_candidates),dim=-1) ## tmp_hypotheses: [(a,b,c),(a,b,c), ..., (a,b,c)], new_func_candidates: [(p,p,p), (q,q,q), ..., (v,v,v)]

        loss_weights = [1 - config['closs_weight'], config['closs_weight']]
        curr_loss = torch.zeros(tmp_hypotheses.shape[0]).to(config['device'])
        for lossid, lossname in enumerate(config["losses"]):
            with torch.no_grad():
                lossvalue = lossfns[lossid].compute_gold_loss(
                    source_text, mlm_tokenizer.batch_decode(tmp_hypotheses,skip_special_tokens=True),
                    label_id=config['target_label_ids'][lossid],
                )
            torch.cuda.empty_cache()
            curr_loss += loss_weights[lossid] * lossvalue
        curr_loss = torch.split(curr_loss, num_initial_tmp_hypotheses, dim=0)
        top_beams = [torch.topk(x, k=config['beam_size'], dim=-1, largest=False).indices for x in curr_loss]

        tmp_hypotheses = torch.split(tmp_hypotheses, num_initial_tmp_hypotheses, dim=0)
        for jx, ix in enumerate(batch_ids_to_edit):

            hypotheses[ix] = torch.cat([tmp_hypotheses[jx][top_beams[jx]], masked_sequence[ix][curr_edit_index+1:].unsqueeze(0).repeat(config['beam_size'],1)], dim=-1)
            
    return [mlm_tokenizer.batch_decode(x, skip_special_tokens=True) for x in hypotheses]

In [None]:
%%timeit
config['method'] = 'mlm-beamsearch-v0'
hypotheses = get_beam_hypotheses(source_text, 
                            masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer, 
                            lossfns,
                            config)

5.2 s ± 12.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
config['method'] = 'mlm-beamsearch-v1'
hypotheses = get_beam_hypotheses(source_text, 
                            masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer, 
                            lossfns,
                            config)

4.58 s ± 19.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
config['method'] = 'mlm-beamsearch-v0'
hypotheses = get_beam_hypotheses_v0(source_text, 
                            masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer, 
                            lossfns,
                            config)

4.9 s ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
config['method'] = 'mlm-beamsearch-v1'
hypotheses = get_beam_hypotheses_v1(source_text, 
                            masked_sequence, 
                            indices_in_mlm_tokens,
                            predicted_token_ids.indices,
                            mlm_tokenizer, 
                            lossfns,
                            config)

4.59 s ± 5.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
