In [2]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)
import time

In [10]:
# adjust these constants, then run entire notebook.

path_to_model = "./2021-07-23.2/results/checkpoint-7500"

In [4]:
test_parameters = {
        'max_length': 128,
        'temperature': 0.9,
        'no_repeat_ngram_size': 4,
        'do_sample': True,
        'top_k': 50}

test_prompts = [
    
    # single notes
    "G",
    "e",
    
    # drowsy maggie - Reel, Edor
    """R: reel
M: 4/4
K: Edor
|: E2 B E d E B E |""",
    
    # kesh jig - Jig, Gmaj
    """R: jig
M: 6/8
K: Gmaj
|: G3 G A B |""",
    
    # john ryan's - Polka, Dmaj 
    """R: polka
M: 2/4
K: Dmaj
d d B/ c/ d/ B/ |""",
    
    # king of the fairies - Hornpipe
    """R: hornpipe
M: 4/4
K: Edor
|: B,2 | E D E F G F G A""",
    
    # inisheer - Waltz
    """R: waltz
M: 3/4
K: Gmaj
B2 B A B d |""",
    
    # the butterfly - Slip Jig
    """R: slip jig
M: 9/8
K: Emin
|: B2 E G2 E F3 |""",
    
    # banish misfortune - Dmix
    """R: jig
M: 6/8
K: Dmix
f e d c A G |""",
    
    # tam lin - Dmin, low register
    """R: reel
M: 4/4
K: Dmin
A,2 D A, F A, D A, |""",
    
    # cliffs of moher - Ador, high register
    """R: jig
M: 6/8
K: Adow
|: a3 b a g |""",
    
    # silver spear - no closing barline
    """R: reel
M: 4/4
K: Dmaj
A |: F A (3 A A A B A F A""",
    
    # unspecified tune types and meters
    "R: jig",
    "R: reel",
    "R: polka",
    "R: waltz",
    "R: hornpipe",
    "R: slip jig",
    "M: 4/4",
    "M: 6/8",
    "M: 3/4",
    "M: 2/4",
    "M: 9/8"
]

In [11]:
tokenizer = AutoTokenizer.from_pretrained(path_to_model)
model = AutoModelForCausalLM.from_pretrained(path_to_model)

In [12]:
def generate(start_text = "a", number = 6, parameters = test_parameters):
    # encoding the input text
    input_ids = tokenizer.encode(start_text, return_tensors='pt')
    output = model.generate(input_ids, num_return_sequences = number, **parameters)
    return output

In [13]:
def test_suite(path = path_to_model,
               parameters = test_parameters,
               prompts = test_prompts,
               adjust_tabs = True):
    
    current_time = time.strftime("%Y-%m-%dt%H:%M:%S", time.localtime())
    filename = "test_results/{t}.txt".format(t = current_time)
    print("running tests on model at {}".format(path))
    print("saving to {}".format(filename))
    
    with open(filename, 'a') as file:
        file.write('{}\n\n'.format(current_time))
        file.write("path_to_model: {}\n\n".format(path_to_model))
        file.write("parameters:\n")
        for p in parameters:
            file.write("    {}: {}\n".format(p, parameters[p]))
        
        file.write("\n=========\n\n")
        
        test_count = 0
        
        for prompt in prompts:
            test_count += 1
            print("running test {}: {}".format(test_count, prompt))
            file.write("test {}:\n".format(test_count))
            file.write("prompt:\n")
            file.write("{}\n\n\n".format(prompt))
            if adjust_tabs:
                prompt = prompt.replace('\n', '\t')
            output = generate(prompt)
            for setting in output:
                s = tokenizer.decode(setting)
                if adjust_tabs:
                    s = s.replace('\t', '\n')
                file.write(s + "\n\n")
            
            file.write("\n---------\n\n")

In [14]:
test_suite(adjust_tabs=True)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running tests on model at ./2021-07-23.2/results/checkpoint-7500
saving to test_results/2021-07-23t13:25:33.txt
running test 1: G


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 2: e


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 3: R: reel
M: 4/4
K: Edor
|: E2 B E d E B E |


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 4: R: jig
M: 6/8
K: Gmaj
|: G3 G A B |


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 5: R: polka
M: 2/4
K: Dmaj
d d B/ c/ d/ B/ |


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 6: R: hornpipe
M: 4/4
K: Edor
|: B,2 | E D E F G F G A


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 7: R: waltz
M: 3/4
K: Gmaj
B2 B A B d |


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 8: R: slip jig
M: 9/8
K: Emin
|: B2 E G2 E F3 |


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 9: R: jig
M: 6/8
K: Dmix
f e d c A G |


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 10: R: reel
M: 4/4
K: Dmin
A,2 D A, F A, D A, |


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 11: R: jig
M: 6/8
K: Adow
|: a3 b a g |


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 12: R: reel
M: 4/4
K: Dmaj
A |: F A (3 A A A B A F A


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 13: R: jig


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 14: R: reel


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 15: R: polka


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 16: R: waltz


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 17: R: hornpipe


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 18: R: slip jig


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 19: M: 4/4


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 20: M: 6/8


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 21: M: 3/4


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 22: M: 2/4


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


running test 23: M: 9/8


In [None]:
output = generate('G A B A d d |')
for x in output:
    print(tokenizer.decode(x), '\n')