In [None]:
!pip install bitsandbytes

In [None]:
!pip install ipywidgets
!jupyter nbextension enable --py widgetsnbextension

In [None]:
true_addresses_input_path = 'data/medium_ubuntu_true_addresses.pkl'

ar_conv_prompts_input_path = 'prompts/ar_conv_prompts.pkl'
ar_struct_prompts_input_path = 'prompts/ar_struct_prompts.pkl'
ar_conv_struct_prompts_input_path = 'prompts/ar_conv_struct_prompts.pkl'
ar_struct_summ_prompts_input_path = 'prompts/ar_struct_summ_prompts.pkl'
ar_struct_desc_prompts_input_path = 'prompts/ar_struct_desc_prompts.pkl'
ar_struct_summ_desc_prompts_input_path = 'prompts/ar_struct_summ_desc_prompts.pkl'

ar_conv_prompts_output_path = 'output/ar_conv_out.pkl'
ar_struct_prompts_output_path =  'output/ar_struct_out.pkl'
ar_conv_struct_prompts_output_path = 'output/ar_conv_struct_out.pkl'
ar_struct_summ_prompts_output_path = 'output/ar_struct_summ_out.pkl'
ar_struct_desc_prompts_output_path = 'output/ar_struct_desc_out.pkl'
ar_struct_summ_desc_prompts_output_path = 'output/ar_struct_summ_desc_out.pkl'

# Model

login ad hugginface

In [None]:
# !pip install --upgrade huggingface_hub

from huggingface_hub import login
login("YOUR_KEY")

In [None]:
!ls data

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
import transformers
from tqdm.notebook import tqdm 
import torch

model_name = "meta-llama/Llama-2-13b-chat-hf"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map={"": 0}
)

tokenizer = AutoTokenizer.from_pretrained(model_name)



In [None]:
!pip install evaluate

# Perplexity Measure

In [None]:
from evaluate import load
from statistics import mean, median, stdev
import os
from os.path import exists, join, isdir
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
import csv
from torch.nn import CrossEntropyLoss
import datasets
import numpy as np
import torch
import math
from typing import List
from evaluate import logging
import json

In [None]:
perplexity = load("perplexity", module_type="metric")

