## Dependencies

In [11]:
import os
os.system("pip install transformers sentencepiece accelerate bitsandbytes")

0

## Imports

In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from transformers import (
    set_seed,
    AutoTokenizer,
    AutoModelForCausalLM,
    T5Tokenizer,
    T5ForConditionalGeneration,
)

## Setup

In [2]:
set_seed(42)

In [3]:
device = torch.device("cuda")

## Probe helper functions

In [4]:
def probe_t5(model, input_ids, target):
    """
    model: a pretrained google model pulled in from HuggingFace (ie. flan-t5-small,
      flan-ul2, t5-small, etc.)
    input_ids: the indices (in the vocabulary) of our left-context tokens
    target: the index (in the vocabulary) of the token we're gathering a prediction for
  
    return: a float indicating the likelihood of the target following the left-context
      according to the model in case of error, return None
    """

    # Call the model
    outputs = model(
        input_ids=input_ids,
        decoder_input_ids=torch.tensor([[0, 32099]], device="cuda:0"),
        output_hidden_states=True,
        return_dict=True,
    )

    # We have batch size of 1, so grab that, then,
    # Take the entire last matrix which corresponds to the last layer
    logits = outputs["logits"][0, -1]

    # convert our prediction scores to a probability distribution with softmax
    probs = F.softmax(logits, dim=-1)

    probs = probs.detach().cpu().numpy()

    return probs[target.item()]

In [5]:
def probe_gpt(model, input_ids, target):
    """
    model: a gpt pretrained model pulled in from HuggingFace
    input_ids: the indices (in gpt's vocabulary) of our left-context tokens
    target: the index (in gpt's vocabulary) of the token we're gathering a prediction for
  
    return: a float indicating the likelihood of the target following the left-context
      according to the model in case of error, return None
    """

    # ensure we're only asking for a single token prediction
    if len(target) > 1:
        # default to the very first token that get's predicted
        # e.g. in the case of Tokyo, which gets split into <Tok> <yo>,
        target = target[0]

    # sanity check - do a conversion that tells us the exact "token" predicted on
    # print(model.convert_)

    # grab value
    target_scalar = target.detach().cpu().numpy()

    # use model to solicit a prediction
    outputs = model(input_ids=input_ids, output_hidden_states=True, return_dict=True)

    # shape of 50257 which corresponds to the vocab size of GPT
    # every token in GPT's vocab gets a representative prediction from the model
    logits = outputs["logits"][0, -1]

    # convert our prediction scores to a probability distribution with softmax
    probs = F.softmax(logits, dim=-1)

    probs = list(probs.detach().cpu().numpy())

    # double check weird-ness before accessing prob
    if len(probs) < target:
        return None

    # return the likelihood that our stipulated target would follow the context,
    # according to the model
    try:
        return np.take(probs, [target_scalar])[0]
    except IndexError:

        print("target index not in model vocabulary scope; raising IndexError")
        return None

## Model-wise comparison helper functions

* we should be able to do the following
  * input a set of models we want to evaluate
  * input an expression of interest
  * input a 'true' next-token alonside a false
  * and get an output report that contains..
    * the 'result' ie is true > false
    * the probabilities of both of those values
  * running this method over a large set of positive/negative pairings should result in a large pool of information that can be used to compare model-families
  * we can also look at the relative 'certainty' across different models (at least in orders of magnitude)

In [6]:
# first, write helper to pull a pretrained LM and tokenizer off the shelf
def get_model_and_tokenizer(model_name):
    if "ul2" in model_name.lower():
        return T5Tokenizer.from_pretrained(
            model_name
        ), T5ForConditionalGeneration.from_pretrained(
            model_name, load_in_8bit=True, device_map="auto"
        )

    elif "t5" in model_name.lower():
        return T5Tokenizer.from_pretrained(
            model_name
        ), T5ForConditionalGeneration.from_pretrained(
            model_name, torch_dtype=torch.float16, device_map="auto"
        ).to(
            device
        )
    elif "gpt" in model_name.lower():
        return AutoTokenizer.from_pretrained(
            model_name
        ), AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=torch.float16, device_map="auto"
        ).to(
            device
        )

In [7]:
# next, write a helper to pull a probe function for the given LM
def get_probe_function(prefix):
    probe_functions = [probe_t5, probe_gpt]
    for func in probe_functions:
        if prefix.lower() in func.__name__:
            return func

