# Aeolia Project Examples

This notebook contains examples demonstrating the usage of the Aeolia project, which is built on PyTorch Geometric Temporal. The examples will cover data loading, model training, and evaluation.

In [None]:
import sys
from pathlib import Path
from torch.utils.data import DataLoader
project_root = Path('/Users/hansen/dev/aeolia')
sys.path.append(str(project_root))
from src.models.astgcn import PolyphonyGCN
from configs.default_config import DefaultConfig
from src.data.dataset import MidiGraphDataset, temporal_graph_collate, to_dense_adj

data_dir = project_root / "data" / "raw_test"
config = DefaultConfig(project_root=project_root, config_path='/Users/hansen/dev/aeolia/configs/config.yml')
print("Loading dataset...")
dataset = MidiGraphDataset(
    npz_dir=data_dir,
    seq_length=config.periods,
    time_step=config.time_step,
    max_pitch=128,
    config=config
)

# Load dataset
dataset = MidiGraphDataset(
    npz_dir=data_dir,
    seq_length=config.periods,
    time_step=config.time_step,
    max_pitch=128,
    config=config
)

# Initialize DataLoader
data_loader = DataLoader(dataset)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=temporal_graph_collate)
batch = next(iter(dataloader))

# Initialize model
model = PolyphonyGCN(config=config)


/Users/hansen/dev/aeolia
./data/
loading composer mapping default config
loading processing statistics default config
Configuration:
K: 3
num_nodes: 234
input_dim: 32
time_kernel: 3
strides: [3, 3, 3]
period: 200
hidden_dim: 512
num_blocks: 3
Learning Rate: 0.0001
Batch Size: 8
Number of Epochs: 100
Model Save Path: ./models/
Data Path: ./data/
Log Path: ./logs/
Device: cpu
Seed: 42
num_composers: 42
max_voices: 65
Loading dataset...
Loading dataset from /Users/hansen/dev/aeolia/data/raw_test
loading composer mapping
loading processing statistics
loading rhythm mapping
Building segment index...
Found 153931 segments across 8287 files
self.node_counts_per_layer: [234, 234, 116]
Loading dataset from /Users/hansen/dev/aeolia/data/raw_test
loading composer mapping
loading processing statistics
loading rhythm mapping
Building segment index...
Found 153931 segments across 8287 files
self.node_counts_per_layer: [234, 234, 116]
idx: 9645, file_path: /Users/hansen/dev/aeolia/data/raw_test/downl

In [32]:
import torch
#from src.data.dataset import MidiGraphDataset, temporal_graph_collate, to_dense_adj
batch['features'][:, :, 0].shape
testxmask = batch['features'][:, :, 0].view(-1).to(torch.bool)

#to_dense_adj(batch['encoder_input_graphs'][0][0].edge_index).nonzero()
#batch['encoder_input_graphs'][0][0]['num_nodes']




In [62]:
batch['encoder_input_graphs'][0][0]

DataBatch(edge_index=[2, 248], edge_attr=[248], num_nodes=1872, x_mask=[8, 234], batch=[1872], ptr=[9])

In [None]:
for batch_id, item in enumerate(batch['encoder_input_graphs'][0][0]):
    print(item)

