# Dev contrastive knowledge assesment notebook

<a target="_blank" href="https://colab.research.google.com/github/daniel-furman/Capstone/blob/main/notebooks/cka_dev_helpers.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>


## Dependencies

In [None]:
!git clone https://github.com/daniel-furman/Capstone.git
!pip install -r /content/Capstone/requirements.txt

In [None]:
# LLaMa requirements
# !pip install -r /content/Capstone/requirements_llama.txt

## Imports

In [None]:
import os
import datetime
import json
import numpy as np
import tqdm

import torch
from torch.nn.functional import softmax

import transformers
from transformers import (
    set_seed,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForMaskedLM,
    T5Tokenizer,
    T5ForConditionalGeneration,
)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available():
    raise Exception("Change runtime type to include a GPU.")

In [None]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

## Dev functions for new models

In [None]:
def probe_flan(model, tokenizer, target_id, context, verbose=False):

    # tokenize context
    input_ids = tokenizer(
        context,
        padding="longest",
        max_length=512,
        truncation=True,
        return_tensors="pt",
    ).input_ids.to(device)

    # use model to solicit a prediction
    outputs = model(
        input_ids=input_ids,
        decoder_input_ids=torch.tensor([[0, 32099]], device="cuda:0"),
        output_hidden_states=True,
        return_dict=True,
    )

    # We have batch size of 1, so grab that, then,
    # Take the entire first matrix which corresponds to the entity after the context
    logits = outputs["logits"][0, 0]

    # convert our prediction scores to a probability distribution with softmax
    probs = softmax(logits, dim=-1)

    probs = probs.detach().cpu().numpy()

    if verbose:
        print(f"\n\tcontext... {context}")
        print(f"\ttokenized_context ids... {input_ids}")
        print(f"\tdecoded tokenized_context... {tokenizer.decode(input_ids[0])}")
        print(f"\tdecoded target id... {tokenizer.decode([target_id.item()])}")
        print(
            f"\tmost probable prediction id decoded... {tokenizer.decode([np.argmax(probs)])}\n"
        )

    return probs[target_id.item()]


def probe_t5(model, tokenizer, target_id, context, verbose=False):

    # tokenize context
    input_ids = tokenizer(
        context,
        padding="longest",
        max_length=512,
        truncation=True,
        return_tensors="pt",
    ).input_ids.to(device)

    # use model to solicit a prediction
    outputs = model(
        input_ids=input_ids,
        decoder_input_ids=torch.tensor([[0, 32099]], device="cuda:0"),
        output_hidden_states=True,
        return_dict=True,
    )

    # We have batch size of 1, so grab that, then,
    # Take the entire first matrix which corresponds to the entity after the context
    logits = outputs["logits"][0, 1]

    # convert our prediction scores to a probability distribution with softmax
    probs = softmax(logits, dim=-1)

    probs = probs.detach().cpu().numpy()

    if verbose:
        print(f"\n\tcontext... {context}")
        print(f"\ttokenized_context ids... {input_ids}")
        print(f"\tdecoded tokenized_context... {tokenizer.decode(input_ids[0])}")
        print(f"\tdecoded target id... {tokenizer.decode([target_id.item()])}")
        print(
            f"\tmost probable prediction id decoded... {tokenizer.decode([np.argmax(probs)])}\n"
        )

    return probs[target_id.item()]


def probe_gpt(model, tokenizer, target_id, context, verbose=False):

    # tokenize context
    input_ids = tokenizer(
        context,
        return_tensors="pt",
    ).input_ids.to(device)

    # grab value
    target_scalar = target_id.detach().cpu().numpy()

    # use model to solicit a prediction
    outputs = model(input_ids=input_ids, output_hidden_states=True, return_dict=True)

    # every token in the model's vocab gets a representative prediction from the model
    logits = outputs["logits"][0, -1]
    # convert our prediction scores to a probability distribution with softmax
    probs = softmax(logits, dim=-1)

    probs = list(probs.detach().cpu().numpy())

    if verbose:
        print(f"\n\tcontext... {context}")
        print(f"\ttokenized_context ids... {input_ids}")
        print(f"\tdecoded tokenized_context... {tokenizer.decode(input_ids[0])}")
        print(f"\tdecoded target id... {tokenizer.decode([target_id.item()])}")
        print(
            f"\tmost probable prediction id decoded... {tokenizer.decode([np.argmax(probs)])}\n"
        )

    # double check weird-ness before accessing prob
    if len(probs) < target_id:
        return None

    # return the likelihood that our stipulated target would follow the context,
    # according to the model
    try:
        return np.take(probs, [target_scalar])[0]

    except IndexError:

        print("target index not in model vocabulary scope; raising IndexError")
        return None


