In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import music21 as m21
import numpy as np

import tonnetz_util as tnzu
import tonnetz_cnn as tnzc

In [None]:
tmap = tnzu.TonnetzMap()
# tmap.set_active([[0, -1], [-1, 1], [1, 0]])
tmap.set_active([[0, 0], [-1, 0]])

In [None]:
RULE_CONFIGURATION = {
    'b': (2,),  # birth
    's': (1, 2,),  # survival
}

In [None]:
gamemaps = [tmap]
for i in range(50):
    gamemaps.append(tnzu.play_life_hex(gamemaps[-1], RULE_CONFIGURATION))
imgs = [g.draw(radius=20) for g in gamemaps]

In [None]:
temp = tnzu.TonnetzMap()
temp.set_active_midi([60, 64, 67])
temp.draw(radius=30)

In [None]:
tnzu.maps2tonnetzgif(gamemaps, "out/game.gif", speed=100, radius=20)
tnzu.maps2chordscore(gamemaps).write("midi", "out/game.mid")

In [None]:
gamemaps[3].to_oddq_grid()

In [None]:
# pixcents = hx.axial_to_pixel(np.array(list(notenum2axial.values())), 1)
pixcents = (tnzu.axial_to_pixel_mat @ np.array(list(tnzu.notenum2axial.values())).T).T
# plotting on xy axis, not pixel coordinates
pixcents[:, 1] *= -1
fig, ax = plt.subplots()
ax.scatter(pixcents[:, 0], pixcents[:, 1], s=1000, c=[[0.2, 0.6, 0.1, 0.5]], marker="o", edgecolors="black")
ax.axis("equal")
for note, pixcent in zip(tnzu.notenum2axial.keys(), pixcents):
    ax.text(pixcent[0], pixcent[1], m21.note.Note(note).nameWithOctave, horizontalalignment="center", verticalalignment="center")
fig.set_size_inches(7, 7)

convention used by the hex convolutions: odd-q flat

In [None]:
score = m21.converter.parse("other_midis/The Legend of Zelda The Wind Waker - Title.mid")

In [None]:
score.parts[0].measure(1).getElementsByClass(m21.tempo.MetronomeMark)[0].number

In [None]:
score.parts[0].show("text")

In [None]:
for m in score.parts[0]:
    print(m.flatten().show("text"))

In [None]:
score.parts[0].measure(1).flatten().show("text")

In [None]:
score.parts[0].measure(2).flatten().show("text")

In [None]:
score.parts

In [None]:
for m in score.parts[0]:
    notes = list(m.flatten().notes)
    midivals = set()
    for n in notes:
        print(n)
        # if isinstance(n, m21.chord.Chord):
        #     midivals.update([chnote.pitch.midi for chnote in n.notes])
        # else:
        #     midivals.add(n.pitch.midi)
    midivals = list(midivals)
    print("-------")
    # measuremap = TonnetzMap()
    # measuremap.set_active_midi(midivals)

In [None]:
oops = tnzu.midi_to_tonnetzmaps("other_midis/The Legend of Zelda The Wind Waker - Title.mid", interval="eighth")
oops[25].draw(radius=20)

In [None]:
import hexagdly
import lightning as L
from sklearn import metrics
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.random_projection import GaussianRandomProjection
from sklearn.linear_model import LogisticRegression
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision.datasets import MNIST, EMNIST
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, random_split
from torchvision import datasets
from torchvision.transforms import v2
import torchvision.transforms.functional as T

FINITE TONNETZ MAP IS 13 X 9

In [None]:
nprev = 4
interval="quarter"
quarterLength = 1
songds = tnzc.MidiTonnetzDataset("other_midis/The Legend of Zelda Ocarina of Time - Gerudo Valley.mid", nprev=nprev, interval=interval)

songds_train, songds_test = random_split(songds, [0.8, 0.2])

trainloader = torch.utils.data.DataLoader(songds_train, shuffle=True, num_workers=0, batch_size=8)
testloader = torch.utils.data.DataLoader(songds_test, shuffle=False, num_workers=0, batch_size=8)

# unetmodel = CrapModel(nchannels=nprev, pos_weight=5)
unetmodel = tnzc.UNetModel(nchannels=nprev, pos_weight=5)
trainer = L.Trainer(max_epochs=5)
trainer.fit(model=unetmodel, train_dataloaders=trainloader)

trainer.test(unetmodel, dataloaders=testloader)

In [None]:
unetmodel = unetmodel.to("mps")
unetmodel.eval()
bpm = songds.score.parts[0].measure(1).getElementsByClass(m21.tempo.MetronomeMark)[0].number
with torch.no_grad():
    predmaps = [tnzu.TonnetzMap.from_oddqgrid(songds[0][0][0].numpy())]
    predtensors = [songds[0][0].to("mps").float()]
    for i in range(100):
        nextmeasuretensor = unetmodel(torch.unsqueeze(predtensors[-1], dim=0))
        nextmeasureclasses = (F.sigmoid(nextmeasuretensor[0]) >= 0.5).float()
        predtensors.append(torch.concat([predtensors[-1][1:], nextmeasureclasses]))
        predmaps.append(tnzu.TonnetzMap.from_oddqgrid(nextmeasureclasses.to("cpu").squeeze().numpy()))
    predscore = tnzu.maps2chordscore(predmaps, quarterLength=quarterLength, bpm=bpm)
    predscore.write("midi", "out/predicted.mid")
    tnzu.maps2tonnetzgif(predmaps, "out/predicted.gif", speed=100, radius=20)

In [None]:
testscore = tnzu.maps2chordscore(songds.tonnetzmaps, quarterLength=quarterLength)
testscore.write("midi", "out/truth.mid")
tnzu.maps2tonnetzgif(songds.tonnetzmaps, "out/truth.gif", speed=100, radius=20)