## Imports

In [1]:
import json
import glob
import itertools
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import OrderedDict

import librosa
from IPython.display import Audio as ipy_audio
from IPython.core.display import display

from quicktranscribe import tonic, pitch, wave, kde
from mogra import tonnetz
from mogra.datatypes import Swar, normalize_frequency, ratio_to_swar, SWAR_BOUNDARIES

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch.utils.data import Dataset, DataLoader
# from torch_geometric_temporal.nn.recurrent import GConvGRU

## Tonnetz Graph

In [3]:
# define the net
gs = tonnetz.EFGenus.from_list([3,3,3,3,5,5])
tn = tonnetz.Tonnetz(gs)

In [4]:
# get the adjacency and equivalence matrices
adjac = tn.adjacency_matrix()
equiv = tn.equivalence_matrix()

## Pitch Tracked Data

Util Functions

In [5]:
def read_sample_and_tonic(track_path, start_min=5, end_min=10):
    
    ctonic = tonic.read_tonic(track_path + ".ctonic.txt")
    y_stereo, sr = wave.read_audio_section(track_path + ".mp3", int(start_min*60), int(end_min*60))
    y_sample = librosa.to_mono(y_stereo.T)
    
    return y_sample, sr, ctonic

In [6]:
def preprocess(track):
    # downsample
    track = track[::16]
    # drop nans
    track = track[~np.isnan(track)]
    # round to nearest int
    track = np.round(track).astype(int)
    
    return track

Read Data

In [7]:
DATA_DIR = "concrete-demo/"

In [8]:
artist = "AjoyChakrabarty"
raags_and_times = {
    # "Bhoop": [(4,6), (7,9), (10,12), (13,15), (16,18), (19,21)],
    # "Deshkar": [(4,6), (7,9), (10,12), (13,15), (16,18), (19,21)],
    # "Bhoop": [(3,21)],
    # "Deshkar": [(3,21)],
    "Bhoop": [(3,21)],
    "Deshkar": [(3,21)],
}

In [9]:
# seqs = OrderedDict({rr:[] for rr in raags_and_times.keys()})
# for raag in raags_and_times.keys():
#     track_mp3 = glob.glob(DATA_DIR + f"*{raag}*{artist}.mp3")[0]
#     track_path = track_mp3[:-4]
#     for start_min, end_min in tqdm(raags_and_times[raag]):
#         y_sample, sr, ctonic = read_sample_and_tonic(track_path, start_min, end_min)
#         ftrack = pitch.track_pitch_pyin(y_sample, sr, ctonic)
#         seqs[raag].append(preprocess(ftrack))

In [10]:
# import pickle
# with open("seqs.pkl", "wb") as f:
#     pickle.dump(seqs, f)

In [11]:
import pickle
with open("seqs.pkl", "rb") as f:
    seqs = OrderedDict(pickle.load(f))

Define + Plot Ground Truth

In [12]:
raag_gt_ratios = {
    "Bhoop": [1, 10/9, 5/4, 3/2, 5/3],
    "Deshkar": [1, 9/8, 81/64, 3/2,  27/16],
    "Yaman": [1, 9/8, 5/4, 45/32, 3/2, 27/16, 15/8],
}
raag_gt_nodes = {
    "Bhoop": [(0,0), (1,0), (0,1), (-1,1), (-2,1)],
    "Deshkar": [(0,0), (1,0), (2,0), (3,0), (4,0)],
    "Yaman": [(0,0), (1,0), (2,0), (1,1), (2,1), (0,1), (3,0)],
}

In [13]:
# import plotly.graph_objects as go
# for raag, gt_nodes in raag_gt_nodes.items():
#     fig = go.Figure(data=[go.Scatter(
#         x=tn.coords3d[0],
#         y=tn.coords3d[1],
#         mode="text+markers",
#         marker=dict(
#             size=21,
#             symbol="circle",
#             color=["#e0b724" if coord in gt_nodes else "midnightblue" for coord in tn.node_coordinates]
#         ),
#         text=tn.node_names,
#         textposition="middle center",
#         textfont=dict(family="Overpass", size=13, color="white"),
#     )])
    
#     # fig = tn.prep_plot(fig)
#     fig.update_layout(title=f"Raag {raag}", xaxis_title="powers of 3", yaxis_title="powers of 5")
#     # set major ticks
#     fig.update_xaxes(tickvals=np.arange(-4, 5))
#     fig.update_yaxes(tickvals=np.arange(-2, 3))
#     # fig.update_xaxes(tickvals=np.arange(-2, 3), ticktext=[f"$3^{ii}$" for ii in np.arange(-2, 3)])
#     fig.show()

## Model Setup