def probe_bert(model, tokenizer, target_id, context, verbose=False):

    # tokenize context
    input_ids = tokenizer(
        context,
        padding="longest",
        max_length=512,
        truncation=True,
        return_tensors="pt",
    ).input_ids

    mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]

    # use model to solicit a prediction
    logits = model(input_ids=input_ids.to(device)).logits
    mask_token_logits = logits[0, mask_token_index, :]

    # Convert our prediction scores to a probability distribution with softmax
    probs = torch.squeeze(softmax(mask_token_logits, dim=-1))

    probs = probs.detach().cpu().numpy()

    if verbose:
        print(f"\n\tcontext... {context}")
        print(f"\ttokenized_context ids... {input_ids}")
        print(f"\tdecoded tokenize_context... {tokenizer.decode(input_ids[0])}")
        print(f"\tmask token id... {tokenizer.mask_token_id}")
        print(f"\tmask token index in context... {mask_token_index}")
        print(f"\tdecoded target id... {tokenizer.decode([target_id.item()])}")
        print(
            f"\tmost probable prediction id decoded... {tokenizer.decode([np.argmax(probs)])}\n"
        )

    return probs[target_id.item()]


def probe_llama(model, tokenizer, target_id, context, verbose=False):

    # tokenize context
    input_ids = tokenizer(
        context,
        return_tensors="pt",
    ).input_ids.to(device)

    # grab value
    target_scalar = target_id.detach().cpu().numpy()

    # use model to solicit a prediction
    outputs = model(input_ids=input_ids, output_hidden_states=True, return_dict=True)

    # every token in the model's vocab gets a representative prediction from the model
    logits = outputs["logits"][0, -1]
    # convert our prediction scores to a probability distribution with softmax
    probs = softmax(logits, dim=-1)

    probs = list(probs.detach().cpu().numpy())

    if verbose:
        print(f"\n\tcontext... {context}")
        print(f"\ttokenized_context ids... {input_ids}")
        print(f"\tdecoded tokenized_context... {tokenizer.decode(input_ids[0])}")
        print(f"\tdecoded target id... {tokenizer.decode([target_id.item()])}")
        print(
            f"\tmost probable prediction id decoded... {tokenizer.decode([np.argmax(probs)])}\n"
        )

    # double check weird-ness before accessing prob
    if len(probs) < target_id:
        return None

    # return the likelihood that our stipulated target would follow the context,
    # according to the model
    try:
        return np.take(probs, [target_scalar])[0]

    except IndexError:

        print("target index not in model vocabulary scope; raising IndexError")
        return None


In [None]:
# first, write helper to pull a pretrained LM and tokenizer off the shelf
def get_model_and_tokenizer(model_name):
    if ("flan" in model_name.lower()) or "t5" in model_name.lower():
        return T5Tokenizer.from_pretrained(
            model_name
        ), T5ForConditionalGeneration.from_pretrained(
            model_name, load_in_8bit=True, device_map="auto"
        )

    elif (
        ("gpt" in model_name.lower())
        or ("opt" in model_name.lower())
        or ("pythia" in model_name.lower())
    ):
        return AutoTokenizer.from_pretrained(
            model_name
        ), AutoModelForCausalLM.from_pretrained(
            model_name, load_in_8bit=True, device_map="auto"
        )

    elif "bert" in model_name.lower():
        return AutoTokenizer.from_pretrained(
            model_name
        ), AutoModelForMaskedLM.from_pretrained(
            model_name, torch_dtype=torch.float16
        ).to(
            device
        )

    elif "llama" in model_name.lower():
        return transformers.LLaMATokenizer.from_pretrained(
            "/content/drive/MyDrive/Colab Files/llama/LLaMA/int8/tokenizer/"
        ), transformers.LLaMAForCausalLM.from_pretrained(
            model_name, load_in_8bit=True, device_map="auto"
        )


# next, write a helper to pull a probe function for the given LM
def get_probe_function(prefix):
    probe_functions = [probe_flan, probe_gpt, probe_bert, probe_llama, probe_t5]
    for func in probe_functions:
        if prefix.lower() in func.__name__:
            return func


