In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
import checklist
from checklist.editor import Editor
from checklist.expect import Expect
from checklist.pred_wrapper import PredictorWrapper
from checklist.test_types import MFT
from typing import List
import warnings
warnings.filterwarnings('ignore')

# MFTs: Introduction
In this notebook, we will create Minimum Functionality Tests (MFTs) for a generative language model. MFTs test one specific function of a language model. They are analogous to unit tests in traditional software engineering.

## Setup generative model
Before we can test anything, we need to set up our language model. We will use the HuggingFace transformers library to load a GPT2 model.

First, we create a tokenizer. The tokenizer is responsible for splitting strings into individual words, then converting those words into vectors of numbers that our model can understand.

In [2]:
# Load pretrained model tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Demonstrate what the tokenizer does
tokenizer.encode("Wherefore art thou Romeo?")

[8496, 754, 1242, 14210, 43989, 30]

Our tokenizer has turned the human-readable text into a list of numbers that the model understands. Next, let's load the GPT2 model.

In [3]:
# Load pretrained model (weights)
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
device = 'cuda'
model.eval()
model.to(device)
"Model loaded"

'Model loaded'

Generating text with the model requires a bit of work. Let's write a function `generate_sentences` to handle the text generation.

`generate_sentences` has 1 parameter, `prompts`, which is a list of strings. A prompt is a string that the model will use as a starting point for generating new text. It gives the model context about what kind of text should be generated.

`generate_sentences` will output a list of generated text responses for each prompt.

In [4]:
def generate_sentences(prompts: List[str]) -> List[str]:
    sentences = []
    for prompt in prompts:
        token_tensor = tokenizer.encode(prompt, return_tensors='pt').to(device) # return_tensors = "pt" returns a PyTorch tensor
        out = model.generate(
            token_tensor,
            do_sample=True,
            min_length=10,
            max_length=50,
            num_beams=1,
            no_repeat_ngram_size=2,
            early_stopping=False,
            output_scores=True,
            return_dict_in_generate=True)
        text = tokenizer.decode(out.sequences[0], skip_special_tokens=True)
        sentences.append(text[len(prompt):])
    return sentences

In [5]:
generate_sentences(["Wherefore art thou Romeo?"])

[' why didst thou say then that God would not send him to the kingdom for such an errno? (2) When he met, the king said, Do you therefore find a person worthy to be saved, that']

Now that everything is ready, we can write our first MFT.

## MFT - Language prompt
For this MFT, we will expect the model to create a reasonable continuation of a prompt. The model will be prompted with strings like "The most commonly spoken language in {country} is " where {country} is a placeholder for a country such as Spain.

We need create a rule to determine if the model passes our test. The criteria for passing or failing the test is entirely user defined. We will consider this MFT to pass if the model's output contains any language name. This will demonstrate that the model understands the general context of the prompt. The mentioned language doesn't have to be accurate - for example, "In Spain the most commonly spoken language is Indonesian" would pass our test, because Indonesian is a language. The language may also be located anywhere in the output - for example, "In Spain the most commonly spoken language is not easy to learn. Spanish has many complicated conjugations." would also pass our test.

In a later section of this notebook, there is another version of this MFT that is stricter, requiring the correct language to be mentioned in the response.

### Handwritten MFT
First, we will write the MFT by hand. Then, we'll use Checklist's MFT class to demonstrate how Checklist helps us create the MFT much more quickly.

#### Generate prompts from template
We will use Checklist's Editor class to quickly create the prompts. For a detailed explanation of generating data, see the "1. Generating data" tutorial notebook.

In [6]:
editor = Editor()
prompt_strs = editor.template("The most commonly spoken language in {country} is", nsamples=10)
prompt_strs.data

['The most commonly spoken language in Vatican City is',
 'The most commonly spoken language in Albania is',
 'The most commonly spoken language in Albania is',
 'The most commonly spoken language in Vatican City is',
 'The most commonly spoken language in The Gambia is',
 'The most commonly spoken language in Japan is',
 'The most commonly spoken language in Libya is',
 'The most commonly spoken language in Oman is',
 'The most commonly spoken language in Liberia is',
 'The most commonly spoken language in Jamaica is']

