In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
import jax
import jax.numpy as jnp
import numpy as np
import haiku as hk
import optax
import matplotlib.pyplot as plt

class ConditionalVAE(hk.Module):
    def __init__(self, 
                 input_dim: int, 
                 condition_dim: int, 
                 latent_dim: int = 20, 
                 name: str = 'conditional_vae'):
        """
        Conditional Variational Autoencoder using Haiku
        
        Args:
            input_dim (int): Dimension of input data
            condition_dim (int): Dimension of conditional information
            latent_dim (int): Dimension of latent space
            name (str): Name of the module
        """
        super().__init__(name=name)
        self.input_dim = input_dim
        self.condition_dim = condition_dim
        self.latent_dim = latent_dim

    def __call__(self, x, c, is_training=True):
        """
        Forward pass of the Conditional VAE
        
        Args:
            x (jnp.ndarray): Input data
            c (jnp.ndarray): Conditional information
            is_training (bool): Training mode flag
        
        Returns:
            x_reconst (jnp.ndarray): Reconstructed input
            mu (jnp.ndarray): Mean of latent distribution
            logvar (jnp.ndarray): Log variance of latent distribution
        """
        # Concatenate input and condition for encoder
        encoder_input = jnp.concatenate([x, c], axis=-1)
        
        # Encoder network
        h = hk.Sequential([
            hk.Linear(256), jax.nn.relu,
            hk.Linear(128), jax.nn.relu
        ])(encoder_input)
        
        # Latent space parameters
        mu = hk.Linear(self.latent_dim)(h)
        logvar = hk.Linear(self.latent_dim)(h)
        
        # Reparameterization trick
        z = self._reparameterize(mu, logvar, is_training)
        
        # Concatenate latent sample and condition for decoder
        decoder_input = jnp.concatenate([z, c], axis=-1)
        
        # Decoder network
        x_reconst = hk.Sequential([
            hk.Linear(128), jax.nn.relu,
            hk.Linear(256), jax.nn.relu,
            hk.Linear(self.input_dim), jax.sigmoid
        ])(decoder_input)
        
        return x_reconst, mu, logvar

    def _reparameterize(self, mu, logvar, is_training):
        """
        Reparameterization trick
        
        Args:
            mu (jnp.ndarray): Mean of latent distribution
            logvar (jnp.ndarray): Log variance of latent distribution
            is_training (bool): Training mode flag
        
        Returns:
            z (jnp.ndarray): Sampled latent vector
        """
        if not is_training:
            return mu
        
        std = jnp.exp(0.5 * logvar)
        eps = jax.random.normal(hk.next_rng_key(), mu.shape)
        return mu + eps * std

def compute_loss(params, x, c, model, rng):
    """
    Compute VAE loss
    
    Args:
        params (dict): Model parameters
        x (jnp.ndarray): Input data
        c (jnp.ndarray): Conditional information
        model (ConditionalVAE): VAE model
        rng (jax.random.PRNGKey): Random number generator key
    
    Returns:
        loss (jnp.ndarray): Total loss
        aux_data (dict): Auxiliary information
    """
    # Forward pass
    x_reconst, mu, logvar = model(x, c, is_training=True)
    
    # Reconstruction loss (Binary Cross Entropy)
    recon_loss = -jnp.sum(x * jnp.log(x_reconst + 1e-10) + 
                           (1 - x) * jnp.log(1 - x_reconst + 1e-10), axis=-1)
    recon_loss = jnp.mean(recon_loss)
    
    # KL Divergence loss
    kl_loss = -0.5 * jnp.sum(1 + logvar - mu**2 - jnp.exp(logvar), axis=-1)
    kl_loss = jnp.mean(kl_loss)
    
    # Total loss (with beta for KL divergence regularization)
    beta = 1.0
    loss = recon_loss + beta * kl_loss
    
    return loss, {
        'recon_loss': recon_loss,
        'kl_loss': kl_loss
    }