In [12]:
# lastly, write a wrapper function to compare models
def compare_models(model_name_list, input_pairings):

    score_dict = {}

    for model_name in model_name_list:
        print(f"CKA for {model_name}")
        print("\tLoading  model...")

        # get proper model and tokenizer
        tokenizer, model = get_model_and_tokenizer(model_name)

        print("\tRunning comparisons...")

        # establish prefix
        prefix = ""
        probe_func = None

        # get correct CKA function
        if (("t5" in model_name.lower()) or ("ul2" in model_name.lower())):
            prefix = "t5"
            probe_func = get_probe_function(prefix)

        elif "gpt" in model_name.lower():
            prefix = "gpt"
            probe_func = get_probe_function(prefix)

        # iterate over context/entity pairings
        # input_pairings is a dict
        # context is a plain string (since our context's will be unique)
        # and entities is a list containing, in the first slot, the true
        # value for the statement and in the subsequent slots, incorrect information

        for context, entities in input_pairings.items():
            entity_count = 0
            p_true = 0.0
            p_false = 0.0

            if prefix == "t5":
                context += " <extra_id_0> ."

            for entity in entities:
                target = None
                if prefix == "t5":
                    target = tokenizer.encode(
                        entity,
                        padding="longest",
                        max_length=512,
                        truncation=True,
                        return_tensors="pt",
                    ).to(device)[0][0]
                elif prefix == "gpt":
                    target = tokenizer.encode(entity, return_tensors="pt").to(device)[0]

                # tokenize context
                input_ids = tokenizer.encode(
                        context,
                        return_tensors="pt",
                    ).to(device)

                # call probe function
                model_prob = probe_func(model, input_ids, target)

                if entity_count == 0:
                    p_true = model_prob

                else:
                    p_false += model_prob

                entity_count += 1

            p_false /= entity_count - 1
            score_dict[model_name.lower() + ": " + context] = {
                'p_true': p_true,
                'p_false': p_false,
                'p_true - p_false': p_true - p_false,
                'p_true > p_false': p_true > p_false
            }

        print("\tDone\n")
        del tokenizer
        del model
        torch.cuda.empty_cache()

    return score_dict


## Test model-wise comparison

### Flans

In [9]:
config = {
    "models": [
        #"google/flan-t5-small",
        #"google/flan-t5-large",
        #"google/flan-t5-xl",
        #"google/flan-t5-xxl",
        "google/flan-ul2",
    ],
    "input_information": {
        "The 2020 Olympics were held in": ["Tokyo", "Berlin"],
        "Operation Overlord took place in": ["Normandy", "Manila"],
        "Steve Jobs is the founder of": ["Apple", "Microsoft"]
    },
}

In [10]:
score_dict = compare_models(config["models"], config["input_information"])

CKA for google/flan-ul2
	Loading  model...

Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

	Running comparisons...
	Done



In [11]:
score_dict


{'google/flan-ul2: The 2020 Olympics were held in <extra_id_0> .': {'p_true': 0.7124,
  'p_false': 0.0001531839370727539,
  'p_true - p_false': 0.7122491598129272,
  'p_true > p_false': True},
 'google/flan-ul2: Operation Overlord took place in <extra_id_0> .': {'p_true': 0.2974,
  'p_false': 0.0,
  'p_true - p_false': 0.29736328125,
  'p_true > p_false': True},
 'google/flan-ul2: Steve Jobs is the founder of <extra_id_0> .': {'p_true': 0.4827,
  'p_false': 0.00016295909881591797,
  'p_true - p_false': 0.4825030565261841,
  'p_true > p_false': True}}

### GPT2s

In [13]:
config = {
    "models": [
        #"distilgpt2", 
        #"gpt2", 
        #"gpt2-medium", 
        #"gpt2-large", 
        "gpt2-xl"
    ],
    "input_information": {
        "The 2020 Olympics were held in": ["Tokyo", "Berlin"],
        "Operation Overlord took place in": ["Normandy", "Manila"],
        "Steve Jobs is the founder of": ["Apple", "Microsoft"]
    },
}

In [14]:
score_dict = compare_models(config["models"], config["input_information"])

CKA for gpt2-xl
	Loading  model...
	Running comparisons...
	Done



In [15]:
score_dict


{'gpt2-xl: The 2020 Olympics were held in': {'p_true': 2.384e-05,
  'p_false': 5.960464477539063e-08,
  'p_true - p_false': 2.378225326538086e-05,
  'p_true > p_false': True},
 'gpt2-xl: Operation Overlord took place in': {'p_true': 0.0,
  'p_false': 5.960464477539063e-08,
  'p_true - p_false': -5.960464477539063e-08,
  'p_true > p_false': False},
 'gpt2-xl: Steve Jobs is the founder of': {'p_true': array([0.0001261], dtype=float16),
  'p_false': array([2.4e-07], dtype=float16),
  'p_true - p_false': array([0.0001259], dtype=float16),
  'p_true > p_false': array([ True])}}