#### Language CSV
We need a list of languages to check if the model's output contains a language. To save some time, we will read language names from a CSV file. The data comes from standard ISO Language Codes https://datahub.io/core/language-codes 

In [7]:
import urllib.request
urllib.request.urlretrieve('https://datahub.io/core/language-codes/r/language-codes.csv', 'language-codes.csv')
lang_codes_csv = pd.read_csv('language-codes.csv')
lang_codes_csv

Unnamed: 0,alpha2,English
0,aa,Afar
1,ab,Abkhazian
2,ae,Avestan
3,af,Afrikaans
4,ak,Akan
...,...,...
179,yi,Yiddish
180,yo,Yoruba
181,za,Zhuang; Chuang
182,zh,Chinese


#### Run the MFT
Now we're ready to create the MFT. We will create 3 Pandas dataframes, one each for prompts, responses, and results. Then, we will loop over the prompts, send each prompt to the model, and determine if it passes or fails the test. Each prompt and its test result will be recorded in the dataframes.

In [8]:
prompts = pd.DataFrame({"id": [], "prompt": []})
responses = pd.DataFrame({"id": [], "response": []})
results = pd.DataFrame({"id": [], "p/f": []})
langs = lang_codes_csv["English"].tolist()

model_responses = generate_sentences(prompt_strs.data)

for (i, response) in enumerate(model_responses):
    pf = 'fail'
    
    # Check if any language from the CSV data is in the generated string
    for l in langs:
        if l in response:
            pf = 'pass'
            break

    prompts = prompts.append({"id": i, "prompt": prompt_strs.data[i]}, ignore_index=True)
    responses = responses.append({"id": i, "response": response}, ignore_index=True)
    results = results.append({"id": i, "p/f": pf}, ignore_index=True)

#### Show test results
Now let's look at the results of our test.

In [9]:
pd.set_option("max_colwidth", 250)

In [10]:
prompts

Unnamed: 0,id,prompt
0,0.0,The most commonly spoken language in Vatican City is
1,1.0,The most commonly spoken language in Albania is
2,2.0,The most commonly spoken language in Albania is
3,3.0,The most commonly spoken language in Vatican City is
4,4.0,The most commonly spoken language in The Gambia is
5,5.0,The most commonly spoken language in Japan is
6,6.0,The most commonly spoken language in Libya is
7,7.0,The most commonly spoken language in Oman is
8,8.0,The most commonly spoken language in Liberia is
9,9.0,The most commonly spoken language in Jamaica is


In [11]:
responses

