In [1]:
import random
import re
import time
from random import sample
from typing import List, Tuple

import kscope
import pandas as pd
from metrics import map_ag_news_int_labels, report_metrics
from tqdm import tqdm
from transformers import AutoTokenizer
from utils import get_label_token_ids, get_label_with_highest_likelihood, split_prompts_into_batches

# Getting Started

There is a bit of documentation on how to interact with the large models [here](https://kaleidoscope-sdk.readthedocs.io/en/latest/). The relevant github links to the SDK are [here](https://github.com/VectorInstitute/kaleidoscope-sdk) and underlying code [here](https://github.com/VectorInstitute/kaleidoscope).

First we connect to the service through which we'll interact with the LLMs and see which models are available to us

In [2]:
# Establish a client connection to the kscope service
client = kscope.Client(gateway_host="llm.cluster.local", gateway_port=3001)

Show all supported models

In [3]:
client.models

['gpt2',
 'llama2-7b',
 'llama2-7b_chat',
 'llama2-13b',
 'llama2-13b_chat',
 'llama2-70b',
 'llama2-70b_chat',
 'falcon-7b',
 'falcon-40b',
 'sdxl-turbo']

Show all model instances that are currently active

In [4]:
client.model_instances

[{'id': 'a33c0f4d-da2b-4861-8c3d-91e66955e879',
  'name': 'falcon-7b',
  'state': 'ACTIVE'},
 {'id': '7389b196-9637-4a42-adca-7bfb4f59733d',
  'name': 'llama2-7b',
  'state': 'ACTIVE'},
 {'id': 'b3871a00-4848-49be-a1c8-c8f6c47ad8b2',
  'name': 'falcon-40b',
  'state': 'ACTIVE'},
 {'id': '99bee87e-abc4-44fd-b4d3-ea2c527bb93e',
  'name': 'llama2-13b',
  'state': 'ACTIVE'},
 {'id': '4bd663ba-aab9-49f7-83aa-9e2fda3058e9',
  'name': 'llama2-70b',
  'state': 'LAUNCHING'}]

To start, we obtain a handle to a model. In this example, let's use the LLaMA-2 7B parameter model.

**NOTE**: This notebook uses activation retrieval to extract responses from the model: 
* This functionality is available for LLaMA-2 models (non-chat). 
* It is **NOT**, however, currently available for Falcon models of any size.

In [5]:
model = client.load_model("llama2-7b")
# If this model is not actively running, it will get launched in the background.
# In this case, wait until it moves into an "ACTIVE" state before proceeding.
while model.state != "ACTIVE":
    time.sleep(1)

We need to configure the model to generate in the way we want it to. So we set a number of important parameters. For a discussion of the configuration parameters see: `src/reference_implementations/prompting_vector_llms/CONFIG_README.md`

We're only interested in generating one token responses so we set `max_tokens` to 1

In [6]:
short_generation_config = {"max_tokens": 1, "top_p": 1.0, "temperature": 1.0}

Let's try a basic prompt for factual information.

__Note__ that if you run the cell multiple times, you'll get different responses due to sampling.

In [7]:
generation = model.generate("What is the capital of Canada?", {"max_tokens": 20, "temperature": 1.0})
# Extract the text from the returned generation
generation.generation["sequences"][0]

'\nCanadians do not live in igloos.\nA Canadian is called a Canadien'

We're going to have our model attempt to classify some news articles from the AG News Dataset. Articles have a single label 1-4

1. World
2. Sports
3. Business
4. Sci/Tech

This is a constrained label space. We'll use the words "World", "Sports", "Business", and "Technology" as generative LM targets for each of the labels.

In [8]:
def remove_markup(text: str) -> str:
    text = re.sub(r"https?://\S+|www\.\S+", "", text)
    text = re.sub(r"<.*?>+", "", text)
    return text


def ag_news_processor(path: str) -> Tuple[List[str], List[str], List[str]]:
    ag_news_data = pd.read_csv(path)
    labels = ag_news_data["Class Index"].tolist()
    titles = ag_news_data["Title"].apply(lambda x: remove_markup(x)).tolist()
    descriptions = ag_news_data["Description"].apply(lambda x: remove_markup(x)).tolist()
    return labels, titles, descriptions


int_to_label_map = {1: "world", 2: "sports", 3: "business", 4: "technology"}
ag_news_labels, ag_news_titles, ag_news_descriptions = ag_news_processor(
    "resources/ag_news_datasets/ag_news_sample.csv"
)

In [9]:
ag_news_labels = map_ag_news_int_labels(ag_news_labels, int_to_label_map)
ag_news_descriptions = [description.replace("\\", " ").strip() for description in ag_news_descriptions]
ag_news_titles = [title.strip() for title in ag_news_titles]
label_words = ["World", "Sports", "Business", "Technology"]
lowercase_labels = [word.lower() for word in label_words]

In [10]:
model_input_texts = [
    f"Title: {ag_news_title}\nDescription: {ag_news_description}"
    for ag_news_title, ag_news_description in zip(ag_news_titles, ag_news_descriptions)
]

Let's start by trying out a basic question prompt to see what the model does. You might also try some prompts from [this paper](https://arxiv.org/pdf/2212.04037.pdf). See Table 1.

In [11]:
prompt_template = "To which category does this news article belong? "
sample_texts = [f"{model_input_text}\n{prompt_template}" for model_input_text in model_input_texts[0:3]]
print(f"Example Prompt\n{sample_texts[0]}")
print("-------------------------------------")
generation = model.generate(sample_texts, short_generation_config)
for text in generation.generation["sequences"]:
    print(text)
    print("==================================")

Example Prompt
Title: Telecom lifts first quarter net profit 19pc
Description: Telecom Corp today reported its September first quarter net profit rose 19 per cent to $193 million. The profit bettered analysts #39; average forecasts of $185m.
To which category does this news article belong? 
-------------------------------------
1
3
1


Not well...Now let's try to constrain the model a bit by including the desired labels in the question.

In [13]:
prompt_template = "From World, Sports, Business, Technology, what category does this article belong to? "
sample_texts = [f"{model_input_text}\n{prompt_template}" for model_input_text in model_input_texts[0:3]]
print(f"Example Prompt\n{sample_texts[0]}")
print("-------------------------------------")
generation = model.generate(sample_texts, short_generation_config)
for text in generation.generation["sequences"]:
    print(text)
    print("==================================")

Example Prompt
Title: Telecom lifts first quarter net profit 19pc
Description: Telecom Corp today reported its September first quarter net profit rose 19 per cent to $193 million. The profit bettered analysts #39; average forecasts of $185m.
From World, Sports, Business, Technology, what category does this article belong to? 
-------------------------------------
3
›
﻿


The model doesn't really answer in the space that we want it to. Let's try with some few-shot examples to see if that helps.

__NOTE__: We have simply randomly picked the examples used in the 5-shot prompt. Different choices might be made, including 4-shot or 8-shot prompts so that categories are evenly represented.

In [14]:
prompt_demonstrations = """Title: Lane drives in winning run in ninth\nDescription: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night.\nCategory (World, Sports, Business, Technology): Sports

Title: Arson attack on Jewish centre in Paris (AFP)\nDescription: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said.\nCategory (World, Sports, Business, Technology): World

Title: Oil prices look set to dominate\nDescription: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue.\nCategory (World, Sports, Business, Technology): Business

Title: Indexes in Japan fall short of hype\nDescription: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity market.\nCategory (World, Sports, Business, Technology): Business

Title: UK Scientists Allowed to Clone Human Embryos (Reuters)\nDescription: Reuters - British scientists said on Wednesday they had received permission to clone human embryos for medical research, in what they believe to be the first such license to be granted in Europe.\nCategory (World, Sports, Business, Technology): Technology

"""  # noqa

Now we form the prompt with the demonstrations included

In [15]:
prompt_template_postfix = "Category (World, Sports, Business, Technology):"
sample_texts = [
    f"{prompt_demonstrations}{model_input_text}\n{prompt_template_postfix}"
    for model_input_text in model_input_texts[0:3]
]
print(f"Prompt Example\n{sample_texts[0]}")

Prompt Example
Title: Lane drives in winning run in ninth
Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night.
Category (World, Sports, Business, Technology): Sports

Title: Arson attack on Jewish centre in Paris (AFP)
Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said.
Category (World, Sports, Business, Technology): World

Title: Oil prices look set to dominate
Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue.
Category (World, Sports, Business, Technology): Business

Title: Indexes in Japan fall short of hype
Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity marke

In [16]:
generation = model.generate(sample_texts, short_generation_config)
for text in generation.generation["sequences"]:
    print(text)
    print("==================================")

Business
Business
Sports


Few-shot learning definitely helps a lot! We'll measure accuracy on a sample of the AG news dataset below. However, there is nothing stopping the model from not selecting our labels. So can we do better? We can work around this by understanding the likelihood of our labels from the model's perspective. This will also allow us to use zero-shot learning, even when the model doesn't seem to want to respond in the way we expect.

In [17]:
# We're interested in the activations from the last layer of the model, because this will allow us to calculate the
# likelihoods
last_layer_name = model.module_names[-1]
last_layer_name

'output'

The last layer activations of the model are analogous to the probabilities of each token in the model vocabulary. That is, it is the conditional probability
$$
P(y_t \vert y_{<t}, x),
$$
The probability distribution over the vocabulary of the next token given the preceding tokens $y_{<t}$, and the prompt text $x$. Thus, for each token $y_{t}$ in our input, we get back a vector of dimension $32000$ (the vocabulary size of LLaMA-2) which encodes the probability distribution of $y_{t+1}$ over the vocabulary. For this example, we only care about the last token in our input, as it houses the probability of the, as yet, unseen token the model will generate.

**NOTE**: The last layer for LLaMA-2, named "output," is actually the logits (pre-softmax) and therefore not quite probabilities, but is proportional to them.


#### Tokenizer 

For activation retrieval, we need to instantiate a tokenizer to obtain appropriate token indices for our labels. 

__NOTE__: All LLaMA-2 models, regardless of size, used the same tokenizer. However, if you want to use a different type of model, a different tokenizer may be needed.

If you are on the cluster, the tokenizer may be loaded from `/model-weights/Llama-2-7b-hf`. Otherwise, you'll need to download the `config.json`, `tokenizer.json`, `tokenizer.model`, and `tokenizer_config.json` from there to your local machine.

In [18]:
tokenizer = AutoTokenizer.from_pretrained("/model-weights/Llama-2-7b-hf")
# Let's test out how the tokenizer works on an example sentence. Note that the token with ID = 1 is the
# Beginning of sentence token ("<s>")
encoded_tokens = tokenizer.encode("Hello this is a test")
print(f"Encoded Tokens: {encoded_tokens}")
# If you ever need to move back from token ids, you can use tokenizer.decode or tokenizer.batch_decode
decoded_tokens = tokenizer.decode(encoded_tokens)
print(f"Decoded Tokens: {decoded_tokens}")

Encoded Tokens: [1, 15043, 445, 338, 263, 1243]
Decoded Tokens: <s> Hello this is a test


In [19]:
label_token_ids = get_label_token_ids(tokenizer, prompt_template, label_words)
# decode the tokens as a sanity check that we got the right IDs
tokenizer.decode(label_token_ids)

'World Sports Business Technology'

We need the token ids of our labels to extract the probabilties from the vocabulary of the model. The token id corresponds to the index of the token in the vocabulary matrix of the underlying model.

Let's look at how we can extract the likelihoods given the label tokens

In [20]:
single_prompted_input = f"{model_input_texts[0]}\n{prompt_template}"
print(f"Prompt Input\n{single_prompted_input}")
# Create a prompt and ask for activations of the last layer from the model
activations = model.get_activations(single_prompted_input, [last_layer_name], short_generation_config)
activations

Prompt Input
Title: Telecom lifts first quarter net profit 19pc
Description: Telecom Corp today reported its September first quarter net profit rose 19 per cent to $193 million. The profit bettered analysts #39; average forecasts of $185m.
From World, Sports, Business, Technology, what category does this article belong to? 


Activations(activations=[{'output': tensor([[-12.8203,  -7.4727,  -0.4651,  ...,  -6.7773,  -8.0078,  -7.4922],
        [-11.2344,  -6.3164,  -1.8682,  ...,  -6.1953,  -8.2266,  -5.8242],
        [-10.8359,  -8.3672,  -3.1270,  ...,  -6.3555,  -8.2812,  -5.5117],
        ...,
        [ -3.2578,   0.4604,  14.1641,  ...,   0.2345,  -0.9448,  -2.8594],
        [ -8.5000,  -5.9648,  13.7422,  ...,  -2.0059,  -3.4336,  -3.2305],
        [ -3.6855,  -1.6484,   3.0215,  ...,   6.7148,   2.2930,   2.1953]],
       dtype=torch.float16)}], logprobs=[[-2.4057371616363525]], sequences=['2'], tokens=[['2']])

The activations in the activations dictionary correspond to the outputs for each token in our prompt. So the shape of the tensor should be n_tokens x 32000, where 32000 is the size of LLaMA-2's vocabulary.

In [21]:
last_layer_matrix = activations.activations[0][last_layer_name]
print(f"Number of tokens: {len(tokenizer.encode(single_prompted_input))}")
# The shape of this tensor should be number of input tokens by the vocabulary size (n x 32000)
print(f"Activations matrix shape: {last_layer_matrix.shape}")

Number of tokens: 83
Activations matrix shape: torch.Size([83, 32000])


We're interested in the logits (i.e., activations prior to applying softmax) which correspond to our labels. The function `get_label_with_highest_likelihood` looks into the last row of the activations matrix (analogous to the probability distribution over the vocabulary of the first predicted token) and finds the largest logit among our labels.

In [22]:
predicted_label = get_label_with_highest_likelihood(
    last_layer_matrix, label_token_ids, int_to_label_map, right_shift=True
)
print(f"Predicted Label: {predicted_label}")

Predicted Label: technology


## Accuracy

Time to compare our results across our methods. 
1. Measure the accuracy of our zero-shot prompting approach.
2. Measure the accuracy of our few-shot prompting approach.
3. Measure the accuracy of our likelihood approach without zero-shot.
4. Measure the accuracy of our likelihood approach with few-shot.

### Zero-shot only

We know that our zero-shot approaches above struggled to answer in our expected label space. However, for fun, let's just quantify just how poorly we do if we try to use one of these prompts to perform our task.

In [23]:
prompt_template = "From World, Sports, Business, Technology, what category does this article belong to? "
prompts = [f"{model_input_text}\n{prompt_template}" for model_input_text in model_input_texts]
# For memory management, we split the prompts into batches of 10
prompt_batches = split_prompts_into_batches(prompts, 10)
predicted_labels = []
unmatched_predictions = []
for prompt_batch in tqdm(prompt_batches):
    generation = model.generate(prompt_batch, short_generation_config)
    # We'll use tokens this time and consider just the first token
    first_predicted_tokens = [tokens[0].strip().lower() for tokens in generation.generation["tokens"]]
    # If a token doesn't correspond to one of our labels, we'll randomly select one
    for potential_prediction in first_predicted_tokens:
        if potential_prediction in lowercase_labels:
            predicted_labels.append(potential_prediction)
        else:
            unmatched_predictions.append(potential_prediction)
            predicted_labels.append(random.choice(lowercase_labels))

100%|███████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:56<00:00,  5.70s/it]


