# Dev contrastive knowledge assesment (CKA) notebook

* See [Capstone Repo](https://github.com/daniel-furman/Capstone) for more details
---

## Dependencies

In [77]:
!git clone https://github.com/daniel-furman/Capstone.git

fatal: destination path 'Capstone' already exists and is not an empty directory.


In [78]:
!pip install -r /content/Capstone/src/cka/scripts/requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Imports

In [79]:
import os
import numpy as np

import torch
from torch.nn.functional import softmax

from transformers import (
    set_seed,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForMaskedLM,
    T5Tokenizer,
    T5ForConditionalGeneration,
)

In [80]:
device = torch.device("cuda")

## Dev functions for new models

In [81]:
def probe_flan(model, tokenizer, target_id, context, verbose=False):

    # tokenize context
    tokenized_context = tokenizer(
        context,
        padding="longest",
        max_length=512,
        truncation=True,
        return_tensors="pt",
    ).to(device) 

    input_ids = tokenized_context["input_ids"]

    # use model to solicit a prediction
    outputs = model(
        input_ids=input_ids,
        decoder_input_ids=torch.tensor([[0, 32099]], device="cuda:0"),
        output_hidden_states=True,
        return_dict=True,
    )

    # We have batch size of 1, so grab that, then,
    # Take the entire last matrix which corresponds to the last layer
    logits = outputs["logits"][0, -1]

    # convert our prediction scores to a probability distribution with softmax
    probs = softmax(logits, dim=-1)

    probs = probs.detach().cpu().numpy()

    if verbose:
        print(f'\tcontext... {context}')
        print(f'\ttokenized_context ids... {tokenized_context["input_ids"]}')
        print(f'\tdecoded tokenized_context... {tokenizer.decode(tokenized_context["input_ids"][0])}')
        print(f'\tdecoded target id... {tokenizer.decode([target_id.item()])}')

    return probs[target_id.item()]


def probe_gpt2(model, tokenizer, target_id, context, verbose=False):

    # tokenize context
    tokenized_context = tokenizer(
        context,
        return_tensors="pt",
    ).to(device) 

    input_ids = tokenized_context["input_ids"]

    # grab value
    target_scalar = target_id.detach().cpu().numpy()

    # use model to solicit a prediction
    outputs = model(input_ids=input_ids, output_hidden_states=True, return_dict=True)

    # shape of 50257 which corresponds to the vocab size of GPT
    # every token in GPT's vocab gets a representative prediction from the model
    logits = outputs["logits"][0, -1]
    # convert our prediction scores to a probability distribution with softmax
    probs = softmax(logits, dim=-1)

    probs = list(probs.detach().cpu().numpy())

    if verbose:
        print(f'\tcontext... {context}')
        print(f'\ttokenized_context ids... {tokenized_context["input_ids"]}')
        print(f'\tdecoded tokenized_context... {tokenizer.decode(tokenized_context["input_ids"][0])}')
        print(f'\tdecoded target id... {tokenizer.decode([target_id.item()])}')
        
    # double check weird-ness before accessing prob
    if len(probs) < target_id:
        return None

    # return the likelihood that our stipulated target would follow the context,
    # according to the model
    try:
        return np.take(probs, [target_scalar])[0]

    except IndexError:

        print("target index not in model vocabulary scope; raising IndexError")
        return None


def probe_bert(model, tokenizer, target_id, context, verbose=False):

    # tokenize context
    tokenized_context = tokenizer(
        context,
        padding="longest",
        max_length=512,
        truncation=True,
        return_tensors="pt",
    )

    mask_token_index = torch.where(tokenized_context["input_ids"] == tokenizer.mask_token_id)[1]

    # use model to solicit a prediction
    logits = model(**tokenized_context.to(device)).logits
    mask_token_logits = logits[0, mask_token_index, :]

    # Convert our prediction scores to a probability distribution with softmax
    probs = torch.squeeze(softmax(mask_token_logits, dim=-1))

    probs = probs.detach().cpu().numpy()

    if verbose:
        print(f'\tcontext... {context}')
        print(f'\ttokenized_context ids... {tokenized_context["input_ids"]}')
        print(f'\tdecoded tokenize_context... {tokenizer.decode(tokenized_context["input_ids"][0])}')
        print(f'\tmask token id... {tokenizer.mask_token_id}')
        print(f'\tmask token index in context... {mask_token_index}')
        print(f'\tdecoded target id... {tokenizer.decode([target_id.item()])}')

    return probs[target_id.item()]

In [82]:
# first, write helper to pull a pretrained LM and tokenizer off the shelf
def get_model_and_tokenizer(model_name):
    if "flan" in model_name.lower():
        return T5Tokenizer.from_pretrained(
            model_name
        ), T5ForConditionalGeneration.from_pretrained(
            model_name, load_in_8bit=True, device_map="auto"
        )

    elif "gpt" in model_name.lower():
        return AutoTokenizer.from_pretrained(
            model_name
        ), AutoModelForCausalLM.from_pretrained(
            model_name, load_in_8bit=True, device_map="auto"
        )

    elif "bert" in model_name.lower():
        return AutoTokenizer.from_pretrained(
            model_name
        ), AutoModelForMaskedLM.from_pretrained(
            model_name, torch_dtype=torch.float16
        ).to(device)


# next, write a helper to pull a probe function for the given LM
def get_probe_function(prefix):
    probe_functions = [probe_flan, probe_gpt2, probe_bert]
    for func in probe_functions:
        if prefix.lower() in func.__name__:
            return func


# lastly, write a wrapper function to compare models
def compare_models(model_name_list, input_pairings, verbose):

    """
    Model-wise comparison helper function

    we should be able to do the following:
      * input a set of models we want to evaluate
      * input an expression of interest
      * input a 'true' next-token alonside a false
      * and get an output report that contains..
        * the 'result' ie is true > false
        * the probabilities of both of those values
      * running this method over a large set of positive/negative pairings should result in a large pool of information that can be used to compare model-families
      * we can also look at the relative 'certainty' across different models (at least in orders of magnitude)

    """

    score_dict_full = {}
    score_dict_succinct = {}

    for model_name in model_name_list:
        print(f"CKA for {model_name}")
        print("Loading  model...")

        # get proper model and tokenizer
        tokenizer, model = get_model_and_tokenizer(model_name)

        print("Running comparisons...")

        # establish prefix
        prefix = ""
        probe_func = None

        # get correct CKA function
        if ("t5" in model_name.lower()) or ("ul2" in model_name.lower()):
            prefix = "flan"
            probe_func = get_probe_function(prefix)

        elif "gpt" in model_name.lower():
            prefix = "gpt"
            probe_func = get_probe_function(prefix)

        elif "bert" in model_name.lower():
            prefix = "bert"
            probe_func = get_probe_function(prefix)

        # iterate over context/entity pairings
        # input_pairings is a dict
        # context is a plain string (since our context's will be unique)
        # and entities is a list containing, in the first slot, the true
        # value for the statement and in the subsequent slots, incorrect information

        for context, entities in input_pairings.items():
            entity_count = 0
            p_true = 0.0
            p_false = 0.0

            if prefix == "flan":
                context += " <extra_id_0>."

            if prefix == "bert":
                context += " [MASK]."

            for entity in entities:
                target_id = None
                # first find target vocab id
                # default to the very first token that get's predicted
                # e.g. in the case of Tokyo, which gets split into <Tok> <yo>,

                if prefix == "flan":
                    target_id = tokenizer.encode(
                        entity,
                        padding="longest",
                        max_length=512,
                        truncation=True,
                        return_tensors="pt",
                    ).to(device)[0][0]

                elif prefix == "gpt":
                    target_id = tokenizer.encode(entity, return_tensors="pt").to(device)[0][0]
                
                elif prefix == "bert":
                    target_id = tokenizer.encode(
                        entity,
                        padding="longest",
                        max_length=512,
                        truncation=True,
                        return_tensors="pt",
                    ).to(device)[0][1]
                
                # next call probe function
                model_prob = probe_func(model, tokenizer, target_id, context, verbose)
                
                # lastly, register results
                if entity_count == 0:
                    p_true = model_prob

                else:
                    p_false += model_prob

                entity_count += 1

            p_false /= entity_count - 1
            score_dict_full[model_name.lower() + ": " + context] = {
                "p_true": p_true,
                "p_false": p_false,
                "p_true - p_false": p_true - p_false,
                "p_true > p_false": p_true > p_false,
            }

            score_dict_succinct[model_name.lower() + ": " + context] = {
                "p_true > p_false": p_true > p_false
            }

        print("Done\n")
        del tokenizer
        del model
        torch.cuda.empty_cache()

    return score_dict_full, score_dict_succinct


## Test the functions

In [83]:
def main(config):

    set_seed(42)

    score_dict_full, score_dict_succinct = compare_models(
        config["models"], config["input_information"], config["verbosity"]
    )

    return score_dict_full, score_dict_succinct

### Bert

In [84]:
config = {
    "models": [
        "bert-base-uncased", # 110M params
    ],
    "input_information": {
        "The 2020 Olympics were held in": ["Tokyo", "Berlin"],
        "Operation Overlord took place in": ["Normandy", "Manila"],
        "Steve Jobs is the founder of": ["Apple", "Microsoft"],
    },
    "verbosity": True
}


In [85]:
score_dict_full, score_dict_succinct = main(config)

CKA for bert-base-uncased
Loading  model...


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Running comparisons...
	context... The 2020 Olympics were held in [MASK].
	tokenized_context ids... tensor([[  101,  1996, 12609,  3783,  2020,  2218,  1999,   103,  1012,   102]],
       device='cuda:0')
	decoded tokenize_context... [CLS] the 2020 olympics were held in [MASK]. [SEP]
	mask token id... 103
	mask token index in context... tensor([7])
	decoded target id... tokyo
	context... The 2020 Olympics were held in [MASK].
	tokenized_context ids... tensor([[  101,  1996, 12609,  3783,  2020,  2218,  1999,   103,  1012,   102]],
       device='cuda:0')
	decoded tokenize_context... [CLS] the 2020 olympics were held in [MASK]. [SEP]
	mask token id... 103
	mask token index in context... tensor([7])
	decoded target id... berlin
	context... Operation Overlord took place in [MASK].
	tokenized_context ids... tensor([[  101,  3169,  2058, 19980,  2165,  2173,  1999,   103,  1012,   102]],
       device='cuda:0')
	decoded tokenize_context... [CLS] operation overlord took place in [MASK]. [SEP

In [86]:
score_dict_full

{'bert-base-uncased: The 2020 Olympics were held in [MASK].': {'p_true': 0.7124,
  'p_false': 0.0009756088256835938,
  'p_true - p_false': 0.7114267349243164,
  'p_true > p_false': True},
 'bert-base-uncased: Operation Overlord took place in [MASK].': {'p_true': 0.01952,
  'p_false': 1.3828277587890625e-05,
  'p_true - p_false': 0.01950216293334961,
  'p_true > p_false': True},
 'bert-base-uncased: Steve Jobs is the founder of [MASK].': {'p_true': 0.2837,
  'p_false': 0.10113525390625,
  'p_true - p_false': 0.18255615234375,
  'p_true > p_false': True}}

In [87]:
score_dict_succinct

{'bert-base-uncased: The 2020 Olympics were held in [MASK].': {'p_true > p_false': True},
 'bert-base-uncased: Operation Overlord took place in [MASK].': {'p_true > p_false': True},
 'bert-base-uncased: Steve Jobs is the founder of [MASK].': {'p_true > p_false': True}}

### gpt2s

In [88]:
config = {
    "models": [
        "distilgpt2",  # 82M params
        #"gpt2",  # 124M params
        #"gpt2-medium",  # 355M params
        #"gpt2-large",  # 774M params
        #"gpt2-xl",  # 1.5B params
    ],
    "input_information": {
        "The 2020 Olympics were held in": ["Tokyo", "Berlin"],
        "Operation Overlord took place in": ["Normandy", "Manila"],
        "Steve Jobs is the founder of": ["Apple", "Microsoft"],
    },
    "verbosity": True

}


In [89]:
score_dict_full, score_dict_succinct = main(config)

CKA for distilgpt2
Loading  model...
Running comparisons...
	context... The 2020 Olympics were held in
	tokenized_context ids... tensor([[  464, 12131, 14935,   547,  2714,   287]], device='cuda:0')
	decoded tokenized_context... The 2020 Olympics were held in
	decoded target id... Tok
	context... The 2020 Olympics were held in
	tokenized_context ids... tensor([[  464, 12131, 14935,   547,  2714,   287]], device='cuda:0')
	decoded tokenized_context... The 2020 Olympics were held in
	decoded target id... Ber
	context... Operation Overlord took place in
	tokenized_context ids... tensor([[32180,  3827, 10572,  1718,  1295,   287]], device='cuda:0')
	decoded tokenized_context... Operation Overlord took place in
	decoded target id... Norm
	context... Operation Overlord took place in
	tokenized_context ids... tensor([[32180,  3827, 10572,  1718,  1295,   287]], device='cuda:0')
	decoded tokenized_context... Operation Overlord took place in
	decoded target id... Man
	context... Steve Jobs is t

In [90]:
score_dict_full

{'distilgpt2: The 2020 Olympics were held in': {'p_true': 4e-07,
  'p_false': 5.960464477539063e-08,
  'p_true - p_false': 3.5762786865234375e-07,
  'p_true > p_false': True},
 'distilgpt2: Operation Overlord took place in': {'p_true': 0.0,
  'p_false': 1.1920928955078125e-07,
  'p_true - p_false': -1.1920928955078125e-07,
  'p_true > p_false': False},
 'distilgpt2: Steve Jobs is the founder of': {'p_true': 9.36e-06,
  'p_false': 6.258487701416016e-06,
  'p_true - p_false': 3.0994415283203125e-06,
  'p_true > p_false': True}}

In [91]:
score_dict_succinct

{'distilgpt2: The 2020 Olympics were held in': {'p_true > p_false': True},
 'distilgpt2: Operation Overlord took place in': {'p_true > p_false': False},
 'distilgpt2: Steve Jobs is the founder of': {'p_true > p_false': True}}

### Google/flans

In [92]:
config = {
    "models": [
        "google/flan-t5-small",  # 80M params
        # "google/flan-t5-base",  # 250M params
        # "google/flan-t5-large",  # 780M params
        # "google/flan-t5-xl",  # 3B params
        # "google/flan-t5-xxl",  # 11B params
        # "google/flan-ul2",  # 20B params
    ],
    "input_information": {
        "The 2020 Olympics were held in": ["Tokyo", "Berlin"],
        "Operation Overlord took place in": ["Normandy", "Manila"],
        "Steve Jobs is the founder of": ["Apple", "Microsoft"],
    },
    "verbosity": True

}


In [93]:
score_dict_full, score_dict_succinct = main(config)

CKA for google/flan-t5-small
Loading  model...
Running comparisons...
	context... The 2020 Olympics were held in <extra_id_0>.
	tokenized_context ids... tensor([[   37,  6503, 17793,   130,  1213,    16, 32099,     3,     5,     1]],
       device='cuda:0')
	decoded tokenized_context... The 2020 Olympics were held in<extra_id_0>.</s>
	decoded target id... Tokyo
	context... The 2020 Olympics were held in <extra_id_0>.
	tokenized_context ids... tensor([[   37,  6503, 17793,   130,  1213,    16, 32099,     3,     5,     1]],
       device='cuda:0')
	decoded tokenized_context... The 2020 Olympics were held in<extra_id_0>.</s>
	decoded target id... Berlin
	context... Operation Overlord took place in <extra_id_0>.
	tokenized_context ids... tensor([[ 6411,  1575,  2035,   322,    26,   808,   286,    16, 32099,     3,
             5,     1]], device='cuda:0')
	decoded tokenized_context... Operation Overlord took place in<extra_id_0>.</s>
	decoded target id... Norman
	context... Operation Over

In [94]:
score_dict_full

{'google/flan-t5-small: The 2020 Olympics were held in <extra_id_0>.': {'p_true': 3.46e-06,
  'p_false': 3.5762786865234375e-07,
  'p_true - p_false': 3.0994415283203125e-06,
  'p_true > p_false': True},
 'google/flan-t5-small: Operation Overlord took place in <extra_id_0>.': {'p_true': 2.06e-05,
  'p_false': 1.0728836059570312e-06,
  'p_true - p_false': 1.9550323486328125e-05,
  'p_true > p_false': True},
 'google/flan-t5-small: Steve Jobs is the founder of <extra_id_0>.': {'p_true': 0.0005684,
  'p_false': 1.8715858459472656e-05,
  'p_true - p_false': 0.0005496740341186523,
  'p_true > p_false': True}}

In [95]:
score_dict_succinct

{'google/flan-t5-small: The 2020 Olympics were held in <extra_id_0>.': {'p_true > p_false': True},
 'google/flan-t5-small: Operation Overlord took place in <extra_id_0>.': {'p_true > p_false': True},
 'google/flan-t5-small: Steve Jobs is the founder of <extra_id_0>.': {'p_true > p_false': True}}