In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)
import time

In [None]:
# adjust these constants, then run entire notebook.

path_to_model = "./2021-07-19.1/results/checkpoint-19810"

test_parameters = {
        'max_length': 128,
        'temperature': 0.9,
        'no_repeat_ngram_size': 4,
        'do_sample': True,
        'top_k': 50}

test_prompts = [
    
    # single notes
    "G",
    "e",
    
    # drowsy maggie - Reel, Edor
    """R: reel
M: 4/4
K: Edor
|: E2 B E d E B E |""",
    
    # kesh jig - Jig, Gmaj
    """R: jig
M: 6/8
K: Gmaj
|: G3 G A B |""",
    
    # john ryan's - Polka, Dmaj 
    """R: polka
M: 2/4
K: Dmaj
d d B/ c/ d/ B/ |""",
    
    # king of the fairies - Hornpipe
    """R: hornpipe
M: 4/4
K: Edor
|: B,2 | E D E F G F G A""",
    
    # inisheer - Waltz
    """R: waltz
M: 3/4
K: Gmaj
B2 B A B d |""",
    
    # the butterfly - Slip Jig
    """R: slip jig
M: 9/8
K: Emin
|: B2 E G2 E F3 |""",
    
    # banish misfortune - Dmix
    """R: jig
M: 6/8
K: Dmix
f e d c A G |""",
    
    # tam lin - Dmin, low register
    """R: reel
M: 4/4
K: Dmin
A,2 D A, F A, D A, |""",
    
    # cliffs of moher - Ador, high register
    """R: jig
M: 6/8
K: Adow
|: a3 b a g |""",
    
    # silver spear - no closing barline
    """R: reel
M: 4/4
K: Dmaj
A |: F A (3 A A A B A F A""",
    
    # unspecified tune types and meters
    "R: jig",
    "R: reel",
    "R: polka",
    "R: waltz",
    "R: hornpipe",
    "R: slip jig",
    "M: 4/4",
    "M: 6/8",
    "M: 3/4",
    "M: 2/4",
    "M: 9/8"
]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(path_to_model)
model = AutoModelForCausalLM.from_pretrained(path_to_model)

In [None]:
def generate(start_text = "a", number = 6, parameters = test_parameters):
    # encoding the input text
    input_ids = tokenizer.encode(start_text, return_tensors='pt')
    output = model.generate(input_ids, num_return_sequences = number, **parameters)
    return output

In [None]:
def test_suite(path = path_to_model,
               parameters = test_parameters,
               prompts = test_prompts):
    
    current_time = time.strftime("%Y-%m-%dt%H:%M:%S", time.localtime())
    filename = "test_results/{t}.txt".format(t = current_time)
    print("running tests on model at {}".format(path))
    print("saving to {}".format(filename))
    
    with open(filename, 'a') as file:
        file.write('{}\n\n'.format(current_time))
        file.write("path_to_model: {}\n\n".format(path_to_model))
        file.write("parameters:\n")
        for p in parameters:
            file.write("    {}: {}\n".format(p, parameters[p]))
        
        file.write("\n=========\n\n")
        
        test_count = 0
        
        for prompt in prompts:
            
            test_count += 1
            print("running test {}: {}".format(test_count, prompt))
            file.write("test {}:\n".format(test_count))
            file.write("prompt:\n")
            file.write("{}\n\n\n".format(prompt))
            
            output = generate(prompt)
            for setting in output:
                file.write(tokenizer.decode(setting) + "\n\n")
            
            file.write("\n---------\n\n")

In [None]:
test_suite()

In [None]:
output = generate('G A B A d d |')
for x in output:
    print(tokenizer.decode(x), '\n')