In [25]:
print(f"Example Prompt\n{prompts[0]}")
print("---------------------------------------------")
print(f"Failed to match {len(unmatched_predictions)} responses to our label space")
print(f"Some examples of responses: {sample(unmatched_predictions, 10)}")
report_metrics(predicted_labels, ag_news_labels, labels_order=["world", "sports", "business", "technology"])

Example Prompt
Title: Telecom lifts first quarter net profit 19pc
Description: Telecom Corp today reported its September first quarter net profit rose 19 per cent to $193 million. The profit bettered analysts #39; average forecasts of $185m.
From World, Sports, Business, Technology, what category does this article belong to? 
---------------------------------------------
Failed to match 100 responses to our label space
Some examples of responses: ['maybe', '1', '1', '3', '二', '0', '2', '2', '3', '0']
Prediction Accuracy: 0.15
Confusion Matrix with ordering ['world', 'sports', 'business', 'technology']
[[ 6  4 13  5]
 [ 3  4  4 10]
 [ 8  6  2  7]
 [11  9  5  3]]
Label: world, F1: 0.21428571428571427, Precision: 0.21428571428571427, Recall: 0.21428571428571427
Label: sports, F1: 0.1818181818181818, Precision: 0.17391304347826086, Recall: 0.19047619047619047
Label: business, F1: 0.0851063829787234, Precision: 0.08333333333333333, Recall: 0.08695652173913043
Label: technology, F1: 0.113207