In [None]:
def compute_perplexity(
        predictions, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None
    ):

        if device is not None:
            assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
            if device == "gpu":
                device = "cuda"
        else:
            device = "cuda" if torch.cuda.is_available() else "cpu"


        # if batch_size > 1 (which generally leads to padding being required), and
        # if there is not an already assigned pad_token, assign an existing
        # special token to also be the padding token
        if tokenizer.pad_token is None and batch_size > 1:
            existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
            # check that the model already has at least one special token defined
            assert (
                len(existing_special_tokens) > 0
            ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
            # assign one of the special tokens to also be the pad token
            tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

        if add_start_token and max_length:
            # leave room for <BOS> token to be added:
            assert (
                tokenizer.bos_token is not None
            ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
            max_tokenized_len = max_length - 1
        else:
            max_tokenized_len = max_length

        encodings = tokenizer(
            predictions,
            add_special_tokens=False,
            padding=False,
            truncation=True if max_tokenized_len else False,
            max_length=max_tokenized_len,
            return_tensors="pt",
            return_attention_mask=True,
        ).to(device)

        encoded_texts = encodings["input_ids"]
        attn_masks = encodings["attention_mask"]

        # check that each input is long enough:
        if add_start_token:
            assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
        else:
            assert torch.all(
                torch.ge(attn_masks.sum(1), 2)
            ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."

        ppls = []
        loss_fct = CrossEntropyLoss(reduction="none")

        for start_index in range(0, len(encoded_texts), batch_size):
            end_index = min(start_index + batch_size, len(encoded_texts))
            encoded_batch = encoded_texts[start_index:end_index]
            attn_mask = attn_masks[start_index:end_index]

            if add_start_token:
                bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
                encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
                attn_mask = torch.cat(
                    [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
                )

            labels = encoded_batch

            with torch.no_grad():
                out_logits = model(encoded_batch, attention_mask=attn_mask).logits

            shift_logits = out_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

            perplexity_batch = torch.exp(
                (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
                / shift_attention_mask_batch.sum(1)
            )

            ppls += perplexity_batch.tolist()

        return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

In [None]:
import math

def extract_probability(ppls, len_sents):
    """
    Extract the probability from the perplexities and compute the log probability.

    Parameters:
    - ppls: Dictionary containing perplexities.
    - len_sents: Length of the sentences.

    Returns:
    - log_prob: Log probability computed from the extracted perplexities.
    """
    # Extract perplexities from the dictionary
    ppls = ppls['perplexities']
    
    # Get the last perplexity
    prob_last = ppls[0]
    
    # Calculate log probability
    log_prob = math.log(prob_last) * (-len_sents)
    
    # Return the computed log probability
    return log_prob

In [None]:
def compute_conditional_perplexity(tokenizer, sents):
    """
    Compute the conditional perplexity for a sequence of sentences using a specified language model.

    Parameters:
    - tokenizer: An object implementing tokenization methods for the language model.
    - sents: A list of sentences for which conditional perplexity needs to be computed.

    Returns:
    - cond_ppl: The computed conditional perplexity for the given sequence of sentences.
    """

    # Initialize variables
    c = ''
    probabilities_1_to_n_1 = []
    probabilities_1_to_n = []
    cond_ppls = []
    len_c = 0
    
    # Iterate through sentences to compute conditional perplexity
    for i in range(1, len(sents)):
        # Update accumulated context
        c = str(c + sents[i-1])
        t = str(c + sents[i])
        
        # Calculate lengths of context and current sentence in tokens
        len_c = len_c + len(tokenizer.encode(sents[i-1]))
        len_t = len_c + len(tokenizer.encode(sents[i]))
        
        # Extract probabilities for context c and current sentence t using the language model
        p_c = extract_probability(compute_perplexity(predictions=[c], batch_size=1, add_start_token=False), len_c)
        p_t = extract_probability(compute_perplexity(predictions=[t], batch_size=1, add_start_token=False), len_t)
        
        # Append probabilities to lists
        probabilities_1_to_n_1.append(p_c)
        probabilities_1_to_n.append(p_t)
        
        len_sents = [len(tokenizer.encode(i)) for i in sents]
        
        # Calculate conditional perplexity using the formula
        cond_ppl = math.exp(1./len_sents[i]*(p_c - p_t))
        
        # Append the calculated conditional perplexity to the list
        cond_ppls.append(cond_ppl)
        
    # Compute the mean of conditional perplexities
    cond_ppl = mean(cond_ppls)
    
    # Return the computed conditional perplexity
    return cond_ppl, cond_ppls

In [None]:
def compute_single_conditional_perplexity(tokenizer, sents):
    """
    Compute the conditional perplexity for a sentence follwoing another using a specified language model.

    Parameters:
    - tokenizer: An object implementing tokenization methods for the language model.
    - sents: A list of sentences for which conditional perplexity needs to be computed.

    Returns:
    - cond_ppl: The computed conditional perplexity for the given sequence of sentences.
    """
    c = str(sents[0])
    t = str(c + sents[1])
        
    # Calculate lengths of context and current sentence in tokens
    len_c = len(tokenizer.encode(sents[0]))
    len_t = len_c + len(tokenizer.encode(sents[1]))
        
    # Extract probabilities for context c and current sentence t using the language model
    p_c = extract_probability(compute_perplexity(predictions=[c], batch_size=1, add_start_token=False), len_c)
    p_t = extract_probability(compute_perplexity(predictions=[t], batch_size=1, add_start_token=False), len_t)
        
        
    len_sents = len(tokenizer.encode(sents[1]))
        
    # Calculate conditional perplexity using the formula
    cond_ppl = math.exp(1./len_sents*(p_c - p_t))
    
    # Return the computed conditional perplexity
    return cond_ppl

# EXPERIMENTS

In [None]:
import pickle

with open(true_addresses_input_path, 'rb') as f:
    true_addresses = pickle.load(f)

# CONV

In [None]:
import pickle

with open(ar_conv_prompts_input_path, 'rb') as f:
    ar_conv_prompts = pickle.load(f)


ar_conv_output = []
ar_conv_output_cppls = []
ar_conv_output_distribution_cppls = [] 

candidates = ["[ALEX]", "[BENNY]", "[CAM]", "[DANA]", "[OTHER]"]

for idx in range(len(ar_conv_prompts)):
    
    prompt = ar_conv_prompts[idx]
    
    cppls = []
    
    for r in candidates:
    
        resp = r

        cppl_value = compute_single_conditional_perplexity(tokenizer, [prompt, resp])

        cppls.append(cppl_value)
    
    min_cppls = min(cppls)
    res_out = cppls.index(min_cppls)
    
    ar_conv_output.append(res_out)
    ar_conv_output_cppls.append(min_cppls)
    ar_conv_output_distribution_cppls.append(cppls)
    

ar_conv_row = []

for idx in range(len(ar_conv_output)):
    if ar_conv_output[idx] == true_addresses[idx]:
        ar_conv_row.append(1)
    else:
        ar_conv_row.append(0)

print(len(ar_conv_output))
print(sum(ar_conv_row)/len(ar_conv_row))

ar_conv_out = {
    "output": ar_conv_output,
    "output_cppls": ar_conv_output_cppls,
    "output_distribution_cppls": ar_conv_output_distribution_cppls,
    "row": ar_conv_row
}


with open(ar_conv_prompts_output_path, 'wb') as f:
    pickle.dump(ar_conv_out, f)

In [None]:
print(len(ar_conv_output))
print(sum(ar_conv_row)/len(ar_conv_row))

# STRUCT

In [None]:
import pickle

with open(ar_struct_prompts_input_path, 'rb') as f:
    ar_struct_prompts = pickle.load(f)


ar_struct_output = []
ar_struct_output_cppls = []
ar_struct_output_distribution_cppls = []

candidates = ["[ALEX]", "[BENNY]", "[CAM]", "[DANA]", "[OTHER]"]

for idx in range(len(ar_struct_prompts)):

    prompt = ar_struct_prompts[idx]

    cppls = []

    for r in candidates:

        resp = r

        cppl_value = compute_single_conditional_perplexity(tokenizer, [prompt, resp])

        cppls.append(cppl_value)

    min_cppls = min(cppls)
    res_out = cppls.index(min_cppls)

    ar_struct_output.append(res_out)
    ar_struct_output_cppls.append(min_cppls)
    ar_struct_output_distribution_cppls.append(cppls)


ar_struct_row = []

for idx in range(len(ar_struct_output)):
    if ar_struct_output[idx] == true_addresses[idx]:
        ar_struct_row.append(1)
    else:
        ar_struct_row.append(0)

print(len(ar_struct_output))
print(sum(ar_struct_row)/len(ar_struct_row))

ar_struct_out = {
    "output": ar_struct_output,
    "output_cppls": ar_struct_output_cppls,
    "output_distribution_cppls": ar_struct_output_distribution_cppls,
    "row": ar_struct_row
}


with open(ar_struct_prompts_output_path, 'wb') as f:
    pickle.dump(ar_struct_out, f)

In [None]:
print(len(ar_struct_output))
print(sum(ar_struct_row)/len(ar_struct_row))

# CONV+STRUCT

In [None]:
import pickle

with open(ar_conv_struct_prompts_input_path, 'rb') as f:
    ar_conv_struct_prompts = pickle.load(f)


ar_conv_struct_output = []
ar_conv_struct_output_cppls = []
ar_conv_struct_output_distribution_cppls = []

candidates = ["[ALEX]", "[BENNY]", "[CAM]", "[DANA]", "[OTHER]"]

for idx in range(len(ar_conv_struct_prompts)):

    prompt = ar_conv_struct_prompts[idx]

    cppls = []

    for r in candidates:

        resp = r

        cppl_value = compute_single_conditional_perplexity(tokenizer, [prompt, resp])

        cppls.append(cppl_value)

    min_cppls = min(cppls)
    res_out = cppls.index(min_cppls)

    ar_conv_struct_output.append(res_out)
    ar_conv_struct_output_cppls.append(min_cppls)
    ar_conv_struct_output_distribution_cppls.append(cppls)


ar_conv_struct_row = []

for idx in range(len(ar_conv_struct_output)):
    if ar_conv_struct_output[idx] == true_addresses[idx]:
        ar_conv_struct_row.append(1)
    else:
        ar_conv_struct_row.append(0)

print(len(ar_conv_struct_output))
print(sum(ar_conv_struct_row)/len(ar_conv_struct_row))

ar_conv_struct_out = {
    "output": ar_conv_struct_output,
    "output_cppls": ar_conv_struct_output_cppls,
    "output_distribution_cppls": ar_conv_struct_output_distribution_cppls,
    "row": ar_conv_struct_row
}


with open(ar_conv_struct_prompts_output_path, 'wb') as f:
    pickle.dump(ar_conv_struct_out, f)

In [None]:
print(len(ar_conv_struct_output))
print(sum(ar_conv_struct_row)/len(ar_conv_struct_row))

# STRUCT + SUMM

In [None]:
import pickle

with open(ar_struct_summ_prompts_input_path, 'rb') as f:
    ar_struct_summ_prompts = pickle.load(f)


print(ar_struct_summ_prompts[100])
    
ar_struct_summ_output = []
ar_struct_summ_output_cppls = []
ar_struct_summ_output_distribution_cppls = [] 

candidates = ["[ALEX]", "[BENNY]", "[CAM]", "[DANA]", "[OTHER]"]

for idx in range(len(ar_struct_summ_prompts)):
    
    prompt = ar_struct_summ_prompts[idx]
    
    cppls = []
    
    for r in candidates:
    
        resp = r

        cppl_value = compute_single_conditional_perplexity(tokenizer, [prompt, resp])

        cppls.append(cppl_value)
    
    min_cppls = min(cppls)
    res_out = cppls.index(min_cppls)
    
    ar_struct_summ_output.append(res_out)
    ar_struct_summ_output_cppls.append(min_cppls)
    ar_struct_summ_output_distribution_cppls.append(cppls)
    

ar_struct_summ_row = []

for idx in range(len(ar_struct_summ_output)):
    if ar_struct_summ_output[idx] == true_addresses[idx]:
        ar_struct_summ_row.append(1)
    else:
        ar_struct_summ_row.append(0)

print(len(ar_struct_summ_output))
print(sum(ar_struct_summ_row)/len(ar_struct_summ_row))

ar_struct_summ_out = {
    "output": ar_struct_summ_output,
    "output_cppls": ar_struct_summ_output_cppls,
    "output_distribution_cppls": ar_struct_summ_output_distribution_cppls,
    "row": ar_struct_summ_row
}


with open(ar_struct_summ_prompts_output_path, 'wb') as f:
    pickle.dump(ar_struct_summ_out, f)

In [None]:
print(len(ar_struct_summ_output))
print(sum(ar_struct_summ_row)/len(ar_struct_summ_row))

# STRUCT + DESC

In [None]:
import pickle

with open(ar_struct_desc_prompts_input_path, 'rb') as f:
    ar_struct_desc_prompts = pickle.load(f)


print(ar_struct_desc_prompts[100])

ar_struct_desc_output = []
ar_struct_desc_output_cppls = []
ar_struct_desc_output_distribution_cppls = []

candidates = ["[ALEX]", "[BENNY]", "[CAM]", "[DANA]", "[OTHER]"]

for idx in range(len(ar_struct_desc_prompts)):

    prompt = ar_struct_desc_prompts[idx]

    cppls = []

    for r in candidates:

        resp = r

        cppl_value = compute_single_conditional_perplexity(tokenizer, [prompt, resp])

        cppls.append(cppl_value)

    min_cppls = min(cppls)
    res_out = cppls.index(min_cppls)

    ar_struct_desc_output.append(res_out)
    ar_struct_desc_output_cppls.append(min_cppls)
    ar_struct_desc_output_distribution_cppls.append(cppls)


ar_struct_desc_row = []

for idx in range(len(ar_struct_desc_output)):
    if ar_struct_desc_output[idx] == true_addresses[idx]:
        ar_struct_desc_row.append(1)
    else:
        ar_struct_desc_row.append(0)

print(len(ar_struct_desc_output))
print(sum(ar_struct_desc_row)/len(ar_struct_desc_row))

ar_struct_desc_out = {
    "output": ar_struct_desc_output,
    "output_cppls": ar_struct_desc_output_cppls,
    "output_distribution_cppls": ar_struct_desc_output_distribution_cppls,
    "row": ar_struct_desc_row
}


with open(ar_struct_desc_prompts_output_path, 'wb') as f:
    pickle.dump(ar_struct_desc_out, f)

In [None]:
print(len(ar_struct_desc_output))
print(sum(ar_struct_desc_row)/len(ar_struct_desc_row))

# STRUCT + SUMM + DESC

In [None]:
import pickle

with open(ar_struct_summ_desc_prompts_input_path, 'rb') as f:
    ar_struct_summ_desc_prompts = pickle.load(f)


print(ar_struct_summ_desc_prompts[100])
    
ar_struct_summ_desc_output = []
ar_struct_summ_desc_output_cppls = []
ar_struct_summ_desc_output_distribution_cppls = [] 

candidates = ["[ALEX]", "[BENNY]", "[CAM]", "[DANA]", "[OTHER]"]

for idx in range(len(ar_struct_summ_desc_prompts)):
    
    prompt = ar_struct_summ_desc_prompts[idx]
    
    cppls = []
    
    for r in candidates:
    
        resp = r

        cppl_value = compute_single_conditional_perplexity(tokenizer, [prompt, resp])

        cppls.append(cppl_value)
    
    min_cppls = min(cppls)
    res_out = cppls.index(min_cppls)
    
    ar_struct_summ_desc_output.append(res_out)
    ar_struct_summ_desc_output_cppls.append(min_cppls)
    ar_struct_summ_desc_output_distribution_cppls.append(cppls)
    

ar_struct_summ_desc_row = []

for idx in range(len(ar_struct_summ_desc_output)):
    if ar_struct_summ_desc_output[idx] == true_addresses[idx]:
        ar_struct_summ_desc_row.append(1)
    else:
        ar_struct_summ_desc_row.append(0)

print(len(ar_struct_summ_desc_output))
print(sum(ar_struct_summ_desc_row)/len(ar_struct_summ_desc_row))

ar_struct_summ_desc_out = {
    "output": ar_struct_summ_desc_output,
    "output_cppls": ar_struct_summ_desc_output_cppls,
    "output_distribution_cppls": ar_struct_summ_desc_output_distribution_cppls,
    "row": ar_struct_summ_desc_row
}


with open(ar_struct_summ_desc_prompts_output_path, 'wb') as f:
    pickle.dump(ar_struct_summ_desc_out, f)

In [None]:
print(len(ar_struct_summ_desc_output))
print(sum(ar_struct_summ_desc_row)/len(ar_struct_summ_desc_row))