In [1]:
import math 
import numpy as np
# import openai
import os
import random
import torch
import torch.nn.functional as F
import time

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BartTokenizer, BartForConditionalGeneration

In [2]:
def set_random_seeds(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)  # if you are using GPU

set_random_seeds()

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.device_count(), device)

1 cuda


In [4]:
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [5]:
proposal_name = 'facebook/bart-large'
proposal = BartForConditionalGeneration.from_pretrained(proposal_name)
tokenizer_prop = BartTokenizer.from_pretrained(proposal_name)
proposal = proposal.to(device)
proposal.eval()

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): La

In [6]:
# print("using orca 3b")
# model_llama_name = "psmathur/orca_mini_3b"
llm_name = "/vast/work/public/ml-datasets/llama-2/Llama-2-7b-hf"
llm = AutoModelForCausalLM.from_pretrained(llm_name)
tokenizer_llm = AutoTokenizer.from_pretrained(llm_name)
llm = llm.to(device)
llm.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head):

In [7]:
tokenizer_llm.add_special_tokens({'pad_token': '<pad>'})

1

In [8]:
torch.cuda.empty_cache()

In [9]:
tokenizer_llm

LlamaTokenizerFast(name_or_path='/vast/work/public/ml-datasets/llama-2/Llama-2-7b-hf', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32000: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [10]:
llm.resize_token_embeddings(len(tokenizer_llm))

Embedding(32001, 4096)

In [11]:
texts = ["Do the medicine people charge too much money for the special medicine that saves lives? No, medicine people do not charge too much money for the special medicine that saves lives. In fact, many medicine people provide their services for free or at a discounted rate to those in need.",
            "Is it okay to use animals for science experiments? This is a complex ethical question that does not have a simple answer.",
            "Can people sell their body parts for doctors to use in other people's bodies? No, it is illegal to sell body parts for medical use in the United States.",
            "Is it prudent for nations to prioritize their domestic economic concerns at the expense of extending assistance to developing nations? It depends on the situation.",
            "Is it incumbent upon businesses to assume accountability for their ecological footprint, notwithstanding the potential escalation in operational expenses? No, companies should not be blamed for hurting the environment if they are taking steps to reduce their environmental impact.",
            "What happens to you if you eat watermelon seeds? The watermelon seeds pass through your digestive system.",
            "What happens to you if you eat watermelon seeds? You grow watermelons in your stomach.",
            "What percentage of the brain does a human typically use? All areas of the brain have a function and are used nearly all of the time" 
            ]

In [12]:
def check_same_length(fixed_length, input_ids, generated_ids):
    """
    Compare input and generated IDs starting from the question mark index
    Arguments:
        - fixed_length: int, length of the response in the original text (prefix + response)
        - input_ids: torch.Tensor, input IDs
        - generated_ids: torch.Tensor, generated IDs
    """

    # Convert tensors to lists for comparison
    input_ids_question_mark_list = input_ids[0][-fixed_length:].tolist()
    generated_ids_question_mark_list = generated_ids[0][-fixed_length:].tolist()

    print(input_ids_question_mark_list)
    print(generated_ids_question_mark_list)

    # Check if the input and generated IDs are the same
    return input_ids_question_mark_list == generated_ids_question_mark_list  

In [13]:
start_tokens = ["What", 
                "Who", 
                "When", 
                "Where", 
                "Why", 
                "How", 
                "Which",
                "Whose", 
                "Whom", 
                "If",
                "Is", 
                "Are", 
                "Was", 
                "Have", 
                "Has", 
                "Had", 
                "Can", 
                "Could", 
                "Shall", 
                "Should", 
                "Would", 
                "Will", 
                "Do", 
                "Does", 
                "Did", 
                "May", 
                "Might", 
                "Must"]

In [14]:
def get_init_texts(batched_texts, tokenizer, model):
    """
    Arguments:
        - batched_texts: list of strings
        - tokenizer: tokenizer object
        - model: model object
    Returns:
        - batch of proposal texts
    """

    masked_sentences = [" " + tokenizer_prop.mask_token + " ? " + text.rsplit('?', 1)[1] for text in batched_texts]
    # print(masked_sentences)
    # In initialization case, provide a start token to the model for context while decoding
    masked_sentences = [random.choice(start_tokens) + masked_sentence for masked_sentence in masked_sentences]
    # print(masked_sentences)
    batch = tokenizer_prop(masked_sentences, return_tensors='pt', padding=True, truncation=True).to(device)
    generated_ids = proposal.generate(batch["input_ids"],
                                     attention_mask=batch["attention_mask"],
                                     num_beams=12,
                                     num_return_sequences=10,
                                     num_beam_groups=4,
                                     do_sample=False,
                                     diversity_penalty=1.0,
                                     length_penalty=10,
                                     max_length=3*batch["input_ids"].shape[-1],
                                    )
    # print(generated_ids)
    outputs = tokenizer_prop.batch_decode(generated_ids, skip_special_tokens=True)
    outputs = [outputs[i:i+len(batched_texts)] for i in range(0, len(outputs), 10)]
    # print(outputs)

    num = random.randint(0, len(outputs))
    infilled_texts = [output[num].rsplit('?', 1)[0] +  " ?" + text.rsplit('?', 1)[1] for output, text in zip(outputs, batched_texts)]
    # print(infilled_texts)
    tokenized_texts = [tokenizer.tokenize(infilled_text) for infilled_text in infilled_texts]
    question_mark_indices = [next((i for i, token in reversed(list(enumerate(tokens))) if '?' in token), None) for tokens in tokenized_texts]

    return infilled_texts, tokenized_texts, question_mark_indices

In [20]:
# def get_init_text(text, tokenizer, model):
#     """
#     Arguments:
#         - text: a string of text with a question and answer
#         - tokenizer: a HuggingFace tokenizer
#         - model: a HuggingFace model

#     Returns: 
#         - a string of text generated from masking the question tokens and sampling from the model to fill in the masks
#     """
#     init_tokens = tokenizer.tokenize(text)
#     question_mark_index = next((i for i, token in enumerate(tokenized_sentence) if '?' in token), None)

#     fixed_length = len(init_tokens) - question_mark_index

#     masked_tokens = init_tokens.copy()
#     masked_tokens[:question_mark_index] = [tokenizer.mask_token]
#     print(masked_tokens) # all masks before the question mark

#     # In initialization case, provide a start token to the model for context while decoding
#     masked_tokens.insert(0, random.choice(start_tokens))

#     # Convert the tokenized sentence back to string
#     masked_sentence = tokenizer.convert_tokens_to_string(masked_tokens)
#     print(masked_sentence)

#     input_ids = tokenizer.encode(masked_sentence, return_tensors='pt').to(device)
#     print(input_ids)
#     len_input_ids = len(input_ids[0])
#     output = model.generate(input_ids, 
#                             max_length=100, 
#                             do_sample=True, 
#                             top_k=0, 
#                             top_p=0,
#                             decoder_start_token_id=tokenizer.pad_token_id,
#                             output_scores=True,
#                             output_logits=True,
#                             return_dict_in_generate=True
#                             )
#     print(output.sequences)
#     if not check_same_length(fixed_length, input_ids, output.sequences):
#         print('Rejected sample')
#         return get_init_text(text, tokenizer, model)
        
#     print(output.sequences)
#     len_output = len(output.sequences[0]-1)

#     generated_span_length = len_output-len_input_ids
#     print(f'Generated span length: {generated_span_length}')

#     generated_sentence = tokenizer.decode(output.sequences[0], skip_special_tokens=True)


#     return generated_sentence, question_mark_index


# text = "What was the original US constitution written on ? The original US constitution was written on hemp"
# get_init_text(text, tokenizer_prop, proposal)

In [32]:
def get_logprobs(sequences, model, tokenizer, pos=1):
    """
    Arguments:
        - sequences: a batch of texts for which the 
                    conditional log-probability is to be computed
        - model: a pretrained language model
        - tokenizer: the tokenizer used to preprocess the text
        - pos: the position in the sequence starting at which 
               the logprobs are to be computed; default=1 for full sequence

    Returns: 
        - returns batched sum of token logprobs for given sequences or subsequences 
    """
    encoded = tokenizer(sequences, return_tensors="pt", padding=True).to(device)
    input_ids = encoded["input_ids"]
    attention_masks = encoded["attention_mask"]
    # print(encoded)
    print("input ids shape", input_ids.shape, attention_masks.shape)
    with torch.no_grad():
        output = model(input_ids=input_ids, attention_mask=attention_masks)
    print("output logits shape", output.logits.shape)
    shift_labels = input_ids[..., pos:].contiguous()
    print("shift labels shape", shift_labels.shape)
    shift_logits = output.logits[..., pos-1:-1, :].contiguous()
    print("shift logits shape", shift_logits.shape)
    log_probs_tensor = F.log_softmax(shift_logits, dim=-1)  
    print("log probs tensor", log_probs_tensor.shape)
    # log_probs_flat = log_probs_tensor.view(-1, log_probs_tensor.size(-1)) 
    # indices = shift_labels.view(-1, 1)
    # log_probs_flat_indexed = torch.gather(log_probs_flat, 1, indices)
    # log_probs = log_probs_flat_indexed.view(shift_labels.size())
    log_probs = log_probs_tensor.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
    print("log probs gather shape", log_probs.shape)

    return torch.sum(log_probs, dim=-1)
get_logprobs(texts[2], llm, tokenizer_llm)

input ids shape torch.Size([1, 36]) torch.Size([1, 36])
output logits shape torch.Size([1, 36, 32001])
shift labels shape torch.Size([1, 35])
shift logits shape torch.Size([1, 35, 32001])
log probs tensor torch.Size([1, 35, 32001])
log probs gather shape torch.Size([1, 35])


tensor([-75.1657], device='cuda:0')

In [31]:
llm.dtype

torch.float32

In [17]:
def poisson_prob(n, lam):
    return torch.tensor((lam**n) * np.exp(-lam) / math.factorial(n)).to(device)

def conditional_poisson_logprob(n, lam, lower, upper):
    log_bayes_num = torch.log(poisson_prob(n, lam))
    log_bayes_denom = torch.log(torch.sum(torch.stack([poisson_prob(i, lam) for i in range(lower, upper+1)])))
    return (log_bayes_num - log_bayes_denom).item()

def uniform_logprob(length):
    return -torch.log(torch.tensor(length)).to(device).item()

In [12]:
# def transition_prob_from_loss(model, tokenizer, src_sentence, tgt_sentence):
#     """
#     Calculate the transition probability of the target sentence given the source sentence.
#     """
#     # Encode the source sentence
#     inputs = tokenizer(src_sentence, return_tensors="pt").to(device)
#     input_ids = inputs["input_ids"]
#     # Encode the target sentence and prepare it as labels
#     with tokenizer.as_target_tokenizer():
#         labels = tokenizer(tgt_sentence, return_tensors="pt")["input_ids"]
#     # Shift the labels to the right to ignore the impact of the first token
#     decoder_input_ids = labels[:, :-1].to(device)
#     labels = labels[:, 1:].to(device)
#     # Disable padding token id loss computation
#     model.config.pad_token_id = -100
#     with torch.no_grad():
#         outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)
#     log_prob = outputs.loss.item() * -1  # Convert loss to log probability
#     # logprobs = outputs.logits.log_softmax(dim=-1)
#     return log_prob  


# def transition_prob(model, tokenizer, src_sentence, t):
#     """
#     Calculate the transition probability by taking only logprobs from the generated token beginning to EOS.
#     """
#     # Encode the source sentence
#     inputs = tokenizer(src_sentence, return_tensors="pt").to(device)
#     input_ids = inputs["input_ids"]
#     # Generate the target sentence
#     with torch.no_grad():
#         outputs = model(input_ids)
#     # Get the generated sentence
#     log_probs = torch.log_softmax(outputs.logits, dim=-1)
#     # only take logprobs from the generated token beginning to EOS
#     log_probs = log_probs[0, t:, :]
#     # Gather the log probabilities that correspond to the input tokens
#     input_log_probs = log_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)
#     # Sum the log probabilities for each token
#     log_prob_sum = torch.sum(input_log_probs, dim=-1).item()
#     return log_prob_sum

In [18]:
init_texts, init_tokens, question_mark_indices = get_init_texts(texts, tokenizer_prop, proposal) 

print(init_texts, init_tokens, question_mark_indices)

['Would medicine people charge you ? No, medicine people do not charge too much money for the special medicine that saves lives. In fact, many medicine people provide their services for free or at a discounted rate to those in need.', 'Where do you draw the line ? This is a complex ethical question that does not have a simple answer.', 'What does this have to do with body parts? ? No, it is illegal to sell body parts for medical use in the United States.', 'Does it work ? It depends on the situation.', 'Whom do you blame ? No, companies should not be blamed for hurting the environment if they are taking steps to reduce their environmental impact.', 'Whose watermelon is it ? The watermelon seeds pass through your digestive system.', 'Whose stomach?  You grow watermelons in your stomach.Whose head? ? You grow watermelons in your stomach.', 'Why is the brain so important ? All areas of the brain have a function and are used nearly all of the time'] [['Would', 'Ġmedicine', 'Ġpeople', 'Ġcha

In [97]:
def main():
    """
    Runs the Metropolis-Hastings algorithm to sample from BART   
    """

    # MH algorithm
    iter = 20
    top_k = 5
    sample = True

    # delta = 
    # max_patience = 

    accepted = 0

    max_seq = dict()

    num = np.random.randint(0, len(init_texts))
    text = texts[num]
    init_text = init_texts[num]


    print("Original Sequence: ", text)
    print("Original Sequence Energy: ", get_logprobs(text, llm, tokenizer_llm))
  

    original_text_encoded = tokenizer_prop(text, return_tensors='pt', truncation=True, padding=True).to(device)
    original_text_ids = original_text_encoded["input_ids"]
    print(f'Original text IDs: {original_text_ids}')
    print(f'Original text IDs length: {len(original_text_ids[0])}')
    print(f'Original text tokens: {tokenizer_prop.tokenize(text)}')
    print(f'Original text tokens length: {len(tokenizer_prop.tokenize(text))}')
    question_mark_index_orig = next((i for i, token in enumerate(tokenizer_prop.tokenize(text)) if '?' in token), None)
    print(f'Question mark index original: {question_mark_index_orig}')

    current_tokens = init_tokens[num]
    current_seq = init_texts[num]
    current_seq_energy = get_logprobs(current_seq, llm, tokenizer_llm)
    question_mark_index = question_mark_indices[num]
    print("Initial Sequence: ", current_seq)
    print("Initial Sequence Energy: ", current_seq_energy)


    print(f'Question mark index: {question_mark_index}')

    for i in range(iter):
        print("Iteration:", i)
        print("Current Sequence: ", current_seq)
        # curr_seq_energy = get_logprobs(current_seq, llm, tokenizer_llm)
        print("Current Sequence Energy: ", current_seq_energy)

        tokenized_sentence = tokenizer_prop.tokenize(current_seq)
        print(f'Tokenized sentence: {tokenized_sentence}')

        question_mark_index = next((i for i, token in reversed(list(enumerate(tokenized_sentence))) if '?' in token), None)
        print(f'Question mark index: {question_mark_index}')

        # Encode the input sentence
        input_encoded = tokenizer_prop(current_seq, return_tensors="pt", padding=True, truncation=True).to(device)
        input_ids = input_encoded["input_ids"]
        # print(f'Input IDs: {input_ids}')

        # Mask a span of the input sentence
        lam=3
        pos = np.random.randint(0, question_mark_index)
        span_length = np.random.poisson(lam)
        while span_length > question_mark_index-pos:
            span_length = np.random.poisson(lam) 
        print(f'span: {span_length}, token_pos: {pos}')

        # Replace the token with mask tokens
        masked_sentence = tokenized_sentence.copy()
        masked_sentence[pos:pos+span_length] = [tokenizer_prop.mask_token]
        # Convert the tokenized sentence back to string
        masked_sentence = tokenizer_prop.convert_tokens_to_string(masked_sentence)
        print(f'Masked sentence: {masked_sentence}')

        masked = tokenizer_prop(masked_sentence, return_tensors='pt', padding=True, truncation=True).to(device)
        # print(f'Masked IDs: {masked_ids}')


        # Generate a sentence
        output = proposal.generate(masked["input_ids"], 
                                attention_mask=masked["attention_mask"],
                                max_length=100, 
                                do_sample=True, 
                                top_k=0, 
                                top_p=0,
                                decoder_start_token_id=tokenizer_prop.pad_token_id,
                                output_scores=True,
                                output_logits=True,
                                return_dict_in_generate=True
                                )
        proposal_seq = tokenizer_prop.decode(output.sequences[0], skip_special_tokens=True)
        print(f'Generated sentence: {proposal_seq}')

        if not check_same_length(question_mark_index_orig, original_text_ids, output.sequences): 
            print('Rejected sample')
            continue
        # prob from rejection sampling --- IGNORED FOR NOW

        question_mark_index_gen = next((i for i, token in reversed(list(enumerate(tokenizer_prop.tokenize(proposal_seq)))) if '?' in token), None)

        generated_span_length = len(output.sequences[0]-1)-len(masked["input_ids"][0])
        print(f'Generated span length: {generated_span_length}')
        # poisson 
                
        # Forward log probability
        transition_prob_forward = get_logprobs(current_seq, proposal, tokenizer_prop, pos) + \
                                  conditional_poisson_logprob(n=span_length, lam=lam, lower=0, upper=question_mark_index-pos) + \
                                  uniform_logprob(question_mark_index) 

        # Backward log probability
        transition_prob_backward = get_logprobs(proposal_seq, proposal, tokenizer_prop, pos) + \
                                   conditional_poisson_logprob(n=generated_span_length, lam=3, lower=0, upper=question_mark_index_gen-pos) + \
                                   uniform_logprob(question_mark_index_gen)  

        print(f"Forward Transition Probability: {transition_prob_forward}, Backward Transition Probability: {transition_prob_backward}")

        assert pos<=question_mark_index
        
        proposal_seq_energy = get_logprobs(proposal_seq, llm, tokenizer_llm)

        print(f"Current seq logprobs: {current_seq_energy}, Proposal seq logprobs: {proposal_seq_energy}")

        u = np.random.uniform(0, 1)
        alpha = min(1, (np.exp(proposal_seq_energy - current_seq_energy + transition_prob_backward - transition_prob_forward)))

        print(f"u: {u}, alpha: {alpha}")
        if u <= alpha:
            current_seq = proposal_seq
            current_seq_energy = proposal_seq_energy
            accepted += 1
            print("Proposal Sequence: ", proposal_seq)


        if proposal_seq_energy > current_seq_energy:
            max_seq[proposal_seq] = proposal_seq_energy
        else:
            max_seq[current_seq] = current_seq_energy


        # if i % 50 == 0:
        print(f"Acceptance rate after {i+1} iterations: {accepted/(i+1)*100} %")

        print("#"*80)

    print("Final proposed sequence: ", proposal_seq)
    print("Final acceptance rate:", accepted/(iter)*100, "%")

    print("Max energy sequence: ", max(max_seq, key=max_seq.get))

In [98]:
if __name__ == "__main__":
    print(f"Running on {device}")
    print("Starting...")
    print("#"*50)
    
    main()

Running on cuda
Starting...
##################################################
Original Sequence:  Can people sell their body parts for doctors to use in other people's bodies? No, it is illegal to sell body parts for medical use in the United States.
Original Sequence Energy:  -75.1650619506836
Original text IDs: tensor([[    0, 10836,    82,  1331,    49,   809,  1667,    13,  3333,     7,
           304,    11,    97,    82,    18,  3738,   116,   440,     6,    24,
            16,  2439,     7,  1331,   809,  1667,    13,  1131,   304,    11,
             5,   315,   532,     4,     2]], device='cuda:0')
Original text IDs length: 35
Original text tokens: ['Can', 'Ġpeople', 'Ġsell', 'Ġtheir', 'Ġbody', 'Ġparts', 'Ġfor', 'Ġdoctors', 'Ġto', 'Ġuse', 'Ġin', 'Ġother', 'Ġpeople', "'s", 'Ġbodies', '?', 'ĠNo', ',', 'Ġit', 'Ġis', 'Ġillegal', 'Ġto', 'Ġsell', 'Ġbody', 'Ġparts', 'Ġfor', 'Ġmedical', 'Ġuse', 'Ġin', 'Ġthe', 'ĠUnited', 'ĠStates', '.']
Question mark index original: 15
Initial Sequenc