In [1]:
from data_utils import SeparatedMelHarmMarkovDataset
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import torch
from torch.optim import AdamW
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_dir = '/mnt/ssd2/maximos/data/hooktheory_train'

In [3]:
chordSymbolTokenizer = ChordSymbolTokenizer.from_pretrained('saved_tokenizers/ChordSymbolTokenizer')
rootTypeTokenizer = RootTypeTokenizer.from_pretrained('saved_tokenizers/RootTypeTokenizer')
pitchClassTokenizer = PitchClassTokenizer.from_pretrained('saved_tokenizers/PitchClassTokenizer')
rootPCTokenizer = RootPCTokenizer.from_pretrained('saved_tokenizers/RootPCTokenizer')
melodyPitchTokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')

In [4]:
m_chordSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, chordSymbolTokenizer)
m_rootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootTypeTokenizer)
m_pitchClassTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, pitchClassTokenizer)
m_rootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootPCTokenizer)

In [5]:
print(m_chordSymbolTokenizer.harmony_tokenizer.__class__.__name__)

ChordSymbolTokenizer


In [6]:
tokenizer = m_chordSymbolTokenizer

dataset = SeparatedMelHarmMarkovDataset(root_dir, tokenizer, max_length=512, num_bars=64)
# Data collator for BART
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

In [7]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

model = BartForConditionalGeneration(bart_config)

In [8]:
collator = create_data_collator(tokenizer, model=model)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collator)

In [9]:
b = next(iter(dataloader))

  return self.iter().getElementsByClass(classFilterList)
  return self.iter().getElementsByClass(classFilterList)
In /home/maximos/.local/lib/python3.11/site-packages/matplotlib/mpl-data/stylelib/seaborn-v0_8-deep.mplstyle: .flat is deprecated.  Call .flatten() instead
In /home/maximos/.local/lib/python3.11/site-packages/matplotlib/mpl-data/stylelib/seaborn-v0_8-notebook.mplstyle: .flat is deprecated.  Call .flatten() instead
  return self.iter().getElementsByClass(classFilterList)
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


In [None]:
print(b['transitions'][5].sum(axis=1))

