In [None]:
import torch
import torch.nn as nn
from transformers import BartModel
import torch_geometric
from torch_geometric.nn import GCNConv  # Or GATConv for Graph Attention Networks
from torch_geometric.data import Data

In [None]:
# Example: Constructing a graph (Markov transition table)
def create_graph_from_markov_table(transition_matrix):
    # transition_matrix is of shape (num_nodes, num_nodes) and should be sparse
    # Convert to edge index format for PyG (2 x num_edges)
    num_nodes = transition_matrix.shape[0]
    edge_index = []
    edge_attr = []

    for i in range(num_nodes):
        for j in range(num_nodes):
            if transition_matrix[i, j] > 0:  # Only consider non-zero transitions
                edge_index.append([i, j])  # Directed edge from i to j
                edge_attr.append(transition_matrix[i, j])  # Transition probability

    # Convert to PyTorch tensors
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    # Node features can be initialized randomly or based on node properties
    x = torch.randn(num_nodes, 10)  # Random initial features for each node (example)

    # Create PyG data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    
    return data

In [None]:
def cvae_loss(reconstructed, original, mu, logvar):
    reconstruction_loss = nn.MSELoss()(reconstructed, original)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + kl_loss


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BartModel
from torch_geometric.nn import GCNConv, GATConv

# === Load Pretrained BART Model === #
bart = BartModel.from_pretrained("facebook/bart-base")
encoder, decoder = bart.encoder, bart.decoder

# Freeze BART parameters
for param in encoder.parameters():
    param.requires_grad = False
for param in decoder.parameters():
    param.requires_grad = False

