## paper: [Contrastive Decoding: Open-ended Text Generation as Optimization](https://arxiv.org/abs/2210.15097)


In [2]:
import torch
import transformers

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
amateur_lm = transformers.AutoModelForCausalLM.from_pretrained('gpt2').to(device)
expert_lm = transformers.AutoModelForCausalLM.from_pretrained('gpt2-medium').to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
prompt = "Barack Obama was born in Honolulu, Hawaii. He was born in"

def generate_builtin(model, prompt, max_len=100, temperature = 0.8):
    model_inputs = tokenizer(prompt, return_tensors="pt").to(device)

    gen_tokens = model.generate(
        **model_inputs,
        do_sample=True,
        temperature=temperature,
        max_length=max_len,
    )
    return tokenizer.batch_decode(gen_tokens)[0]

print(
    f"""built-in generate for prompt: {prompt}
    responses: {generate_builtin(expert_lm, prompt, max_len=15)}"""
)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
  attn_output = torch.nn.functional.scaled_dot_product_attention(


built-in generate for prompt: Barack Obama was born in Honolulu, Hawaii. He was born in
    responses: Barack Obama was born in Honolulu, Hawaii. He was born in Chicago


In [52]:
def my_generate(model, prompt, max_len=100, temperature=0.8):
    model_inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Start the generation with the initial inputs
    output_tokens = model_inputs["input_ids"]
    attention_mask = model_inputs["attention_mask"].to(device)

    past_key_values = None

    for i in range(max_len):
        if past_key_values is None:
            # For the first step, use the full input
            input = model_inputs
        else:
            # For subsequent steps, use the last generated token and past key values
            input = {
                "input_ids": output_tokens[:, -1].unsqueeze(0),  # Use only the last token
                "attention_mask": attention_mask,
                "past_key_values": past_key_values,
            }

        # Forward pass through the model
        resp = model(**input)
        logits = resp.logits[:, -1, :]  # Get the logits for the last token
        past_key_values = resp.past_key_values  # Update past key values

        # Sample next token
        prob = torch.softmax(logits / temperature, dim=-1)
        next_token = torch.multinomial(prob, num_samples=1).to(device)

        # Append the next token to the output
        output_tokens = torch.cat([output_tokens, next_token], dim=1)

        # Update the attention mask (add 1 for the new token)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1)).to(device)], dim=1)

    # Decode and return the generated sequence
    return tokenizer.decode(output_tokens.squeeze(), skip_special_tokens=True)


temp = 0.3
print(
    f"""
    prompt: {prompt}
    expert lm response: {my_generate(expert_lm, prompt, max_len=10, temperature=temp)}
    amature lm response: {my_generate(amateur_lm, prompt, max_len=10, temperature=temp)}
    """
)



    prompt: Barack Obama was born in Honolulu, Hawaii. He was born in
    expert lm response: Barack Obama was born in Honolulu, Hawaii. He was born in the United States on December 20, 1961.

    amature lm response: Barack Obama was born in Honolulu, Hawaii. He was born in Honolulu, Hawaii. (AP Photo/Mark Humph
    


In [62]:
def my_generate_cd_score(models, prompt, max_len=100, temperature=0.8):
    model_inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Start the generation with the initial inputs
    output_tokens = model_inputs["input_ids"]
    attention_mask = model_inputs["attention_mask"].to(device)

    past_key_values = [None] * 2
    logits = [None] * 2

    for _ in range(max_len):
        for i, model in enumerate(models): # exp first, ama snd       
            if past_key_values[i] is None:
                # For the first step, use the full input
                input = model_inputs
            else:
                # For subsequent steps, use the last generated token and past key values
                input = {
                    "input_ids": output_tokens[:, -1].unsqueeze(0),  # Use only the last token
                    "attention_mask": attention_mask,
                    "past_key_values": past_key_values[i],
                }

            # Forward pass through the model
            resp = model(**input)
            logits[i] = resp.logits[:, -1, :]  # Get the logits for the last token, shape: [batch, 1, n_token]
            past_key_values[i] = resp.past_key_values  # Update past key values

        logits_tensor = torch.cat(logits, dim=0).squeeze() # shape: [2, n_token]
        # Sample next token
        logprob = torch.log(torch.softmax(logits_tensor / temperature, dim=-1))
        cd_scores = logprob[0, :] - logprob[1, :]
        next_token = torch.argmax(cd_scores)

        # Append the next token to the output
        output_tokens = torch.cat([output_tokens, next_token.view([1, 1])], dim=1)

        # Update the attention mask (add 1 for the new token)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1)).to(device)], dim=1)

    # Decode and return the generated sequence
    return tokenizer.decode(output_tokens.squeeze(), skip_special_tokens=True)


print(
    f"""
    prompt: {prompt}
    expert lm response: {my_generate(expert_lm, prompt, max_len=10, temperature=temp)}
    amature lm response: {my_generate(amateur_lm, prompt, max_len=10, temperature=temp)}
    contrastive decoding response(no constraint): {my_generate_cd_score([expert_lm, amateur_lm], prompt, max_len=10, temperature=temp)}
    """
)



    prompt: Barack Obama was born in Honolulu, Hawaii. He was born in
    expert lm response: Barack Obama was born in Honolulu, Hawaii. He was born in the United States on January 20, 1961.

    amature lm response: Barack Obama was born in Honolulu, Hawaii. He was born in Honolulu, Hawaii, on July 1, 1964.
    contrastive decoding response(no constraint): Barack Obama was born in Honolulu, Hawaii. He was born in teasponestplace ACTIONSfectureanchesterchester CT────ItemTracker
    


In [83]:
def my_generate_cd_full(models, prompt, max_len=100, temperatures=[1, 0,5], alpha=0.9):
    model_inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Start the generation with the initial inputs
    output_tokens = model_inputs["input_ids"]
    attention_mask = model_inputs["attention_mask"].to(device)

    past_key_values = [None] * 2
    logits = [None] * 2
    probs = [None] * 2
    logprobs = [None] * 2

    for _ in range(max_len):
        for i, model in enumerate(models): # exp first, ama snd       
            if past_key_values[i] is None:
                # For the first step, use the full input
                input = model_inputs
            else:
                # For subsequent steps, use the last generated token and past key values
                input = {
                    "input_ids": output_tokens[:, -1].unsqueeze(0),  # Use only the last token
                    "attention_mask": attention_mask,
                    "past_key_values": past_key_values[i],
                }

            # Forward pass through the model
            resp = model(**input)
            logits[i] = resp.logits[:, -1, :]  # Get the logits for the last token, shape: [batch, 1, n_token]
            past_key_values[i] = resp.past_key_values  # Update past key values

            probs[i] = torch.softmax(logits[i] / temperatures[i], dim=-1)
            logprobs[i] = torch.log(probs[i])

        cd_scores = logprobs[0] - logprobs[1]

        mask = probs[0] < alpha * torch.max(probs[0])
        cd_scores[mask] = float('-inf')

        next_token = torch.argmax(cd_scores)

        # Append the next token to the output
        output_tokens = torch.cat([output_tokens, next_token.view([1, 1])], dim=1)

        # Update the attention mask (add 1 for the new token)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1)).to(device)], dim=1)

    # Decode and return the generated sequence
    return tokenizer.decode(output_tokens.squeeze(), skip_special_tokens=True)

alpha = 0.7
temp = 1
print(
    f"""
    prompt: {prompt}
    expert lm response: {my_generate(expert_lm, prompt, max_len=10, temperature=temp)}
    amature lm response: {my_generate(amateur_lm, prompt, max_len=10, temperature=temp)}
    contrastive decoding response(no constraint): {my_generate_cd_score([expert_lm, amateur_lm], prompt, max_len=10)}
    contrastive decoding response(full): {my_generate_cd_full([expert_lm, amateur_lm], prompt, max_len=10, alpha=alpha)}
    """
)



    prompt: Barack Obama was born in Honolulu, Hawaii. He was born in
    expert lm response: Barack Obama was born in Honolulu, Hawaii. He was born in February, 1961. We're not sure when he
    amature lm response: Barack Obama was born in Honolulu, Hawaii. He was born in America for the first time at the adoption office for
    contrastive decoding response(no constraint): Barack Obama was born in Honolulu, Hawaii. He was born in teasponestplace ACTIONSfectureanchesterchester CT────ItemTracker
    contrastive decoding response(full): Barack Obama was born in Honolulu, Hawaii. He was born in the United States on January 20, 1961.

    