def generate_mock_data(key, n_samples=1000, input_dim=10, condition_dim=5):
    """
    Generate mock data for training
    
    Args:
        key (jax.random.PRNGKey): Random number generator key
        n_samples (int): Number of samples
        input_dim (int): Dimension of input data
        condition_dim (int): Dimension of conditional information
    
    Returns:
        x (jnp.ndarray): Input data
        c (jnp.ndarray): Conditional information
    """
    key1, key2 = jax.random.split(key)
    x = jax.random.uniform(key1, (n_samples, input_dim))
    c = jax.random.uniform(key2, (n_samples, condition_dim))
    return x, c

def create_train_step(model):
    """
    Create a training step function
    
    Args:
        model (ConditionalVAE): VAE model
    
    Returns:
        train_step (function): Function for performing a single training step
    """
    @jax.jit
    def train_step(state, x, c, rng):
        """
        Perform a single training step
        
        Args:
            state (TrainState): Current training state
            x (jnp.ndarray): Input data
            c (jnp.ndarray): Conditional information
            rng (jax.random.PRNGKey): Random number generator key
        
        Returns:
            new_state (TrainState): Updated training state
            loss (jnp.ndarray): Current loss value
        """
        def loss_fn(params):
            # Compute loss with current parameters
            loss, aux = compute_loss(params, x, c, model, rng)
            return loss, aux
        
        # Compute gradients
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, aux), grads = grad_fn(state.params)
        
        # Update optimizer state and parameters
        new_state = state.apply_gradients(grads=grads)
        
        return new_state, loss

    return train_step

def train_conditional_vae(epochs=100, batch_size=64):
    """
    Train Conditional VAE
    
    Args:
        epochs (int): Number of training epochs
        batch_size (int): Batch size for training
    
    Returns:
        model (ConditionalVAE): Trained Conditional VAE model
        params (dict): Trained model parameters
    """
    # Set random seeds
    key = jax.random.PRNGKey(42)
    
    # Generate mock data
    x, c = generate_mock_data(key)
    
    # Model and initialization
    input_dim = x.shape[1]
    condition_dim = c.shape[1]
    
    # Initialize parameters
    rng = hk.PRNGSequence(key)
    sample_input = x[:batch_size]
    sample_condition = c[:batch_size]
    
    # Transform the model into a pair of pure functions
    model = hk.transform(ConditionalVAE(input_dim, condition_dim))
    
    # Initialize parameters
    params = model.init(next(rng), sample_input, sample_condition)
    
    # Create optimizer
    optimizer = optax.adam(learning_rate=1e-3)
    
    # Create train state
    from flax.training import train_state
    class TrainState(train_state.TrainState):
        pass
    
    state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )
    
    # Create training step function
    train_step = create_train_step(model)
    
    # Training loop
    for epoch in range(epochs):
        # Shuffle data
        key, shuffle_key = jax.random.split(key)
        shuffled_indices = jax.random.permutation(shuffle_key, x.shape[0])
        x = x[shuffled_indices]
        c = c[shuffled_indices]
        
        # Batch training
        total_loss = 0
        for i in range(0, x.shape[0], batch_size):
            batch_x = x[i:i+batch_size]
            batch_c = c[i:i+batch_size]
            
            key, rng_key = jax.random.split(key)
            state, loss = train_step(state, batch_x, batch_c, rng_key)
            total_loss += loss
        
        # Print average loss
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss / (x.shape[0] // batch_size):.4f}')
    
    return model, state.params

def main():
    """
    Main function to demonstrate Conditional VAE
    """
    # Train Conditional VAE
    model, params = train_conditional_vae()
    
    print("Training complete. Model ready for further use.")


main()

ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

In [None]:
# !pip install cairosvg
# !apt install libcairo2
import cairosvg

# Load the SVG file
svg_file_path = '../data/adaptation_m0_log_adaptation.svg'

# Convert SVG to PNG
png_file_path = svg_file_path.replace('svg', 'png')
cairosvg.svg2png(url=svg_file_path, write_to=png_file_path, output_width=500, output_height=500)

print(f'SVG file has been converted to PNG and saved at {png_file_path}')