# === CVAE with Integrated GNN Conditioning Module === #
class CVAE(nn.Module):
    def __init__(self, encoder, decoder, latent_dim, condition_dim, hidden_dim, num_nodes, use_gat=False, lstm_layers=2):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = latent_dim
        self.condition_dim = condition_dim
        self.hidden_dim = hidden_dim

        encoder_hidden_dim = encoder.config.hidden_size

        # === GNN Conditioning Module === #
        self.node_embeddings = nn.Parameter(torch.randn(num_nodes, hidden_dim))

        if use_gat:
            self.gnn1 = GATConv(hidden_dim, hidden_dim, heads=4, concat=True)
            self.gnn2 = GATConv(hidden_dim * 4, condition_dim, heads=1, concat=False)
        else:
            self.gnn1 = GCNConv(hidden_dim, hidden_dim)
            self.gnn2 = GCNConv(hidden_dim, condition_dim)

        # === BiLSTM Encoder === #
        self.lstm_encoder = nn.LSTM(encoder_hidden_dim, hidden_dim, lstm_layers, 
                                    batch_first=True, bidirectional=True)

        # === VAE Bottleneck === #
        self.fc_mu = nn.Linear(hidden_dim * 2 + condition_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim * 2 + condition_dim, latent_dim)

        # === BiLSTM Decoder === #
        self.lstm_decoder = nn.LSTM(latent_dim + condition_dim, hidden_dim, lstm_layers, 
                                    batch_first=True, bidirectional=True)
        self.fc_decode = nn.Linear(hidden_dim * 2, encoder_hidden_dim)

    def gnn_forward(self, edge_index, edge_attr, node_indices):
        """GNN forward pass to compute node conditioning vectors"""
        x = self.node_embeddings
        x = F.relu(self.gnn1(x, edge_index, edge_attr))
        x = F.relu(self.gnn2(x, edge_index, edge_attr))
        return x[node_indices]  # Shape: (batch_size, condition_dim)

    def reparameterize(self, mu, logvar):
        """Reparameterization trick"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, input_ids, attention_mask, edge_index, edge_attr, node_indices):
        # === Compute GNN-Based Conditioning === #
        condition_vector = self.gnn_forward(edge_index, edge_attr, node_indices)

        # === Encode Input with Frozen BART Encoder === #
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

        # === BiLSTM Encoder === #
        lstm_out, _ = self.lstm_encoder(encoder_outputs)  # (batch, seq_len, hidden_dim * 2)
        lstm_out = lstm_out[:, -1, :]  # Take last time step (batch, hidden_dim * 2)

        # === VAE Reparameterization === #
        lstm_out = torch.cat([lstm_out, condition_vector], dim=-1)  # (batch, hidden_dim * 2 + condition_dim)
        mu = self.fc_mu(lstm_out)
        logvar = self.fc_logvar(lstm_out)
        z = self.reparameterize(mu, logvar)

        # === Concatenate Latent Vector with Condition === #
        z = torch.cat([z, condition_vector], dim=-1)
        z = z.unsqueeze(1).repeat(1, input_ids.shape[1], 1)  # Repeat for each time step

        # === BiLSTM Decoder === #
        lstm_out, _ = self.lstm_decoder(z)
        decoder_inputs = self.fc_decode(lstm_out)

        # === Decode with Frozen BART Decoder === #
        decoder_outputs = self.decoder(inputs_embeds=decoder_inputs).last_hidden_state
        return decoder_outputs, mu, logvar


In [None]:
import torch.optim as optim

latent_dim = 128
condition_dim = 64
num_nodes = 100  # Number of nodes in the transition matrix
hidden_dim = 128
batch_size = 32
seq_length = 50  # Example sequence length

cvae = CVAE(encoder, decoder, latent_dim, condition_dim, hidden_dim, num_nodes)
optimizer = optim.Adam(cvae.parameters(), lr=1e-3)

edge_index = torch.randint(0, num_nodes, (2, 500))  # Random Graph Edges
edge_attr = None  # No edge weights

def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction="sum")  # Change to CE for text
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())  # KL Divergence
    return recon_loss + kl_div

for epoch in range(100):
    optimizer.zero_grad()

    # Select conditioning nodes for batch
    node_indices = torch.randint(0, num_nodes, (batch_size,))

    # Random Input Data (Replace with real input)
    input_ids = torch.randint(0, 1000, (batch_size, seq_length))  # Example tokenized input
    attention_mask = torch.ones_like(input_ids)  # Dummy attention mask

    # Forward Pass
    recon_x, mu, logvar = cvae(input_ids, attention_mask, edge_index, edge_attr, node_indices)

    # Compute Loss
    loss = vae_loss(recon_x, input_ids, mu, logvar)
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item()}")


In [None]:
import torch
import torch.optim as optim

latent_dim = 128
condition_dim = 64
num_nodes = 100  # Number of nodes in the transition matrix
hidden_dim = 128
batch_size = 32
seq_length = 50  # Example sequence length

cvae = CVAE(encoder, decoder, latent_dim, condition_dim, hidden_dim, num_nodes)
optimizer = optim.Adam(cvae.parameters(), lr=1e-3)

def build_graph_from_markov(markov_matrices):
    """
    Converts a batch of Markov transition matrices into edge_index and edge_attr.

    Args:
        markov_matrices (torch.Tensor): (batch_size, num_nodes, num_nodes) tensor

    Returns:
        edge_indices (list of torch.Tensor): List of edge_index tensors per batch item
        edge_attrs (list of torch.Tensor): List of edge_attr tensors per batch item
    """
    batch_size, num_nodes, _ = markov_matrices.shape
    edge_indices = []
    edge_attrs = []

    for b in range(batch_size):
        # Extract nonzero entries (source, target) where transition probability > 0
        source_nodes, target_nodes = torch.nonzero(markov_matrices[b], as_tuple=True)
        edge_probs = markov_matrices[b][source_nodes, target_nodes]  # Extract transition probabilities

        # Create edge_index
        edge_index = torch.stack([source_nodes, target_nodes], dim=0)  # Shape (2, num_edges)
        edge_indices.append(edge_index)
        edge_attrs.append(edge_probs)

    return edge_indices, edge_attrs

def vae_loss(recon_x, x, mu, logvar):
    """Computes VAE loss (Reconstruction + KL Divergence)."""
    recon_loss = F.mse_loss(recon_x, x, reduction="sum")  # Change to CE for text
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())  # KL Divergence
    return recon_loss + kl_div

for epoch in range(100):
    optimizer.zero_grad()

    # === Generate a new batch of Markov transition matrices === #
    markov_matrices = torch.rand(batch_size, num_nodes, num_nodes)  # Example: Random transition matrices
    markov_matrices = markov_matrices / markov_matrices.sum(dim=-1, keepdim=True)  # Normalize rows

    # === Convert to graph format === #
    edge_indices, edge_attrs = build_graph_from_markov(markov_matrices)

    # Select conditioning nodes for batch
    node_indices = torch.randint(0, num_nodes, (batch_size,))

    # === Generate Random Input Data (Replace with real input) === #
    input_ids = torch.randint(0, 1000, (batch_size, seq_length))  # Example tokenized input
    attention_mask = torch.ones_like(input_ids)  # Dummy attention mask

    # === Forward Pass (Process each graph separately) === #
    total_loss = 0
    for i in range(batch_size):
        recon_x, mu, logvar = cvae(
            input_ids[i].unsqueeze(0),
            attention_mask[i].unsqueeze(0),
            edge_indices[i],
            edge_attrs[i],
            node_indices[i].unsqueeze(0),
        )
        total_loss += vae_loss(recon_x, input_ids[i].unsqueeze(0), mu, logvar)

    total_loss /= batch_size  # Normalize loss
    total_loss.backward()
    optimizer.step()

    print(f"Epoch {epoch}, Loss: {total_loss.item()}")
