In [1]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)
import time

In [2]:
# adjust these constants, then run entire notebook.

path_to_model = "./2021-07-19.1/results/checkpoint-19810"

test_parameters = {
        'max_length': 128,
        'temperature': 0.9,
        'no_repeat_ngram_size': 4,
        'do_sample': True,
        'top_k': 50}

test_prompts = [
    "a",
    "G F G A G A |"
]

In [3]:
tokenizer = AutoTokenizer.from_pretrained(path_to_model)
model = AutoModelForCausalLM.from_pretrained(path_to_model)

In [4]:
def generate(start_text = "a", number = 6, parameters = test_parameters):
    # encoding the input text
    input_ids = tokenizer.encode(start_text, return_tensors='pt')
    output = model.generate(input_ids, num_return_sequences = number, **parameters)
    return output

In [7]:
def test_suite(path = path_to_model,
               parameters = test_parameters,
               prompts = test_prompts):
    
    current_time = time.strftime("%Y-%m-%dt%H:%M:%S", time.localtime())
    filename = "test_results/{t}.txt".format(t = current_time)
    print("running tests on model at {}".format(path))
    print("saving to {}".format(filename))
    
    with open(filename, 'a') as file:
        file.write('{}\n\n'.format(current_time))
        file.write("path_to_model: {}\n\n".format(path_to_model))
        file.write("parameters:\n")
        for p in parameters:
            file.write("    {}: {}\n".format(p, parameters[p]))
        
        file.write("\n=========\n\n")
        
        test_count = 0
        
        for prompt in prompts:
            
            test_count += 1
            print("running test {}: {}".format(test_count, prompt))
            file.write("test {}:\n".format(test_count))
            file.write("prompt:\n")
            file.write("{}\n\n\n".format(prompt))
            
            output = generate(prompt)
            for setting in output:
                file.write(tokenizer.decode(setting) + "\n\n")
            
            file.write("\n---------\n\n")

In [8]:
test_suite()

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running tests on model at ./2021-07-19.1/results/checkpoint-19810
saving to test_results/2021-07-20t17:07:11.txt
running test 1: a


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 2: G F G A G A |