# lastly, write a wrapper function to compare models
def compare_models(model_name_list, input_pairings, verbose):

    """
    Model-wise comparison helper function

    we should be able to do the following:
      * input a set of models we want to evaluate
      * input an expression of interest
      * input a 'true' next-token alonside a false
      * and get an output report that contains..
        * the 'result' ie is true > false
        * the probabilities of both of those values
      * running this method over a large set of positive/negative pairings should result in a large pool of information that can be used to compare model-families
      * we can also look at the relative 'certainty' across different models (at least in orders of magnitude)

    """

    score_dict_full = {}
    score_dict_succinct = {}
    score_dict_summary = {}

    if not os.path.isdir("/content"):
        os.mkdir("/content")
    if not os.path.isdir("/content/logging"):
        os.mkdir("/content/logging")

    now = datetime.datetime.now()
    dt_string = now.strftime("%d_%m_%Y_%H_%M_%S")

    for model_name in model_name_list:
        true_count = 0
        fact_count = 0
        p_ratio = []
        p_trues = []

        print(f"CKA for {model_name}")
        print("Loading  model...")

        # get proper model and tokenizer
        tokenizer, model = get_model_and_tokenizer(model_name)

        print("Running comparisons...")

        # establish prefix
        prefix = ""
        probe_func = None

        # get correct CKA function
        if "flan" in model_name.lower():
            prefix = "flan"
            probe_func = get_probe_function(prefix)
        if "t5" in model_name.lower():
            prefix = "t5"
            probe_func = get_probe_function(prefix)
        elif (
            ("gpt-neo" in model_name.lower())
            or ("gpt-j" in model_name.lower())
            or ("pythia" in model_name.lower())
        ):
            prefix = "eleutherai"
            probe_func = get_probe_function("gpt")

        elif "gpt" in model_name.lower():
            prefix = "gpt"
            probe_func = get_probe_function(prefix)

        elif "opt" in model_name.lower():
            prefix = "opt"
            probe_func = get_probe_function("gpt")

        elif "roberta" in model_name.lower():
            prefix = "roberta"
            probe_func = get_probe_function("bert")

        elif "bert" in model_name.lower():
            prefix = "bert"
            probe_func = get_probe_function(prefix)

        elif "llama" in model_name.lower():
            prefix = "llama"
            probe_func = get_probe_function(prefix)

        # iterate over context/entity pairings
        # input_pairings is a dict
        # context is a plain string (since our context's will be unique)
        # and entities is a list containing, in the first slot, the true
        # value for the statement and in the subsequent slots, incorrect information

        for _, entities_dict in tqdm.tqdm(input_pairings.items()):

            for counterfact in entities_dict["false"]:

                fact_count += 1

                context = entities_dict["stem"]
                entities = [entities_dict["true"], counterfact]
                entity_count = 0
                p_true = 0.0
                p_false = 0.0

                if prefix == "roberta":
                    context += " <mask>."
                elif prefix == "bert":
                    context += " [MASK]."
                elif prefix == "t5":
                    context += " <extra_id_0>."

                for entity in entities:
                    target_id = None

                    # first find target vocab id
                    # default to the very first token that get's predicted
                    # e.g. in the case of Tokyo, which gets split into <Tok> <yo>,
                    if (prefix == "flan") or (prefix == "t5"):
                        target_id = tokenizer.encode(
                            " " + entity,
                            padding="longest",
                            max_length=512,
                            truncation=True,
                            return_tensors="pt",
                        ).to(device)[0][0]

                    elif (prefix == "gpt") or (prefix == "eleutherai"):
                        target_id = tokenizer.encode(
                            " " + entity, return_tensors="pt"
                        ).to(device)[0][0]

                    elif prefix == "opt":
                        target_id = tokenizer.encode(
                            " " + entity, return_tensors="pt"
                        ).to(device)[0][1]

                    elif prefix == "roberta":
                        target_id = tokenizer.encode(
                            " " + entity,
                            padding="longest",
                            max_length=512,
                            truncation=True,
                            return_tensors="pt",
                        ).to(device)[0][1]

                    elif prefix == "bert":
                        target_id = tokenizer.encode(
                            entity,
                            padding="longest",
                            max_length=512,
                            truncation=True,
                            return_tensors="pt",
                        ).to(device)[0][1]

                    elif prefix == "llama":
                        target_id = tokenizer.encode(
                            " " + entity, return_tensors="pt"
                        ).to(device)[0][2]

                    # next call probe function
                    model_prob = probe_func(
                        model, tokenizer, target_id, context, verbose
                    )

                    # lastly, register results
                    # if it is the first time through, it is the fact
                    if entity_count == 0:
                        p_true = model_prob
                    # if it is the second time through, it is the counterfactual
                    else:
                        p_false = model_prob

                    entity_count += 1

                try:
                    score_dict_full[model_name.lower()].append(
                        {
                            context
                            + " "
                            + f"{entities}": {
                                "p_true": float(p_true),
                                "p_false": float(p_false),
                                "p_true - p_false": float(p_true) - float(p_false),
                                "p_true > p_false": str(p_true > p_false),
                                "p_true / (p_true + p_false)": float(p_true)
                                / (float(p_true) + float(p_false) + 1e-8),
                            }
                        }
                    )
                except KeyError:
                    score_dict_full[model_name.lower()] = [
                        {
                            context
                            + " "
                            + f"{entities}": {
                                "p_true": float(p_true),
                                "p_false": float(p_false),
                                "p_true - p_false": float(p_true) - float(p_false),
                                "p_true > p_false": str(p_true > p_false),
                                "p_true / (p_true + p_false)": float(p_true)
                                / (float(p_true) + float(p_false) + 1e-8),
                            }
                        }
                    ]

                try:
                    score_dict_succinct[model_name.lower()].append(
                        {
                            context
                            + " "
                            + f"{entities}": {
                                "p_true > p_false": str(p_true > p_false),
                            }
                        }
                    )
                except KeyError:
                    score_dict_succinct[model_name.lower()] = [
                        {
                            context
                            + " "
                            + f"{entities}": {
                                "p_true > p_false": str(p_true > p_false),
                            }
                        }
                    ]

                if p_true > p_false:
                    true_count += 1

                p_ratio.append(float(p_true) / (float(p_true) + float(p_false) + 1e-8))
                p_trues.append(float(p_true))

        score_dict_summary[
            model_name.lower()
        ] = f"This model predicted {true_count}/{fact_count} facts at a higher prob than the given counterfactual. The mean p_true / (p_true + p_false) was {np.round(np.mean(np.array(p_ratio)), decimals=4)} while the mean p_true was {np.round(np.mean(np.array(p_trues)), decimals=4)}"

        print("Done\n")
        del tokenizer
        del model
        torch.cuda.empty_cache()

    score_dicts = [score_dict_full, score_dict_succinct, score_dict_summary]

    # logging
    score_dicts_logging = {}
    score_dicts_logging["curr_datetime"] = str(now)
    score_dicts_logging["model_name"] = model_name
    score_dicts_logging["score_dict_summary"] = score_dict_summary
    score_dicts_logging["score_dict_full"] = score_dict_full
    score_dicts_logging["score_dict_succinct"] = score_dict_succinct

    with open(
        f"/content/logging/{prefix}_logged_cka_outputs_{dt_string}.json", "w"
    ) as outfile:
        json.dump(score_dicts_logging, outfile)

    return score_dicts