In [14]:
from mogra.tgnn import swarwise_loss

Parameters

In [15]:
num_nodes = len(tn.node_coordinates)
input_dim = 1  # Feature dimension
hidden_dim = 16  # Hidden layer size
num_classes = 12  # Number of equivalence classes
num_time_steps = 25  # Number of time steps
batch_size = 2  # Number of samples in each batch

Prepare Data

In [16]:
adjac = torch.tensor(adjac, dtype=torch.long)
equiv = torch.tensor(equiv, dtype=torch.long)

In [17]:
# get an array from seqs
min_length = min([len(ii) for ii in seqs["Bhoop"] + seqs["Deshkar"]])
xx = np.array([ii[:min_length] for raag, values in seqs.items() for ii in values])

# xx is currently num_samples x sample_length and contains node indices, make it num_samples x num_nodes x sample_length x input_dim with one-hot node encoding
xseq = np.zeros((xx.shape[0], num_nodes, xx.shape[1], input_dim))
for i in range(xx.shape[0]):
    for j in range(xx.shape[1]):
        tone_options, _ = tn.get_swar_options(Swar(xx[i, j]%12).name)
        for tone_option in tone_options:
            xseq[i, tn.node_coordinates.index(tone_option), j, 0] = 1
xseq = torch.tensor(xseq, dtype=torch.float32)

In [18]:
tn_indices = lambda nodes: [tn.node_coordinates.index(ii) for ii in nodes]
def tn_indices(nodes):
    tni = np.zeros(len(tn.node_coordinates), dtype=int)
    for ii in [tn.node_coordinates.index(ii) for ii in nodes]:
        tni[ii] = 1
    return tni
yy = np.array([tn_indices(raag_gt_nodes[raag]) for raag, values in seqs.items() for ii in values])

# we want to stratify yy labels into 12 classes, with each equivalence class having one (or no) labels
yt = torch.tensor(yy, dtype=torch.long)
yraag = torch.tensor(np.zeros((yt.shape[0], 12)), dtype=torch.long)
for ii, mask in enumerate(equiv.T):
    for jj in range(yt.shape[0]):
        yraag[jj, ii] = (yt[jj, :] * mask).argmax() if sum(yt[jj, :] * mask) > 0 else -1

In [None]:
print(xseq.shape, yraag.shape)

In [20]:
# # DataLoader:
# # we have num_samples of length num_nodes x sample_length x input_dim
# # we want to break down into num_time_steps from sample_length, and return in batches of batch_size
# # so we want to return "x"s of batch_size x num_nodes x num_time_steps x input_dim and "y"s of batch_size x num_classes

# class SeqDataset(Dataset):
#     def __init__(self, xseq, yraag, num_time_steps):
#         self.xseq = xseq
#         self.yraag = yraag
#         self.num_time_steps = num_time_steps
#         self.seqs_per_sample = xseq.shape[2] // num_time_steps

#     def __len__(self):
#         return self.xseq.shape[0] * self.seqs_per_sample

#     def __getitem__(self, idx):
#         sample_idx = idx // self.seqs_per_sample
#         insample_idx = idx % self.seqs_per_sample
#         x = self.xseq[sample_idx, :, insample_idx*self.num_time_steps:(insample_idx+1)*self.num_time_steps, :]
#         y = self.yraag[sample_idx]
#         return x, y

# dataset = SeqDataset(xseq, yraag, num_time_steps)
# # split into train and val
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

---

Fully featurize...

