In [1]:
from data_utils import SeparatedMelHarmMarkovDataset
import os
import numpy as np
from harmony_tokenizers_m21 import ChordSymbolTokenizer, RootTypeTokenizer, \
    PitchClassTokenizer, RootPCTokenizer, GCTRootPCTokenizer, \
    GCTSymbolTokenizer, GCTRootTypeTokenizer, MelodyPitchTokenizer, \
    MergedMelHarmTokenizer
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartConfig, DataCollatorForSeq2Seq
import torch
from torch.optim import AdamW
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_dir = '/mnt/ssd2/maximos/data/hooktheory_train'

In [3]:
chordSymbolTokenizer = ChordSymbolTokenizer.from_pretrained('saved_tokenizers/ChordSymbolTokenizer')
rootTypeTokenizer = RootTypeTokenizer.from_pretrained('saved_tokenizers/RootTypeTokenizer')
pitchClassTokenizer = PitchClassTokenizer.from_pretrained('saved_tokenizers/PitchClassTokenizer')
rootPCTokenizer = RootPCTokenizer.from_pretrained('saved_tokenizers/RootPCTokenizer')
melodyPitchTokenizer = MelodyPitchTokenizer.from_pretrained('saved_tokenizers/MelodyPitchTokenizer')

In [4]:
m_chordSymbolTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, chordSymbolTokenizer)
m_rootTypeTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootTypeTokenizer)
m_pitchClassTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, pitchClassTokenizer)
m_rootPCTokenizer = MergedMelHarmTokenizer(melodyPitchTokenizer, rootPCTokenizer)

In [5]:
print(m_chordSymbolTokenizer.harmony_tokenizer.__class__.__name__)

ChordSymbolTokenizer


In [6]:
tokenizer = m_chordSymbolTokenizer
tokenizer_name = 'ChordSymbolTokenizer'

dataset = SeparatedMelHarmMarkovDataset(root_dir, tokenizer, max_length=512, num_bars=64)
# Data collator for BART
def create_data_collator(tokenizer, model):
    return DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
bart_config = BartConfig(
    vocab_size=len(tokenizer.vocab),
    pad_token_id=tokenizer.pad_token_id,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    decoder_start_token_id=tokenizer.bos_token_id,
    forced_eos_token_id=tokenizer.eos_token_id,
    max_position_embeddings=512,
    encoder_layers=8,
    encoder_attention_heads=8,
    encoder_ffn_dim=512,
    decoder_layers=8,
    decoder_attention_heads=8,
    decoder_ffn_dim=512,
    d_model=512,
    encoder_layerdrop=0.3,
    decoder_layerdrop=0.3,
    dropout=0.3
)

bart = BartForConditionalGeneration(bart_config)

bart_path = 'saved_models/bart/' + tokenizer_name + '/' + tokenizer_name + '.pt'
if device == 'cpu':
    checkpoint = torch.load(bart_path, map_location="cpu", weights_only=True)
else:
    checkpoint = torch.load(bart_path, weights_only=True)
bart.load_state_dict(checkpoint)

bart.to(device)
bart.eval()

bart_encoder, bart_decoder = bart.get_encoder(), bart.get_decoder()
bart_encoder.to(device)
bart_decoder.to(device)

# Freeze BART parameters
for param in bart_encoder.parameters():
    param.requires_grad = False
for param in bart_encoder.parameters():
    param.requires_grad = False

In [10]:
collator = create_data_collator(tokenizer, model=bart)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collator)

In [11]:
b = next(iter(dataloader))

  return self.iter().getElementsByClass(classFilterList)
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


In [12]:
print(b['transitions'][5].sum(axis=1))

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 

In [13]:
print(b['transitions'].shape)

torch.Size([32, 348, 348])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, Batch

def build_batch_graphs(markov_matrices):
    """
    Converts a batch of Markov transition matrices into a single batched PyTorch Geometric graph.

    Args:
        markov_matrices (torch.Tensor): (batch_size, num_nodes, num_nodes) tensor

    Returns:
        batch_graph (Batch): Batched PyG graph containing all transition matrices
        node_indices (torch.Tensor): (batch_size,) tensor containing a node index per sample
    """
    batch_size, num_nodes, _ = markov_matrices.shape
    graphs = []
    node_indices = []

    for b in range(batch_size):
        # Extract nonzero entries (source, target) where transition probability > 0
        source_nodes, target_nodes = torch.nonzero(markov_matrices[b], as_tuple=True)
        edge_probs = markov_matrices[b][source_nodes, target_nodes]  # Extract transition probabilities

        # Create edge_index
        edge_index = torch.stack([source_nodes, target_nodes], dim=0)  # Shape (2, num_edges)
        
        # Create graph data object
        graph = Data(edge_index=edge_index, edge_attr=edge_probs, num_nodes=num_nodes)
        graphs.append(graph)

        # Select a random node to condition on (or use a rule)
        node_indices.append(torch.randint(0, num_nodes, (1,)))

    # Batch all graphs into a single PyG Batch object
    batch_graph = Batch.from_data_list(graphs)
    node_indices = torch.cat(node_indices)  # Shape (batch_size,)

    return batch_graph, node_indices