0 ('edge_index', tensor([[  47,   47,   47,   47,   47,   47,   47,   48,   48,   48,   48,   48,
           48,   48,   49,   49,   49,   49,   49,   49,   49,   50,   50,   50,
           50,   50,   50,   50,   51,   51,   51,   51,   51,   51,   51,   52,
           52,   53,   53,   54,   55,   55,   56,   57,   57,   58,   59,   59,
           59,   59,   59,   59,   59,   59,   59,   59,   59,  287,  287,  287,
          287,  287,  287,  288,  288,  288,  288,  288,  288,  289,  289,  289,
          289,  289,  289,  290,  290,  290,  290,  290,  290,  291,  291,  292,
          292,  293,  293,  294,  294,  295,  296,  296,  297,  298,  298,  298,
          298,  298,  298,  298,  298,  298,  513,  513,  513,  513,  514,  514,
          514,  514,  515,  515,  516,  516,  517,  518,  519,  519,  519,  519,
          519,  752,  752,  752,  752,  752,  752,  753,  753,  753,  753,  753,
          753,  754,  754,  754,  754,  754,  754,  755,  755,  755,  755,  755,
          7

In [43]:
batch['encoder_input_graphs'][0][0]

DataBatch(edge_index=[2, 248], edge_attr=[248], num_nodes=1872, x_mask=[8, 234], batch=[1872], ptr=[9])

In [47]:
batch['encoder_input_graphs'][0][0].edge_index

tensor([[  47,   47,   47,   47,   47,   47,   47,   48,   48,   48,   48,   48,
           48,   48,   49,   49,   49,   49,   49,   49,   49,   50,   50,   50,
           50,   50,   50,   50,   51,   51,   51,   51,   51,   51,   51,   52,
           52,   53,   53,   54,   55,   55,   56,   57,   57,   58,   59,   59,
           59,   59,   59,   59,   59,   59,   59,   59,   59,  287,  287,  287,
          287,  287,  287,  288,  288,  288,  288,  288,  288,  289,  289,  289,
          289,  289,  289,  290,  290,  290,  290,  290,  290,  291,  291,  292,
          292,  293,  293,  294,  294,  295,  296,  296,  297,  298,  298,  298,
          298,  298,  298,  298,  298,  298,  513,  513,  513,  513,  514,  514,
          514,  514,  515,  515,  516,  516,  517,  518,  519,  519,  519,  519,
          519,  752,  752,  752,  752,  752,  752,  753,  753,  753,  753,  753,
          753,  754,  754,  754,  754,  754,  754,  755,  755,  755,  755,  755,
          755,  756,  756,  

In [52]:
batch['encoder_input_graphs'][0][0].x_mask.squeeze().shape

torch.Size([8, 234])

In [56]:
slice_node_ids = torch.unique(batch['encoder_input_graphs'][0][0].edge_index)
sorted_slice_node_ids = torch.sort(slice_node_ids).values.view(-1)
x = batch['features'][:, :, 0].view(-1)
x.shape

torch.Size([1872])

In [39]:
torch.sort(slice_node_ids).values

tensor([  47,   48,   49,   50,   51,   52,   53,   54,   55,   56,   57,   58,
          59,   63,   67,  128,  129,  130,  131,  132,  209,  217,  233,  287,
         288,  289,  290,  291,  292,  293,  294,  295,  296,  297,  298,  303,
         306,  362,  363,  364,  365,  435,  443,  450,  467,  513,  514,  515,
         516,  517,  518,  519,  530,  596,  598,  668,  669,  701,  752,  753,
         754,  755,  756,  757,  758,  759,  760,  761,  762,  764,  771,  830,
         831,  832,  833,  903,  919,  935,  987,  988,  989,  990,  991,  992,
         993,  994,  995,  999, 1011, 1066, 1075, 1080, 1131, 1133, 1169, 1237,
        1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1252, 1298, 1299, 1300,
        1379, 1395, 1403, 1637, 1696, 1697, 1698, 1699, 1700, 1701, 1766, 1768,
        1838, 1871])

In [38]:
mask_to_index(testxmask)

tensor([  55,   63,   67,  128,  129,  131,  209,  217,  233,  287,  291,  303,
         306,  362,  363,  364,  365,  435,  443,  450,  467,  513,  530,  596,
         598,  668,  669,  701,  752,  756,  764,  771,  830,  831,  832,  833,
         903,  919,  935,  987,  999, 1011, 1066, 1075, 1080, 1131, 1133, 1169,
        1237, 1240, 1252, 1298, 1299, 1300, 1379, 1395, 1403, 1637, 1696, 1700,
        1766, 1768, 1838, 1871])

In [37]:
from torch_geometric.utils import coalesce, dense_to_sparse, to_dense_adj, mask_to_index