Unnamed: 0,id,response
0,0.0,"English but the majority of members speak Spanish. It's likely that some of this has to do with the lack of official English translation into Spanish, which is one reason that many of these women don't speak"
1,1.0,"Albanian. More than 80% of Albanians speak English, though almost ten percent speak Albanese.\n\nIn June 1995, the country declared independence from Montenegro when a peace deal for a new Alban"
2,2.0,"Zagreb, which according to some translations has the name Trom.\n\nA Croatian version is used in the language. It is also written 'z' (zel). The translation also follows the"
3,3.0,"Farsi in the language of the Catholic Church, which means ""good"". Many cultures in Egypt, Jordan and Syria, as well as in Muslim countries, have no Arabic language at all. If an"
4,4.0,"English. It's a popular destination with tourists and its a big city, with shops for everything from soap to jewelry, clothes to cars.\n\nIt's the most highly trafficked place in East"
5,5.0,"jihonsu. If japanese wasn't the only language spoken at a school, there would have been no Jihonjin.\n\nPeople in jinguistic classes use the phrase ""a"
6,6.0,"""bel,"" according to the report.\n\nThe report said that the people of Libya used ""boun,"" a common term that means a person with no religious affiliation or belonging to a ""sultanate"
7,7.0,"the English spoken by the first and second generation of the ruling family. On this island's islands there are some 400 thousand nomadic people, as well as a number of tribes with a vast variety of languages spoken"
8,8.0,"Zara, with many people who speak the native tongue, but there have been reports of people using the term ""Zara"" or ""Husriya Bikwan"" to describe people with a history"
9,9.0,"Jamaican. That is when it is referred to as the country of Jamaica ""languages,"" and in practice it's usually the word to which it was translated based on the Jamaat-e-Islam Jama"


In [12]:
results

Unnamed: 0,id,p/f
0,0.0,pass
1,1.0,pass
2,2.0,pass
3,3.0,pass
4,4.0,pass
5,5.0,fail
6,6.0,fail
7,7.0,pass
8,8.0,fail
9,9.0,fail


We can merge all the dataframes to make the results easier to read.

In [13]:
merged = pd.merge(responses, results, on="id")
merged = pd.merge(prompts, merged, on="id")
merged

Unnamed: 0,id,prompt,response,p/f
0,0.0,The most commonly spoken language in Vatican City is,"English but the majority of members speak Spanish. It's likely that some of this has to do with the lack of official English translation into Spanish, which is one reason that many of these women don't speak",pass
1,1.0,The most commonly spoken language in Albania is,"Albanian. More than 80% of Albanians speak English, though almost ten percent speak Albanese.\n\nIn June 1995, the country declared independence from Montenegro when a peace deal for a new Alban",pass
2,2.0,The most commonly spoken language in Albania is,"Zagreb, which according to some translations has the name Trom.\n\nA Croatian version is used in the language. It is also written 'z' (zel). The translation also follows the",pass
3,3.0,The most commonly spoken language in Vatican City is,"Farsi in the language of the Catholic Church, which means ""good"". Many cultures in Egypt, Jordan and Syria, as well as in Muslim countries, have no Arabic language at all. If an",pass
4,4.0,The most commonly spoken language in The Gambia is,"English. It's a popular destination with tourists and its a big city, with shops for everything from soap to jewelry, clothes to cars.\n\nIt's the most highly trafficked place in East",pass
5,5.0,The most commonly spoken language in Japan is,"jihonsu. If japanese wasn't the only language spoken at a school, there would have been no Jihonjin.\n\nPeople in jinguistic classes use the phrase ""a",fail
6,6.0,The most commonly spoken language in Libya is,"""bel,"" according to the report.\n\nThe report said that the people of Libya used ""boun,"" a common term that means a person with no religious affiliation or belonging to a ""sultanate",fail
7,7.0,The most commonly spoken language in Oman is,"the English spoken by the first and second generation of the ruling family. On this island's islands there are some 400 thousand nomadic people, as well as a number of tribes with a vast variety of languages spoken",pass
8,8.0,The most commonly spoken language in Liberia is,"Zara, with many people who speak the native tongue, but there have been reports of people using the term ""Zara"" or ""Husriya Bikwan"" to describe people with a history",fail
9,9.0,The most commonly spoken language in Jamaica is,"Jamaican. That is when it is referred to as the country of Jamaica ""languages,"" and in practice it's usually the word to which it was translated based on the Jamaat-e-Islam Jama",fail


Finally, let's display the failing tests.

In [14]:
merged.loc[merged['p/f'] == 'fail']