We can see that the model actually never answers with a response in our expected label space. As a result, we end up randomly guessing for all of our predictions, leading to an accuracy around 0.25.

### Few-shot only

In this example, we'll use a 5-shot prompt, as we did above and perform a "exact match" with our label space. That is, we parse out the first token that the model produces in its generation and simply try to string match it to one of our four label strings.

__Note__: Our generation configuration uses a `temperature = 1.0` which means that it samples from the vocabulary distribution, as predicted by the model. You could, try changing this to 0 to get a bit better factual extraction (greedy decoding).

In [26]:
prompt_template_postfix = "Category (World, Sports, Business, Technology):"
prompts = [
    f"{prompt_demonstrations}{model_input_text}\n{prompt_template_postfix}" for model_input_text in model_input_texts
]
# For memory management, we split the prompts into batches of 10
prompt_batches = split_prompts_into_batches(prompts, 10)
predicted_labels = []
for prompt_batch in tqdm(prompt_batches):
    generation = model.generate(prompt_batch, short_generation_config)
    # We'll use tokens this time and consider just the first token
    first_predicted_tokens = [tokens[0].strip().lower() for tokens in generation.generation["tokens"]]
    # If a token doesn't correspond to one of our labels, we'll randomly select one
    for potential_prediction in first_predicted_tokens:
        if potential_prediction in lowercase_labels:
            predicted_labels.append(potential_prediction)
        else:
            print(f"Potential Prediction: {potential_prediction} does not match any label")
            predicted_labels.append(random.choice(lowercase_labels))

 70%|████████████████████████████████████████████████████████████████▍                           | 7/10 [00:40<00:15,  5.12s/it]