We only need `equiv` to get `yraag` (don't we?). We use `adjac` in the graph, but it doesn't contain the full information of the network -- namely, the type of edges (x3, x5, /3, /5).

In [None]:
print(adjac.shape, equiv.shape)
print(len(tn.node_coordinates))
print(xseq.shape)

In [None]:
yraag[0]

In [None]:
# one time step
x = xseq[0, :, 0, 0]
print(x)

In [None]:
np.array(tn.node_coordinates)[np.where(x>0)[0]]

The following features need to be associated with this sample:<br>
- a variable length list of "options" (lol -- is this even the right approach?)<br>

Each "option" has the following features:<br>
- a 3-coordinate
- a 5-coordinate
- a cyclic frequency value

Cyclic Frequency Value

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(tn.node_frequencies)
plt.subplot(1, 2, 2)
plt.plot(sorted(tn.node_frequencies), marker="o")

In [53]:
def featurize_ratio(ratio):
    # Extract the decimal part of the frequencies
    decimal_part = ratio % 1

    # Convert the decimal parts to angles (in radians)
    angle = decimal_part * 2 * np.pi

    # Compute the sine and cosine of the angles
    sin_feature = np.sin(angle)
    cos_feature = np.cos(angle)
    cyclic_feature = (sin_feature, cos_feature)
    
    return cyclic_feature

In [None]:
node_features = np.array([featurize_ratio(ii) for ii in tn.node_frequencies])
plt.figure(figsize=(6, 6))
plt.plot(node_features[:,0], node_features[:,1], marker="o")

---

Also, instead of the labels being for the whole sequence, get them to be per time step...

---

## Model

In [None]:
# class RecurrentGCN(torch.nn.Module):
#     def __init__(self, node_features, hidden_dim):
#         super(RecurrentGCN, self).__init__()
        
#         self.hidden_dim = hidden_dim
        
#         self.recurrent_forward = GConvGRU(node_features, hidden_dim, K=2)
#         self.recurrent_backward = GConvGRU(node_features, hidden_dim, K=2)
        
#         self.linear = torch.nn.Linear(2 * hidden_dim, 1)
    
#     def forward(self, x, edge_index, edge_weight=None):
#         # x shape: [batch_size, num_nodes, num_time_steps, num_features]
#         batch_size, num_nodes, num_time_steps, _ = x.size()
        
#         # Initialize hidden states for both directions
#         h_forward = torch.zeros(batch_size, num_nodes, self.hidden_dim, device=x.device)
#         h_backward = torch.zeros(batch_size, num_nodes, self.hidden_dim, device=x.device)
        
#         # Forward pass through time (0 to T)
#         for t in range(num_time_steps):
#             x_t = x[:, :, t, :]  # Node features at time step t
#             h_forward = self.recurrent_forward(x_t, edge_index, h_forward, edge_weight)
        
#         # Backward pass through time (T to 0)
#         for t in reversed(range(num_time_steps)):
#             x_t = x[:, :, t, :]  # Node features at time step t
#             h_backward = self.recurrent_backward(x_t, edge_index, h_backward, edge_weight)
        
#         # Concatenate hidden states from both directions
#         h_concat = torch.cat([h_forward, h_backward], dim=-1)  # Shape: [batch_size, num_nodes, 2 * hidden_dim]
        
#         # Apply linear layer to each node's concatenated hidden state
#         h_concat = F.relu(h_concat)
#         out = self.linear(h_concat)  # Shape: [batch_size, num_nodes, 1]
        
#         return out.squeeze(-1)

In [None]:
# # Instantiate the model
# model = RecurrentGCN(input_dim, hidden_dim)
# # model = TemporalGNN(input_dim, hidden_dim, num_classes, num_time_steps)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


In [None]:
# # 1 Step of Training Loop
# for x, y in train_dataloader:
#     break

# optimizer.zero_grad()
# out = model(x, adjac, None)

## Train + Eval

In [23]:
train_losses = []
val_losses = []
# optimizer.zero_grad()
for epoch in range(10):
    optimizer.zero_grad()
    train_loss = 0
    for x, y in train_dataloader:
        if x.shape[0] != batch_size:
            continue
        try:
            out = model(x, adjac, equiv)
        except:
            continue
        loss = swarwise_loss(out, y, batch_size)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    # print("train loss: ", train_loss)
    train_losses.append(train_loss/len(train_dataloader))
        
    val_loss = 0
    for x, y in val_dataloader:
        if x.shape[0] != batch_size:
            continue
        try:
            out = model(x, adjac, equiv)
        except:
            continue
        loss = swarwise_loss(out, y, batch_size)
        val_loss += loss.item()
    # print("eval loss: ", val_loss)
    val_losses.append(val_loss/len(val_dataloader))
    

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.legend()
plt.show()

In [25]:
# len(train_dataloader)

Confusion Matrix

In [26]:
def get_preds(out_scores, threshold):
    bs = out_scores.shape[0]
    out_cs = out_scores.view(bs*12, -1)
    out_cs_argmax = out_cs.argmax(dim=1)
    out_cs_max = out_cs.max(dim=1)
    for ii, score in enumerate(out_cs_max.values):
        if score < threshold:
            out_cs_argmax[ii] = -1
    out_preds = out_cs_argmax.view(bs, 12)
    return out_preds

In [None]:
labels = []
preds = []
for x, y in val_dataloader:
    out = model(x, adjac, equiv)
    preds.extend(get_preds(out, 0.5).flatten().detach().numpy())
    labels.extend(y.flatten().detach().numpy())

In [None]:
# Plot confusion matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns

cm = confusion_matrix(labels, preds, labels=np.arange(-1, 45, 2))
plt.figure(figsize=(10,7))
sns.heatmap(cm, annot=True, fmt='d', vmax=15, xticklabels=False, yticklabels=False)
plt.xlabel('Predicted')
plt.ylabel('Truth')
plt.show()
