In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from pathlib import Path

import lightning as L
import matplotlib.pyplot as plt
import music21 as m21
import numpy as np
from sklearn.metrics import fbeta_score
import torch
import torch.nn.functional as F

import tonnetz_util as tnzu
import tonnetz_cnn as tnzc

# Game of life

In [None]:
tmap = tnzu.TonnetzMap()
# tmap.set_active([[0, -1], [-1, 1], [1, 0]])
tmap.set_active([[0, 0], [-1, 0]])

In [None]:
RULE_CONFIGURATION = {
    'b': (2,),  # birth
    's': (1, 2,),  # survival
}

In [None]:
gamemaps = [tmap]
for i in range(50):
    gamemaps.append(tnzu.play_life_hex(gamemaps[-1], RULE_CONFIGURATION))
imgs = [g.draw(radius=20) for g in gamemaps]

In [None]:
temp = tnzu.TonnetzMap()
temp.set_active_midi([60, 64, 67])
temp.draw(radius=30)

In [None]:
tnzu.maps2tonnetzgif(gamemaps, "out/game.gif", speed=100, radius=20)
tnzu.maps2chordscore(gamemaps).write("midi", "out/game.mid")

In [None]:
# pixcents = hx.axial_to_pixel(np.array(list(notenum2axial.values())), 1)
pixcents = (tnzu.axial_to_pixel_mat @ np.array(list(tnzu.notenum2axial.values())).T).T
# plotting on xy axis, not pixel coordinates
pixcents[:, 1] *= -1
fig, ax = plt.subplots()
ax.scatter(pixcents[:, 0], pixcents[:, 1], s=1000, c=[[0.2, 0.6, 0.1, 0.5]], marker="o", edgecolors="black")
ax.axis("equal")
for note, pixcent in zip(tnzu.notenum2axial.keys(), pixcents):
    ax.text(pixcent[0], pixcent[1], m21.note.Note(note).nameWithOctave, horizontalalignment="center", verticalalignment="center")
fig.set_size_inches(7, 7)

convention used by the hex convolutions: odd-q flat

# CNN stuff below

FINITE TONNETZ MAP IS 13 X 9

In [None]:
def predict(model, intensor):
    model.eval()
    with torch.no_grad():
        nextmeasuretensor = model(intensor)
        nextmeasureclasses = (F.sigmoid(nextmeasuretensor) >= 0.5).float()
        return nextmeasureclasses

In [None]:
nprev = 4
interval="quarter"
quarterLength = 1
# midipath = "other_midis/The Legend of Zelda The Wind Waker - Title.mid"
midipath = "other_midis/The Legend of Zelda Ocarina of Time - Gerudo Valley.mid"
songds = tnzc.MidiTonnetzDataset(midipath, nprev=nprev, interval=interval)

# songds_train, songds_test = random_split(songds, [0.8, 0.2])

trainloader = torch.utils.data.DataLoader(songds, shuffle=True, num_workers=0, batch_size=8)
# testloader = torch.utils.data.DataLoader(songds_test, shuffle=False, num_workers=0, batch_size=8)

model = tnzc.CrapModel(nchannels=nprev, pos_weight=5)
# model = tnzc.UNetModel(nchannels=nprev, pos_weight=5)
trainer = L.Trainer(max_epochs=5)
trainer.fit(model=model, train_dataloaders=trainloader)

trainer.test(model, dataloaders=trainloader)

In [None]:
model = model.to("mps")
model.eval()
with torch.no_grad():
    fbetas = []
    for chordsin, chordtruth in songds:
        predtensor = predict(model, chordsin.unsqueeze(0).to("mps"))[0]
        predarr = predtensor.cpu().numpy().squeeze()
        trutharr = chordtruth.squeeze().numpy()
        fbetas.append(fbeta_score(trutharr.reshape(-1), predarr.reshape(-1), beta=2))

In [None]:
np.mean(fbetas)

In [None]:
model = model.to("mps")
model.eval()

bpm = songds.score.parts[0].measure(1).getElementsByClass(m21.tempo.MetronomeMark)[0].number

# songds = tnzc.MidiTonnetzDataset(midipath, nprev=nprev, interval=interval, midioffset=-4)

predmaps = []
predtensors = [songds[-1][0].to("mps").float()]
with torch.no_grad():
    for i in range(100):
        nextmeasureclasses = predict(model, torch.unsqueeze(predtensors[-1], dim=0))[0]
        # remove oldest chord, append predicted chord
        predtensors.append(torch.concat([predtensors[-1][1:], nextmeasureclasses]))
        predmaps.append(tnzu.TonnetzMap.from_oddqgrid(nextmeasureclasses.to("cpu").squeeze().numpy()))
predscore = tnzu.maps2chordscore(predmaps, quarterLength=quarterLength, bpm=bpm)
predscore.write("midi", "out/predicted.mid")
tnzu.maps2tonnetzgif(predmaps, "out/predicted.gif", speed=100, radius=20)