Potential Prediction: baseball does not match any label


 80%|█████████████████████████████████████████████████████████████████████████▌                  | 8/10 [00:50<00:13,  6.53s/it]

Potential Prediction: econom does not match any label
Potential Prediction: news does not match any label


100%|███████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:58<00:00,  5.85s/it]


In [25]:
print(f"Example Prompt\n{prompts[10]}")

Example Prompt
Title: Lane drives in winning run in ninth
Description: Jason Lane took an unusual post-game batting practice with hitting coach Gary Gaetti after a disappointing performance Friday night.
Category (World, Sports, Business, Technology): Sports

Title: Arson attack on Jewish centre in Paris (AFP)
Description: AFP - A Jewish social centre in central Paris was destroyed by fire overnight in an anti-Semitic arson attack, city authorities said.
Category (World, Sports, Business, Technology): World

Title: Oil prices look set to dominate
Description: The price of oil looks set to grab headlines as analysts forecast that its record-breaking run may well continue.
Category (World, Sports, Business, Technology): Business

Title: Indexes in Japan fall short of hype
Description: Japanese stocks have failed to measure up to an assessment made in April by Merrill Lynch #39;s chief global strategist, David Bowers, who said Japan was  quot;very much everyone #39;s favorite equity marke

In [26]:
report_metrics(predicted_labels, ag_news_labels, labels_order=["world", "sports", "business", "technology"])