# end build_batch_graphs

def compute_loss(recon_x, x, mu, logvar):
    """
    Compute VAE loss (Reconstruction Loss + KL Divergence).
    
    Args:
        recon_x (torch.Tensor): Reconstructed sequences (batch_size, seq_len, transformer_dim)
        x (torch.Tensor): Ground truth sequences (batch_size, seq_len, transformer_dim)
        mu (torch.Tensor): Mean of latent distribution (batch_size, latent_dim)
        logvar (torch.Tensor): Log variance of latent distribution (batch_size, latent_dim)
    
    Returns:
        loss (torch.Tensor): Combined loss
    """
    recon_loss = F.mse_loss(recon_x, x, reduction='mean')

    # KL divergence loss
    kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + kl_loss, recon_loss, kl_loss
# end compute_loss

class GraphConditioningModule(nn.Module):
    def __init__(self, hidden_dim, out_dim, use_attention=False):
        """
        Graph-based conditioning module for extracting node embeddings as condition vectors.

        Args:
            hidden_dim (int): Hidden dimension of GNN layers
            out_dim (int): Dimension of the conditioning vector
            use_attention (bool): If True, uses GATConv; otherwise, uses GCNConv.
        """
        super(GraphConditioningModule, self).__init__()

        self.use_attention = use_attention
        
        if use_attention:
            self.gnn1 = GATConv(1, hidden_dim)
            self.gnn2 = GATConv(hidden_dim, hidden_dim)
        else:
            self.gnn1 = GCNConv(1, hidden_dim)
            self.gnn2 = GCNConv(hidden_dim, hidden_dim)

        self.fc = nn.Linear(hidden_dim, out_dim)
    # end init

    def forward(self, batch_graph, node_indices):
        """
        Args:
            batch_graph (Batch): Batched graph object from PyG
            node_indices (torch.Tensor): Shape (batch_size,), selected node per sample
        
        Returns:
            condition_vectors (torch.Tensor): Shape (batch_size, out_dim)
        """
        x = torch.ones((batch_graph.num_nodes, 1), device=batch_graph.edge_index.device)  # Dummy features

        x = F.relu(self.gnn1(x, batch_graph.edge_index))
        x = F.relu(self.gnn2(x, batch_graph.edge_index))
        
        node_embeddings = x[node_indices]  # Shape: (batch_size, hidden_dim)
        condition_vectors = self.fc(node_embeddings)  # Shape: (batch_size, out_dim)

        return condition_vectors
    # end forward
# end class GraphConditioningModule

class BiLSTMEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        """
        BiLSTM encoder for sequential input data.
        
        Args:
            input_dim (int): Input feature dimension per timestep
            hidden_dim (int): Hidden state dimension
        """
        super(BiLSTMEncoder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)  # Project bidirectional output
    # end init

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input sequence of shape (batch_size, seq_len, input_dim)
        
        Returns:
            hidden_state (torch.Tensor): Shape (batch_size, hidden_dim)
        """
        _, (h_n, _) = self.lstm(x)
        h_n = torch.cat((h_n[0], h_n[1]), dim=-1)  # Concatenate bidirectional outputs
        return self.fc(h_n)  # Shape: (batch_size, hidden_dim)
    # end forward
# end class BiLSTMEncoder

class BiLSTMDecoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        """
        BiLSTM decoder that reconstructs sequences from latent representations.

        Args:
            hidden_dim (int): Hidden dimension of LSTM
            output_dim (int): Output feature dimension per timestep
        """
        super(BiLSTMDecoder, self).__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    # end init

    def forward(self, z, seq_len):
        """
        Args:
            z (torch.Tensor): Latent variable (batch_size, hidden_dim)
            seq_len (int): Target sequence length
        
        Returns:
            recon_x (torch.Tensor): Shape (batch_size, seq_len, output_dim)
        """
        z = z.unsqueeze(1).repeat(1, seq_len, 1)  # Expand latent state across sequence
        output, _ = self.lstm(z)
        return self.fc(output)  # Shape: (batch_size, seq_len, output_dim)
    # end forward
# end class BiLSTMDecoder

class CVAE(nn.Module):
    def __init__(self, transformer_dim, **config):
        """
        CVAE model integrating BiLSTM encoder-decoder and GNN-based conditioning.

        Args:
            transformer_dim (int): Input and output feature dimension per timestep
            hidden_dim_LSTM (int): Hidden dimension for BiLSTM
            hidden_dim_GNN (int): Hidden dimension for GNN
            latent_dim (int): Dimension of the VAE latent space
            condition_dim (int): Dimension of the conditioning vector
            use_attention (bool): If True, uses GATConv; otherwise, uses GCNConv.
        """
        super(CVAE, self).__init__()

        hidden_dim_LSTM = 256
        hidden_dim_GNN = 256
        latent_dim = 256
        condition_dim = 128
        use_attention=False
        if 'hidden_dim_LSTM' in config.keys():
            hidden_dim_LSTM = config['hidden_dim_LSTM']
        if 'hidden_dim_GNN' in config.keys():
            hidden_dim_GNN = config['hidden_dim_GNN']
        if 'latent_dim' in config.keys():
            latent_dim = config['latent_dim']
        if 'condition_dim' in config.keys():
            condition_dim = config['condition_dim']
        if 'use_attention' in config.keys():
            use_attention = config['use_attention']

        self.lstm_encoder = BiLSTMEncoder(transformer_dim, hidden_dim_LSTM)
        self.lstm_decoder = BiLSTMDecoder(hidden_dim_LSTM, transformer_dim)

        self.graph_conditioning = GraphConditioningModule(hidden_dim_GNN, condition_dim, use_attention=use_attention)

        # Latent space transformations
        self.fc_mu = nn.Linear(hidden_dim_LSTM + condition_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim_LSTM + condition_dim, latent_dim)
        self.fc_z = nn.Linear(latent_dim + condition_dim, hidden_dim_LSTM)
    # end init

    def reparameterize(self, mu, logvar):
        """Reparameterization trick: z = mu + std * epsilon"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    # end reparametrize

    def build_batch_graphs(self, markov_matrices):
        """
        Converts a batch of Markov transition matrices into a single batched PyTorch Geometric graph.

        Args:
            markov_matrices (torch.Tensor): (batch_size, num_nodes, num_nodes) tensor

        Returns:
            batch_graph (Batch): Batched PyG graph containing all transition matrices
            node_indices (torch.Tensor): (batch_size,) tensor containing a node index per sample
        """
        batch_size, num_nodes, _ = markov_matrices.shape
        graphs = []
        node_indices = []

        for b in range(batch_size):
            # Extract nonzero entries (source, target) where transition probability > 0
            source_nodes, target_nodes = torch.nonzero(markov_matrices[b], as_tuple=True)
            edge_probs = markov_matrices[b][source_nodes, target_nodes]  # Extract transition probabilities

            # Create edge_index
            edge_index = torch.stack([source_nodes, target_nodes], dim=0)  # Shape (2, num_edges)
            
            # Create graph data object
            graph = Data(edge_index=edge_index, edge_attr=edge_probs, num_nodes=num_nodes)
            graphs.append(graph)

            # Select a random node to condition on (or use a rule)
            node_indices.append(torch.randint(0, num_nodes, (1,)))

        # Batch all graphs into a single PyG Batch object
        batch_graph = Batch.from_data_list(graphs)
        node_indices = torch.cat(node_indices)  # Shape (batch_size,)

        return batch_graph, node_indices
    # end build_batch_graphs

    def forward(self, x, transitions):
        """
        Args:
            x (torch.Tensor): Input sequence of shape (batch_size, seq_len, input_dim)
            transitions: markov matrix
        
        Returns:
            recon_x (torch.Tensor): Reconstructed sequence
            mu (torch.Tensor): Mean of latent distribution
            logvar (torch.Tensor): Log variance of latent distribution
        """
        h = self.lstm_encoder(x)  # Shape: (batch_size, hidden_dim)
        batch_graph, node_indices = self.build_batch_graphs( transitions )
        condition = self.graph_conditioning(batch_graph, node_indices)  # Shape: (batch_size, condition_dim)

        h_cond = torch.cat([h, condition], dim=-1)  # Shape: (batch_size, hidden_dim_LSTM + condition_dim)

        mu = self.fc_mu(h_cond)
        logvar = self.fc_logvar(h_cond)
        z = self.reparameterize(mu, logvar)

        z_cond = torch.cat([z, condition], dim=-1)
        z_hidden = self.fc_z(z_cond)  # Shape: (batch_size, hidden_dim_LSTM)

        recon_x = self.lstm_decoder(z_hidden, x.shape[1])  # Reconstruct sequence

        return recon_x, mu, logvar
    # end forward