In [None]:
predgrids = np.array([predmap.to_oddq_grid() for predmap in predmaps])
truthgrids = songds.oddqgrids[nprev:]
fbetas = []
for i in range(min(len(predgrids), len(truthgrids))):
    fbetas.append(fbeta_score(truthgrids[i].reshape(-1), predgrids[i].reshape(-1), beta=2))

In [None]:
plt.plot(fbetas)

In [None]:
testscore = tnzu.maps2chordscore(songds.tonnetzmaps, quarterLength=quarterLength)
testscore.write("midi", "out/truth.mid")
tnzu.maps2tonnetzgif(songds.tonnetzmaps, "out/truth.gif", speed=100, radius=20)

## mass cnn testing below, beware

In [None]:
nprev = 4
interval="quarter"
quarterLength = 1

# this is basically already shuffled
midipaths = list(Path("piano_midis").glob("*.mid"))
midiresults = {}

for midipath in midipaths[:100]:
    midisingleres = {}
    songds = tnzc.MidiTonnetzDataset(midipath, nprev=nprev, interval=interval)
    midisingleres["songds"] = songds

    # songds_train, songds_test = random_split(songds, [0.8, 0.2])

    trainloader = torch.utils.data.DataLoader(songds, shuffle=True, num_workers=0, batch_size=16)
    # testloader = torch.utils.data.DataLoader(songds_test, shuffle=False, num_workers=0, batch_size=8)

    # model = tnzc.CrapModel(nchannels=nprev, pos_weight=5)
    model = tnzc.UNetModel(nchannels=nprev, pos_weight=5)
    trainer = L.Trainer(max_epochs=5)
    trainer.fit(model=model, train_dataloaders=trainloader)
    midisingleres["model"] = model  # yeah it's named incorrectly

    testres = trainer.test(model, dataloaders=trainloader)
    midisingleres["trainloss"] = testres[0]["test_loss"]  # yeah it's named incorrectly

    model = model.to("mps")
    model.eval()
    with torch.no_grad():
        fbetas = []
        for chordsin, chordtruth in songds:
            predtensor = predict(model, chordsin.unsqueeze(0).to("mps"))[0]
            predarr = predtensor.cpu().numpy().squeeze()
            trutharr = chordtruth.squeeze().numpy()
            fbetas.append(fbeta_score(trutharr.reshape(-1), predarr.reshape(-1), beta=2))
    midisingleres["fbetas"] = fbetas
    #--------------------------------------------------
    bpm = songds.score.parts[0].measure(1).getElementsByClass(m21.tempo.MetronomeMark)[0].number

    predmaps = []
    predtensors = [songds[0][0].to("mps").float()]
    with torch.no_grad():
        for i in range(100):
            nextmeasureclasses = predict(model, torch.unsqueeze(predtensors[-1], dim=0))[0]
            # remove oldest chord, append predicted chord
            predtensors.append(torch.concat([predtensors[-1][1:], nextmeasureclasses]))
            predmaps.append(tnzu.TonnetzMap.from_oddqgrid(nextmeasureclasses.to("cpu").squeeze().numpy()))
    predscore = tnzu.maps2chordscore(predmaps, quarterLength=quarterLength, bpm=bpm)
    predscore.write("midi", f"out/predicted/{midipath.stem}-predicted.mid")
    tnzu.maps2tonnetzgif(predmaps, f"out/predicted/{midipath.stem}-predicted.gif", speed=100, radius=20)

    predgrids = np.array([predmap.to_oddq_grid() for predmap in predmaps])
    truthgrids = songds.oddqgrids[nprev:]
    chain_fbetas = []
    for i in range(min(len(predgrids), len(truthgrids))):
        chain_fbetas.append(fbeta_score(truthgrids[i].reshape(-1), predgrids[i].reshape(-1), beta=2))
    midisingleres["chain_fbetas"] = chain_fbetas
    midiresults[midipath] = midisingleres

In [None]:
trainlosssum = 0
fbetasum = 0
febetachainsum = 0
for midipath, result in midiresults.items():
    trainlosssum += result["trainloss"]
    fbetaarr = np.array(result["fbetas"])
    fbetasum += np.mean(result["fbetas"])
    febetachainsum += np.mean(result["chain_fbetas"])

In [None]:
trainlosssum / len(midiresults), fbetasum / len(midiresults), febetachainsum / len(midiresults)

In [None]:
febetachain_arrs = []
for midipath, result in midiresults.items():
    if len(result["chain_fbetas"]) == 100:
        febetachain_arrs.append(result["chain_fbetas"])
febetachain_arrs = np.array(febetachain_arrs)

In [None]:
np.mean(febetachain_arrs, axis=0)

In [None]:
fig, ax = plt.subplots()
ax.plot(np.mean(febetachain_arrs_simplcnn, axis=0))
ax.plot(np.mean(febetachain_arrs, axis=0))
ax.legend(["Simple CNN", "UNet"])
ax.set_xlabel("Chord timestep")
ax.set_ylabel("F2 score")
ax.title.set_text("Simple CNN vs UNet")
fig.suptitle("Average F2 score for 100 models over 100 chord timesteps")