Prediction Accuracy: 0.56
Confusion Matrix with ordering ['world', 'sports', 'business', 'technology']
[[10  2 13  3]
 [ 3 12  6  0]
 [ 0  0 13 10]
 [ 3  1  3 21]]
Label: world, F1: 0.45454545454545453, Precision: 0.625, Recall: 0.35714285714285715
Label: sports, F1: 0.6666666666666666, Precision: 0.8, Recall: 0.5714285714285714
Label: business, F1: 0.4482758620689655, Precision: 0.37142857142857144, Recall: 0.5652173913043478
Label: technology, F1: 0.6774193548387097, Precision: 0.6176470588235294, Recall: 0.75


There are a few examples where the model doesn't answer in the space we expect, but there are not many such cases. The accuracy is a significant improvement over zero-shot.

### Likelihood Zero-shot

In this example, we do not incorporate any demonstrations into the prompt (zero-shot prompt). From our experience above, the model does not do a good job generating responses that correspond to our label space. So rather than trying to match responses to our labels as strings, we extract the probabilties of our labels (see example above), as estimated by the model's vocabulary projection, and select the label with the highest probability as the prediction.

In [27]:
prompt_template = "From World, Sports, Business, Technology, what category does this article belong to? "
prompts = [f"{model_input_text}\n{prompt_template}" for model_input_text in model_input_texts]
# For memory management, we split the prompts into batches of size 1, since the activations are heavier.
prompt_batches = split_prompts_into_batches(prompts, 1)
predicted_labels = []
for prompt_batch in tqdm(prompt_batches):
    activations = model.get_activations(prompt_batch, [last_layer_name], short_generation_config)
    for activations_single_prompt in activations.activations:
        last_layer_matrix = activations_single_prompt[last_layer_name]
        predicted_label = get_label_with_highest_likelihood(
            last_layer_matrix, label_token_ids, int_to_label_map, right_shift=True
        )
        predicted_labels.append(predicted_label)