# end CVAE

class TransGraphVAE(nn.Module):
    def __init__(self, t_encoder, t_decoder, **config):
        """
        TransGraphVAE model that involves a GNN-conditioned BiLSTM VAE between a pretrained
        frozen transformer encoder-decoder.

        Args:
            t_encoder: frozen encoder of the pretrained transformer
            t_decoder: frozen encoder of the pretrained transformer
            **config: arguments for the CVAE module
        """
        super(TransGraphVAE, self).__init__()
        self.t_encoder = t_encoder
        self.t_decoder = t_decoder
        self.cvae = CVAE(t_encoder.dim, **config)
    # end init

    def compute_loss(self, recon_x, x, mu, logvar):
        """
        Compute VAE loss (Reconstruction Loss + KL Divergence).
        
        Args:
            recon_x (torch.Tensor): Reconstructed sequences (batch_size, seq_len, transformer_dim)
            x (torch.Tensor): Ground truth sequences (batch_size, seq_len, transformer_dim)
            mu (torch.Tensor): Mean of latent distribution (batch_size, latent_dim)
            logvar (torch.Tensor): Log variance of latent distribution (batch_size, latent_dim)
        
        Returns:
            loss (torch.Tensor): Combined loss
        """
        recon_loss = F.mse_loss(recon_x, x, reduction='mean')

        # KL divergence loss
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

        return recon_loss + kl_loss, recon_loss, kl_loss
    # end compute_loss

    def forward(self, x, transitions):
        x = self.t_encoder(x).last_hidden_state
        recon_x, mu, logvar  = self.cvae(x, transitions)
        total_loss, _, _ = compute_loss(recon_x, x, mu, logvar)
        y_recon = self.t_decoder( recon_x ) # output from reconstruction
        y = self.t_decoder( x ) # normal output

In [15]:
batch_graph, node_indices = build_batch_graphs( b['transitions'] )

In [16]:
print(batch_graph)
print(node_indices)

DataBatch(edge_index=[2, 204], edge_attr=[204], num_nodes=11136, batch=[11136], ptr=[33])
tensor([ 10, 139,  48,  25,  97, 252, 217,  64, 211, 134,  96, 249, 150, 172,
         90, 292, 185,  87, 207, 221,  11,  50,  88, 181, 180,  47,  75,  34,
        287, 244,  19, 221])


In [17]:
print(b['transitions'][0].nonzero())

tensor([[  0,  58],
        [ 58, 116],
        [ 59, 116],
        [116, 262],
        [203,   0],
        [262,  58],
        [262,  59],
        [262, 203],
        [262, 262]])


In [18]:
ex0 = batch_graph.get_example(0)
print(ex0)

Data(edge_index=[2, 9], edge_attr=[9], num_nodes=348)


In [19]:
graph_conditioning = GraphConditioningModule(
    hidden_dim=256, out_dim=128, use_attention=False
)

In [20]:
y = graph_conditioning(batch_graph, node_indices)

In [50]:
print(y.shape)

torch.Size([32, 128])


In [None]:
# def vae_loss(recon_x, x, mu, logvar):
#     """Computes VAE loss (Reconstruction + KL Divergence)."""
#     recon_loss = F.mse_loss(recon_x, x, reduction="sum")  # Change to CE for text
#     kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())  # KL Divergence
#     return recon_loss + kl_div

# for epoch in range(100):
#     optimizer.zero_grad()

#     # === Generate a new batch of Markov transition matrices === #
#     markov_matrices = torch.rand(batch_size, num_nodes, num_nodes)  # Example: Random transition matrices
#     markov_matrices = markov_matrices / markov_matrices.sum(dim=-1, keepdim=True)  # Normalize rows

#     # === Convert batch of matrices into a batched PyG graph === #
#     batch_graph, node_indices = build_batch_graphs(markov_matrices)

#     # === Generate Random Input Data (Replace with real input) === #
#     input_ids = torch.randint(0, 1000, (batch_size, seq_length))  # Example tokenized input
#     attention_mask = torch.ones_like(input_ids)  # Dummy attention mask

#     # === Forward Pass (Now batch-processed) === #
#     recon_x, mu, logvar = cvae(input_ids, attention_mask, batch_graph, node_indices)
    
#     total_loss = vae_loss(recon_x, input_ids, mu, logvar)
#     total_loss.backward()
#     optimizer.step()

#     print(f"Epoch {epoch}, Loss: {total_loss.item()}")