In [None]:
def main(config):

    set_seed(42)

    score_dicts = compare_models(
        config["models"], config["input_information"], config["verbosity"]
    )

    return score_dicts

## Run CKA

### Google/t5s

In [None]:
config = {
    "models": [
        "google/t5-v1_1-base",
    ],
    "input_information": {
        "0": {
            "stem": "The 2020 Olympics were held in",
            "true": "Tokyo",
            "false":["London","Berlin", "Chicago"]
        },
        "1": {
            "stem": "Operation Overlord took place in",
            "true": "Normandy",
            "false":["Manila","Santiago", "Baghdad"]
        },
        "2": {
            "stem": "Steve Jobs is the founder of",
            "true": "Apple",
            "false":["Microsoft","Oracle", "Intel"]
        }
    },
    "verbosity": True

}

score_dicts = main(config)

print(score_dicts[0])
print(score_dicts[1])
print(score_dicts[2])

### EleutherAI models

In [None]:
config = {
    "models": [
        "EleutherAI/gpt-neo-125M",
        #"EleutherAI/pythia-410m",
    ],
    "input_information": {
        "0": {
            "stem": "The 2020 Olympics were held in",
            "true": "Tokyo",
            "false":["London","Berlin", "Chicago"]
        },
        "1": {
            "stem": "Operation Overlord took place in",
            "true": "Normandy",
            "false":["Manila","Santiago", "Baghdad"]
        },
        "2": {
            "stem": "Steve Jobs is the founder of",
            "true": "Apple",
            "false":["Microsoft","Oracle", "Intel"]
        }
    },
    "verbosity": True
}

score_dicts = main(config)

print(score_dicts[0])
print(score_dicts[1])
print(score_dicts[2])

### OPT