Unnamed: 0,id,prompt,response,p/f
5,5.0,The most commonly spoken language in Japan is,"jihonsu. If japanese wasn't the only language spoken at a school, there would have been no Jihonjin.\n\nPeople in jinguistic classes use the phrase ""a",fail
6,6.0,The most commonly spoken language in Libya is,"""bel,"" according to the report.\n\nThe report said that the people of Libya used ""boun,"" a common term that means a person with no religious affiliation or belonging to a ""sultanate",fail
8,8.0,The most commonly spoken language in Liberia is,"Zara, with many people who speak the native tongue, but there have been reports of people using the term ""Zara"" or ""Husriya Bikwan"" to describe people with a history",fail
9,9.0,The most commonly spoken language in Jamaica is,"Jamaican. That is when it is referred to as the country of Jamaica ""languages,"" and in practice it's usually the word to which it was translated based on the Jamaat-e-Islam Jama",fail


### Test with Checklist

Next, let's try running the MFT with Checklist. We will no longer need to keep track of results in Pandas dataframes, since Checklist will track the results for us.

#### Create the expectation function
In order to determine if an example passes or fails the test, Checklist uses an expectation function. An expectation function is a function that receives the example, then returns true if the example passes the test, or false if the example fails.

In [15]:
def response_contains_language(x, pred, conf, label=None, meta=None):
    for l in langs:
        if l in pred:
            return True
    return False

We will wrap this function with `Expect.single`, which causes the expectation function to be called for each example. In other cases, you might want to have an expectation function that checks multiple examples simulatneously. See the tutorial notebook "3. Test types, expectation functions, running tests" for detailed information about expectation functions.

In [16]:
contains_language_expect_fn = Expect.single(response_contains_language)

Now we can feed our prompts and expectation function into the MFT constructor.

In [17]:
test = MFT(**prompt_strs, name='Language in response', description='The response contains a language.', expect=contains_language_expect_fn)

In order to run the test, Checklist also needs a function that generates the model's predictions for the inputs. The function receives all inputs (prompts) as a list, and must return the results in a tuple `(model_predictions, confidences)`, where `model_predictions` is a list of all the predictions, and `confidences` is a list of the model's scores for those predictions.

We will not be using confidences in this test. Checklist provides a wrapper function `PredictorWrapper.wrap_predict()` that outputs a tuple with a confidence score of 1 for any prediction. We can use it to wrap `generate_sentences` so the predictions will have a confidence score as needed.

In [18]:
wrapped_generator = PredictorWrapper.wrap_predict(generate_sentences)
wrapped_generator(["The most commonly spoken language in Brazil is "])

(['Ã (a simple phrase with different meaning meaning), which translates differently to "to get out through the door of the house of man," and to be "away from the man, from home."\n\n'],
 array([1.]))

Now we're ready to run the test. The first argument to the `test.run()` function is the generator function we just created. We will also set the optional parameter `overwrite=True` so the test can be re-run without an error. If overwrite=False, then Checklist will reject subsequent test runs to prevent us from accidentally overwriting your test results.

In [19]:
test.run(wrapped_generator, overwrite=True)

Predicting 10 examples


To see the results, we can use the `summary` function.

In [20]:
def format_example(x, pred, conf, label=None, meta=None): 
    return 'Prompt:      %s\nCompletion:      %s' % (x, pred) 

In [21]:
test.summary(format_example_fn = format_example)

Test cases:      10
Fails (rate):    4 (40.0%)

Example fails:
Prompt:      The most commonly spoken language in Vatican City is
Completion:       Fraternity. But the Church has also declared that all people should have "extinguished degrees." The United States has not yet adopted the name "Catholic" because it simply is not Catholic enough. (For
----
Prompt:      The most commonly spoken language in The Gambia is
Completion:       Bahasa Indonesia. It derives from Bahasu, which is a popular name for the island of Sumatra, the central and southern islands near the Indian Ocean.

Singapore, also known
----
Prompt:      The most commonly spoken language in Japan is
Completion:       兝虭 (pronunciation) in which they are known as "soda-mō". Here, 彁 is a form of kagū that means "from above."

----


Test results can also be explored visually by using the `visual_summary` function.

In [22]:
test.visual_summary()

TestSummarizer(stats={'npassed': 6, 'nfailed': 4, 'nfiltered': 0}, summarizer={'name': 'Language in response',…

## MFT - Language prompt with accurate response

Let's make our test a little stricter to better understand the model's behavior. We will now require the model to respond with the correct language instead of any language in general. To simplify the logic, we will limit the prompts to use specific countries. By using the `meta=True` argument for `editor.template()`, the country associated with the prompt will be will be stored in the `country_prompts` object.


In [23]:
country_prompts = editor.template("The most commonly spoken language in {country} is  ", country = ["United States", "France", "Guatemala", "Mongolia", "Japan"], meta=True)
correct_responses = {
    "United States": "English",
    "France": "French",
    "Guatemala": "Spanish",
    "Mongolia": "Mongolian",
    "Japan": "Japanese"
}

The country metadata can be accessed with `country_prompts.meta`.

In [24]:
country_prompts.meta

[{'country': 'United States'},
 {'country': 'France'},
 {'country': 'Guatemala'},
 {'country': 'Mongolia'},
 {'country': 'Japan'}]

### Handwritten Test

In [25]:
prompts = pd.DataFrame({"id": [], "prompt": []})
responses = pd.DataFrame({"id": [], "response": []})
test_results = pd.DataFrame({"id": [], "p/f": []})

model_responses = generate_sentences(country_prompts.data)

for (i, response) in enumerate(model_responses):
    pf = 'fail'
    country = country_prompts.meta[i]["country"]
    
    # Check if the correct language is in the response
    language = correct_responses[country]
    if language in response:
        pf = 'pass'

    prompts = prompts.append({"id": i, "prompt": country_prompts.data[i]}, ignore_index=True)
    responses = responses.append({"id": i, "response": response}, ignore_index=True)
    test_results = test_results.append({"id": i, "p/f": pf}, ignore_index=True)


#### Show test results
Let's look at our test results. The first dataframe contains the prompts given to the model.

In [26]:
prompts

Unnamed: 0,id,prompt
0,0.0,The most commonly spoken language in United States is
1,1.0,The most commonly spoken language in France is
2,2.0,The most commonly spoken language in Guatemala is
3,3.0,The most commonly spoken language in Mongolia is
4,4.0,The most commonly spoken language in Japan is


The next dataframe shows the model's response to the prompt (not including the prompt itself)

In [27]:
responses

Unnamed: 0,id,response
0,0.0,̂́̃̈͂̕ ̓/ ˡ̷̙̯̅͐: ʃ ɾ̵�
1,1.0,"francais. Its most common form is 'frant l'a la' Français'[literally 'a franschal' and is still used, though in less extreme forms"
2,2.0,"ōlán (pronounced āle-le). The ʀlantyán of ūl (a.k.a., the ""old man"") is the closest."
3,3.0,"iklu, another way of saying ""hundred miles."" It is considered quite low in size by some to match the number of people living at the same point, but that is a different"
4,4.0,"カお (ぁてました), for ""to be,"" and ツめる (るろう).\n\nNoun [ edit ]\n い"


The final dataframe shows the pass/fail status of the test

In [28]:
test_results

Unnamed: 0,id,p/f
0,0.0,fail
1,1.0,fail
2,2.0,fail
3,3.0,fail
4,4.0,fail


### Testing with Checklist
Now let's run the test with Checklist. All we need is a new expectation function. The rest of the process is the same as before.

In [29]:
def response_contains_correct_language(x, pred, conf, label=None, meta=None):
    language = meta['country']
    return language in pred

In [30]:
correct_language_expect_fn = Expect.single(response_contains_correct_language)

In [31]:
test = MFT(**country_prompts, name='Correct language in response', description='The response contains the correct language for the country in the prompt.', expect=correct_language_expect_fn)

In [32]:
test.run(wrapped_generator, overwrite=True)

Predicting 5 examples


In [33]:
test.summary(format_example_fn = format_example)

Test cases:      5
Fails (rate):    5 (100.0%)

Example fails:
Prompt:      The most commonly spoken language in United States is  
Completion:      ́, which means "heir of honor," as reported by the United Kingdom and Denmark. Its origins go back to the 16th century in Sweden.

"As mentioned earlier,
----
Prompt:      The most commonly spoken language in Guatemala is  
Completion:      iavarez, but there are numerous varieties including yuwano, luyo or yucala. Some varieties have additional verbs like jingl, llu, sina and j
----
Prompt:      The most commonly spoken language in Mongolia is  
Completion:      ʿu ́̍̇Ō, but many ˾ə ˈšɪɌˊ has different endings. There is no official Chinese pronunciation of
----


In [34]:
test.visual_summary()

TestSummarizer(stats={'npassed': 0, 'nfailed': 5, 'nfiltered': 0}, summarizer={'name': 'Correct language in re…