In [15]:
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import dgl
from torch.utils.data import Dataset
import torch.nn.functional as F
from dgl.nn import GraphConv


In [9]:
class GraphDataset(Dataset):
    def __init__(self, graphs, node_features):
        """
        Args:
            graphs (list of DGLGraph): List of graphs.
            node_features (list of dict): List of node features dictionaries.
        """
        self.graphs = graphs
        self.node_features = node_features

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx], self.node_features[idx]


In [10]:
def collate_fn(batch):
    """
    Collates a batch of graphs and their features into a batched graph.
    """
    graphs, features = zip(*batch)
    batched_graph = dgl.batch(graphs)

    # Combine node features across the batch
    batch_categorical = torch.cat([feat['categorical'] for feat in features])
    batch_continuous = torch.cat([feat['continuous'] for feat in features])

    return batched_graph, batch_categorical, batch_continuous


In [27]:
# Example dataset of graphs
graphs = [
    dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 0]))),  # Graph 1
    dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 0])))         # Graph 2
]
node_features = [
    {'categorical': torch.tensor([0, 1, 2]), 'continuous': torch.tensor([[1.2, 3.4], [5.6, 7.8], [9.0, 1.1]])},
    {'categorical': torch.tensor([1, 0]), 'continuous': torch.tensor([[2.3, 4.5], [6.7, 8.9]])},
    {'categorical': torch.tensor([0, 1, 2]), 'continuous': torch.tensor([[1.2, 3.4], [5.6, 7.8], [9.0, 1.1]])},
    {'categorical': torch.tensor([1, 0]), 'continuous': torch.tensor([[2.3, 4.5], [6.7, 8.9]])},
    {'categorical': torch.tensor([0, 1, 2]), 'continuous': torch.tensor([[1.2, 3.4], [5.6, 7.8], [9.0, 1.1]])},
    {'categorical': torch.tensor([1, 0]), 'continuous': torch.tensor([[2.3, 4.5], [6.7, 8.9]])}   
]

# Create dataset and dataloader
dataset = GraphDataset(graphs, node_features)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=collate_fn)


In [29]:
class BatchedLatentDiffusionAutoencoder(nn.Module):
    def __init__(self, num_categories, embedding_dim, continuous_dim, latent_dim, diffusion_steps):
        super().__init__()
        
        # Embedding layer for categorical features
        self.embedding = nn.Embedding(num_categories, embedding_dim)
        
        # Graph convolutional encoder
        input_dim = embedding_dim + continuous_dim
        self.gcn_encoder = GraphConv(input_dim, latent_dim)
        
        # Diffusion process
        self.diffusion_steps = diffusion_steps
        self.noise_scheduler = self.get_noise_scheduler(diffusion_steps)
        self.denoising_model = GraphConv(latent_dim, latent_dim)
        
        # Graph convolutional decoder
        self.gcn_decoder = GraphConv(latent_dim, input_dim)
        
        # Output heads for reconstruction
        self.cat_decoder = nn.Linear(embedding_dim, num_categories)  # Categorical reconstruction
        self.cont_decoder = nn.Linear(continuous_dim, continuous_dim)  # Continuous reconstruction

    def get_noise_scheduler(self, steps):
        # Linear noise schedule for simplicity
        return torch.linspace(0.01, 0.1, steps)

    def forward(self, g, categorical, continuous, t):
        print(categorical)
        print(continuous)
        # Node embeddings for categorical features
        cat_embedded = self.embedding(categorical)
        
        # Concatenate embeddings with continuous features
        node_features = torch.cat([cat_embedded, continuous], dim=-1)
        
        # Encode node features into latent space
        latent = self.gcn_encoder(g, node_features)
        
        # Diffusion process
        noise_level = self.noise_scheduler[t]
        noise = torch.randn_like(latent) * noise_level
        noisy_latent = latent + noise
        
        # Denoise latent space
        denoised_latent = self.denoising_model(g, noisy_latent)
        
        # Decode to original feature space
        decoded = self.gcn_decoder(g, denoised_latent)
        
        # Split decoded output
        decoded_cat = self.cat_decoder(decoded[:, :cat_embedded.size(-1)])
        decoded_cont = self.cont_decoder(decoded[:, cat_embedded.size(-1):])
        
        return decoded_cat, decoded_cont, latent, denoised_latent


In [20]:
# Loss functions
def diffusion_loss(latent, denoised_latent, noise):
    return F.mse_loss(denoised_latent, latent - noise)

def reconstruction_loss(original_cat, reconstructed_cat, original_cont, reconstructed_cont):
    # Categorical loss
    cat_loss = nn.CrossEntropyLoss()(reconstructed_cat, original_cat)
    
    # Continuous loss
    cont_loss = nn.MSELoss()(reconstructed_cont, original_cont)
    
    # Combine losses
    return cat_loss + cont_loss


In [30]:
# Hyperparameters
num_categories = 3  # Number of unique categories
embedding_dim = 8   # Embedding size for categorical data
continuous_dim = 2  # Dimensionality of continuous features
latent_dim = 16     # Latent space size

# Initialize the model
diffusion_steps = 10
model = BatchedLatentDiffusionAutoencoder(num_categories=3, embedding_dim=8, continuous_dim=2, latent_dim=16, diffusion_steps=diffusion_steps)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(50):
    epoch_loss = 0
    
    for batched_graph, batch_categorical, batch_continuous in dataloader:
        # Randomly sample a diffusion step
        t = torch.randint(0, diffusion_steps, (1,)).item()
        
        # Forward pass
        reconstructed_cat, reconstructed_cont, latent, denoised_latent = model(
            batched_graph, batch_categorical, batch_continuous, t
        )
        
        # Compute losses
        noise = torch.randn_like(latent) * model.noise_scheduler[t]
        diffusion_loss = F.mse_loss(denoised_latent, latent - noise)
        reconstruction_loss = (
            nn.CrossEntropyLoss()(reconstructed_cat, batch_categorical) +
            nn.MSELoss()(reconstructed_cont, batch_continuous)
        )
        
        # Combine losses
        loss = diffusion_loss + reconstruction_loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")

tensor([0, 1, 2, 1, 0])
tensor([[1.2000, 3.4000],
        [5.6000, 7.8000],
        [9.0000, 1.1000],
        [2.3000, 4.5000],
        [6.7000, 8.9000]])
Epoch 1, Loss: 35.3250
tensor([0, 1, 2, 1, 0])
tensor([[1.2000, 3.4000],
        [5.6000, 7.8000],
        [9.0000, 1.1000],
        [2.3000, 4.5000],
        [6.7000, 8.9000]])
Epoch 2, Loss: 34.9347
tensor([0, 1, 2, 1, 0])
tensor([[1.2000, 3.4000],
        [5.6000, 7.8000],
        [9.0000, 1.1000],
        [2.3000, 4.5000],
        [6.7000, 8.9000]])
Epoch 3, Loss: 34.3807
tensor([0, 1, 2, 1, 0])
tensor([[1.2000, 3.4000],
        [5.6000, 7.8000],
        [9.0000, 1.1000],
        [2.3000, 4.5000],
        [6.7000, 8.9000]])
Epoch 4, Loss: 33.7184
tensor([0, 1, 2, 1, 0])
tensor([[1.2000, 3.4000],
        [5.6000, 7.8000],
        [9.0000, 1.1000],
        [2.3000, 4.5000],
        [6.7000, 8.9000]])
Epoch 5, Loss: 33.1669
tensor([0, 1, 2, 1, 0])
tensor([[1.2000, 3.4000],
        [5.6000, 7.8000],
        [9.0000, 1.1000],
        