100%|██████████| 100/100 [03:00<00:00,  1.81s/it]


In [27]:
report_metrics(predicted_labels, ag_news_labels, labels_order=["world", "sports", "business", "technology"])

Prediction Accuracy: 0.55
Confusion Matrix with ordering ['world', 'sports', 'business', 'technology']
[[ 9  2  7 10]
 [ 3  8  6  4]
 [ 0  0 17  6]
 [ 1  1  5 21]]
Label: world, F1: 0.4390243902439025, Precision: 0.6923076923076923, Recall: 0.32142857142857145
Label: sports, F1: 0.5, Precision: 0.7272727272727273, Recall: 0.38095238095238093
Label: business, F1: 0.5862068965517241, Precision: 0.4857142857142857, Recall: 0.7391304347826086
Label: technology, F1: 0.6086956521739131, Precision: 0.5121951219512195, Recall: 0.75


We can see this approach yielded a large improvement in prediction accuracy over the naive zero-shot prompting approach. Because we are only extracting the probabilities that match our label space, we do not have the issue with the prediction being outside of the label space. Moreover, we're actually able to match the performance of few-shot prompts with just zero-shot. It should be noted that this is likely due, in part, to the fact that we're using `temperature=1.0` for the few-shot prompting, but it's still a big improvement. That is, because we used a temperature of `1.0` in that example, the model sampled the next token from the predicted distribution. Thus, we didn't necessarily always select the token that the model thinks is the __most__ probable.

### Likelihood with Few-Shot

The zero-shot prompt combined with likelihood estimation for our label space does a much better job than pure zero-shot prompting. Let's combine the two approaches. We'll use a 5-shot prompt, as we did in the exact match example above, but now we'll use likelihood over our labels as the prediction mechanism rather than exact matching the first generated token.

In [None]:
prompt_template_postfix = "Category (World, Sports, Business, Technology):"
prompts = [
    f"{prompt_demonstrations}{model_input_text}\n{prompt_template_postfix}" for model_input_text in model_input_texts
]
# For memory management, we split the prompts into batches of size 1, since the activations are heavier.
prompt_batches = split_prompts_into_batches(prompts, 1)
predicted_labels = []
for prompt_batch in tqdm(prompt_batches):
    activations = model.get_activations(prompt_batch, [last_layer_name], short_generation_config)
    for activations_single_prompt in activations.activations:
        last_layer_matrix = activations_single_prompt[last_layer_name]
        predicted_label = get_label_with_highest_likelihood(
            last_layer_matrix, label_token_ids, int_to_label_map, right_shift=True
        )
        predicted_labels.append(predicted_label)

 96%|██████████████████████████████████████████████████████████████████████████████████████▍   | 96/100 [07:54<00:25,  6.42s/it]

In [30]:
report_metrics(predicted_labels, ag_news_labels, labels_order=["world", "sports", "business", "technology"])

Prediction Accuracy: 0.83
Confusion Matrix with ordering ['world', 'sports', 'business', 'technology']
[[17  1  7  3]
 [ 0 18  3  0]
 [ 1  0 21  1]
 [ 1  0  0 27]]
Label: world, F1: 0.7234042553191489, Precision: 0.8947368421052632, Recall: 0.6071428571428571
Label: sports, F1: 0.9, Precision: 0.9473684210526315, Recall: 0.8571428571428571
Label: business, F1: 0.7777777777777777, Precision: 0.6774193548387096, Recall: 0.9130434782608695
Label: technology, F1: 0.9152542372881356, Precision: 0.8709677419354839, Recall: 0.9642857142857143


Because we're using the likelihood mapping, there are no instances in which we fail to match our labels and we're performing this task quite well. Note that if we set our generation to greedy decoding (`temperature = 0`), we would likely get close to this performance without likelihood matching, but we'd still have label matching issues for some of our generations.