In [None]:
config = {
    "models": [
        "facebook/opt-125m",
        # "facebook/opt-350m",
    ],
    "input_information": {
        "0": {
            "stem": "The 2020 Olympics were held in",
            "true": "Tokyo",
            "false":["London","Berlin", "Chicago"]
        },
        "1": {
            "stem": "Operation Overlord took place in",
            "true": "Normandy",
            "false":["Manila","Santiago", "Baghdad"]
        },
        "2": {
            "stem": "Steve Jobs is the founder of",
            "true": "Apple",
            "false":["Microsoft","Oracle", "Intel"]
        }
    },
    "verbosity": True
}

score_dicts = main(config)

print(score_dicts[0])
print(score_dicts[1])
print(score_dicts[2])

### RoBERTa

In [None]:
config = {
    "models": [
        "roberta-base",
    ],
    "input_information": {
        "0": {
            "stem": "The 2020 Olympics were held in",
            "true": "Tokyo",
            "false":["London","Berlin", "Chicago"]
        },
        "1": {
            "stem": "Operation Overlord took place in",
            "true": "Normandy",
            "false":["Manila","Santiago", "Baghdad"]
        },
        "2": {
            "stem": "Steve Jobs is the founder of",
            "true": "Apple",
            "false":["Microsoft","Oracle", "Intel"]
        }
    },
    "verbosity": True
}

score_dicts = main(config)

print(score_dicts[0])
print(score_dicts[1])
print(score_dicts[2])

### Bert

In [None]:
config = {
    "models": [
        "bert-base-uncased",
    ],
    "input_information": {
        "0": {
            "stem": "The 2020 Olympics were held in",
            "true": "Tokyo",
            "false":["London","Berlin", "Chicago"]
        },
        "1": {
            "stem": "Operation Overlord took place in",
            "true": "Normandy",
            "false":["Manila","Santiago", "Baghdad"]
        },
        "2": {
            "stem": "Steve Jobs is the founder of",
            "true": "Apple",
            "false":["Microsoft","Oracle", "Intel"]
        }
    },
    "verbosity": True
}

score_dicts = main(config)

print(score_dicts[0])
print(score_dicts[1])
print(score_dicts[2])

### gpt2s

In [None]:
config = {
    "models": [
        "gpt2",
    ],
    "input_information": {
        "0": {
            "stem": "The 2020 Olympics were held in",
            "true": "Tokyo",
            "false":["London","Berlin", "Chicago"]
        },
        "1": {
            "stem": "Operation Overlord took place in",
            "true": "Normandy",
            "false":["Manila","Santiago", "Baghdad"]
        },
        "2": {
            "stem": "Steve Jobs is the founder of",
            "true": "Apple",
            "false":["Microsoft","Oracle", "Intel"]
        }
    },
    "verbosity": True

}

score_dicts = main(config)

print(score_dicts[0])
print(score_dicts[1])
print(score_dicts[2])


### Google/flans

In [None]:
config = {
    "models": [
        "google/flan-t5-small",  # 80M params
    ],
    "input_information": {
        "0": {
            "stem": "The 2020 Olympics were held in",
            "true": "Tokyo",
            "false":["London","Berlin", "Chicago"]
        },
        "1": {
            "stem": "Operation Overlord took place in",
            "true": "Normandy",
            "false":["Manila","Santiago", "Baghdad"]
        },
        "2": {
            "stem": "Steve Jobs is the founder of",
            "true": "Apple",
            "false":["Microsoft","Oracle", "Intel"]
        }
    },
    "verbosity": True

}

score_dicts = main(config)

print(score_dicts[0])
print(score_dicts[1])
print(score_dicts[2])

## LLaMas

As of Mar, 2023, LLaMa models can only be ran if you have access to the weights. Change the Google drive file paths to the location of the weights on your drive. 

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
config = {
    "models": [
        "/content/drive/MyDrive/Colab Files/llama/LLaMA/int8/llama-7b/",
    ],
    "input_information": {
        "0": {
            "stem": "The 2020 Olympics were held in",
            "true": "Tokyo",
            "false":["London","Berlin", "Chicago"]
        },
        "1": {
            "stem": "Operation Overlord took place in",
            "true": "Normandy",
            "false":["Manila","Santiago", "Baghdad"]
        },
        "2": {
            "stem": "Steve Jobs is the founder of",
            "true": "Apple",
            "false":["Microsoft","Oracle", "Intel"]
        }
    },
    "verbosity": False
}

score_dicts = main(config)

print(score_dicts[0])
print(score_dicts[1])
print(score_dicts[2])