tensor([0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.8333, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [23]:
print(b['transitions'].shape)

torch.Size([32, 348, 348])


In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool

class GraphConditioningModule(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        """
        Graph Neural Network that processes a batch of Markov transition matrices 
        and extracts one conditioning vector per graph.
        
        Args:
            in_dim (int): Input node feature dimension (dummy features if not available)
            hidden_dim (int): Hidden layer dimension
            out_dim (int): Output conditioning vector dimension
        """
        super(GraphConditioningModule, self).__init__()
        
        # Graph Attention Network (GAT) for learning node embeddings
        self.gnn1 = GATConv(in_dim, hidden_dim)
        self.gnn2 = GATConv(hidden_dim, hidden_dim)
        
        # Fully connected layer to transform pooled graph embedding
        self.fc = nn.Linear(hidden_dim, out_dim)

    def forward(self, batch_graph, node_indices):
        """
        Args:
            batch_graph (Batch): PyG Batch containing all graphs in the batch
            node_indices (torch.Tensor): Shape (batch_size,), selected node per sample
        
        Returns:
            condition_vectors (torch.Tensor): Shape (batch_size, out_dim)
        """
        x = torch.ones((batch_graph.num_nodes, 1), device=batch_graph.edge_index.device)  # Dummy node features
        
        # Forward pass through GNN
        x = F.relu(self.gnn1(x, batch_graph.edge_index))
        x = F.relu(self.gnn2(x, batch_graph.edge_index))
        
        # Extract the node embeddings of the selected nodes
        node_embeddings = x[node_indices]  # Shape: (batch_size, hidden_dim)

        # Pass through a linear layer to generate conditioning vectors
        condition_vectors = self.fc(node_embeddings)  # Shape: (batch_size, out_dim)

        return condition_vectors


In [25]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data, Batch

# latent_dim = 128
# condition_dim = 64
# num_nodes = 100  # Number of nodes per transition matrix
# hidden_dim = 128
# batch_size = 32
# seq_length = 50  # Example sequence length

# cvae = CVAE(encoder, decoder, latent_dim, condition_dim, hidden_dim, num_nodes)
# optimizer = optim.Adam(cvae.parameters(), lr=1e-3)

def build_batch_graphs(markov_matrices):
    """
    Converts a batch of Markov transition matrices into a single batched PyTorch Geometric graph.

    Args:
        markov_matrices (torch.Tensor): (batch_size, num_nodes, num_nodes) tensor

    Returns:
        batch_graph (Batch): Batched PyG graph containing all transition matrices
        node_indices (torch.Tensor): (batch_size,) tensor containing a node index per sample
    """
    batch_size, num_nodes, _ = markov_matrices.shape
    graphs = []
    node_indices = []

    for b in range(batch_size):
        # Extract nonzero entries (source, target) where transition probability > 0
        source_nodes, target_nodes = torch.nonzero(markov_matrices[b], as_tuple=True)
        edge_probs = markov_matrices[b][source_nodes, target_nodes]  # Extract transition probabilities

        # Create edge_index
        edge_index = torch.stack([source_nodes, target_nodes], dim=0)  # Shape (2, num_edges)
        
        # Create graph data object
        graph = Data(edge_index=edge_index, edge_attr=edge_probs, num_nodes=num_nodes)
        graphs.append(graph)

        # Select a random node to condition on (or use a rule)
        node_indices.append(torch.randint(0, num_nodes, (1,)))

    # Batch all graphs into a single PyG Batch object
    batch_graph = Batch.from_data_list(graphs)
    node_indices = torch.cat(node_indices)  # Shape (batch_size,)

    return batch_graph, node_indices
# end build_batch_graphs

In [26]:
batch_graph, node_indices = build_batch_graphs( b['transitions'] )

In [27]:
print(batch_graph)
print(node_indices)

DataBatch(edge_index=[2, 208], edge_attr=[208], num_nodes=11136, batch=[11136], ptr=[33])
tensor([314, 239, 265, 130, 283, 178,  29, 225, 118, 137, 190, 215, 209,  75,
        322, 169, 229,  96, 334, 327,  96, 152, 294, 300,  96, 103, 315, 140,
          4, 330, 316, 303])


In [None]:
# def vae_loss(recon_x, x, mu, logvar):
#     """Computes VAE loss (Reconstruction + KL Divergence)."""
#     recon_loss = F.mse_loss(recon_x, x, reduction="sum")  # Change to CE for text
#     kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())  # KL Divergence
#     return recon_loss + kl_div

# for epoch in range(100):
#     optimizer.zero_grad()

#     # === Generate a new batch of Markov transition matrices === #
#     markov_matrices = torch.rand(batch_size, num_nodes, num_nodes)  # Example: Random transition matrices
#     markov_matrices = markov_matrices / markov_matrices.sum(dim=-1, keepdim=True)  # Normalize rows

#     # === Convert batch of matrices into a batched PyG graph === #
#     batch_graph, node_indices = build_batch_graphs(markov_matrices)

#     # === Generate Random Input Data (Replace with real input) === #
#     input_ids = torch.randint(0, 1000, (batch_size, seq_length))  # Example tokenized input
#     attention_mask = torch.ones_like(input_ids)  # Dummy attention mask

#     # === Forward Pass (Now batch-processed) === #
#     recon_x, mu, logvar = cvae(input_ids, attention_mask, batch_graph, node_indices)
    
#     total_loss = vae_loss(recon_x, input_ids, mu, logvar)
#     total_loss.backward()
#     optimizer.step()

#     print(f"Epoch {epoch}, Loss: {total_loss.item()}")
