In [1]:
dataset_path = '/code/root/MEMA_gene_matrix_hard_calls_imputed_data.parquet'

In [2]:
import os
import shutil
import json
import jax
import optax
import hashlib
import jax.numpy as jnp
import jax.random as random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from flax import struct
from flax.training import train_state
from tqdm.notebook import tqdm  # Better for Jupyter
import time
from functools import partial
from collections import defaultdict

from typing import NamedTuple, Any

In [3]:
from models.jax_vae_cat import VAE
from root.load_parquet import load_genetic_data, create_random_batches


class TrainState(train_state.TrainState):
    pass 
    # batch_stats: Any

@jax.jit
def categorical_accuracy(logits, targets):
    """
    Compute classification accuracy for categorical VAE
    
    Args:
        logits: Model outputs (batch_size, ..., num_classes) - raw logits
        targets: One-hot encoded targets (batch_size, ..., num_classes)
    
    Returns:
        accuracy: Average accuracy across batch (scalar)
    """
    # Get predicted classes (argmax of logits)
    predicted_classes = jnp.argmax(logits, axis=-1)
    
    # Get true classes (argmax of one-hot targets) 
    true_classes = jnp.argmax(targets, axis=-1)
    
    # Compute accuracy per sample
    correct_predictions = jnp.equal(predicted_classes, true_classes)
    
    # Average over all dimensions (genes and batch)
    accuracy = jnp.mean(correct_predictions)
    
    return accuracy

@jax.jit
def train_step(
    state_, 
    input_,
    target_, 
    rng_key,
):
    def loss_fn(params_):
        (x_recon, mu, logvar) = state_.apply_fn(
            {'params': params_},
            input_,
            key=rng_key
        )
        reconstruction_loss = -jnp.sum(
        target_ * jax.nn.log_softmax(x_recon, axis=-1),
            axis=(-2, -1)  # Sum over genes and classes
        )
        reconstruction_loss = jnp.mean(reconstruction_loss)

        accuracy = categorical_accuracy(target_, x_recon)
    
        # KL divergence
        kl_loss = -0.5 * jnp.sum(1 + logvar - mu**2 - jnp.exp(logvar), axis=-1)
        kl_loss = jnp.mean(kl_loss)
        
        # Total loss
        tot_err = reconstruction_loss + 1.0 * kl_loss
        
        return tot_err, (reconstruction_loss, kl_loss, tot_err, accuracy)
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    aux, grads = grad_fn(state_.params)
    tot_loss, (recon_loss, kl_loss, _, accuracy) = aux
    
    new_state = state_.apply_gradients(grads=grads)
    return new_state, (tot_loss, recon_loss, kl_loss, accuracy)


if __name__ == "__main__":

    # GPU Check and Setup
    print("🔧 Checking GPU availability...")
    print(f"JAX devices: {jax.devices()}")
    print(f"JAX default backend: {jax.default_backend()}")
    
    try:
        gpu_devices = jax.devices('gpu')
        if gpu_devices:
            print(f"✅ Found {len(gpu_devices)} GPU(s): {gpu_devices}")
            print("🎯 Setting JAX to use GPU...")
        else:
            print("⚠️ No GPU devices found - using CPU")
    except Exception as e:
        print(f"⚠️ GPU check failed: {e}")
        print("Using default JAX configuration")
    
    compute_dtype = jnp.float32
    
    # Initialize model
    print("Initializing VAE model...")
    num_genes = 5000
    vae = VAE(num_genes=num_genes)
    
    key = random.PRNGKey(42)
    x = random.uniform(key, (1, num_genes))
    print(x.shape)
    key, init_key, sample_key = random.split(key, 3)
    variables = vae.init(init_key, x, sample_key)
    # print(vae.tabulate(init_key, x, sample_key))
    print("Model initialized!")
    
    # Training setup
    key = random.key(0)
    key, *subkeys = random.split(key, 4)
    
    learning_rate = 0.001
    num_epochs = 500
    batch_size = 32
    
    print(f"📋 Training Configuration:")
    print(f"   Learning Rate: {learning_rate}")
    print(f"   Epochs: {num_epochs}")
    print(f"   Batch Size: {batch_size}")
    print()
 
    params= variables['params']# , variables['batch_stats']
    tx = optax.adamw(learning_rate)

    state = TrainState.create(
        apply_fn=vae.apply,
        params=params,
        tx=tx,
        # batch_stats=batch_stats
    )
    
    # Initialize dataloader
    print("📂 Setting up dataloader...")
    inputs_og, targets_og, _ = load_genetic_data(dataset_path, k=num_genes, selection_method='variance')
    print()
    
    
    print("Starting training loop...")
    print("-" * 80)
    key = jax.random.PRNGKey(42)

    # Training metrics for this epoch
    epoch_losses = []
    epoch_recon_losses = []
    epoch_kl_losses = []
    epoch_accuracies = []

    for epoch in tqdm(range(num_epochs)):
        # Generate new random batches each epoch
        epoch_key = jax.random.fold_in(key, epoch)
        batches = create_random_batches(inputs_og, targets_og, batch_size, epoch_key, drop_remainder=True)

        # Train on all batches
        for batch_idx, (batch_inputs, batch_targets) in enumerate(batches):
            # Create VAE key for reparameterization
            vae_key = jax.random.fold_in(key, epoch * len(batches) + batch_idx)
        
            state, (loss, recon_loss, kl_loss, accuracy) = train_step(
                state, batch_inputs, batch_targets, vae_key
            )
            
            epoch_losses.append(loss)
            epoch_recon_losses.append(recon_loss)
            epoch_kl_losses.append(kl_loss)
            epoch_accuracies.append(accuracy)
        
        # Calculate epoch averages
        avg_loss = jnp.mean(jnp.array(epoch_losses))
        avg_recon = jnp.mean(jnp.array(epoch_recon_losses))
        avg_kl = jnp.mean(jnp.array(epoch_kl_losses))
        avg_accuracy = jnp.mean(jnp.array(epoch_accuracies))
        print(f'Epoch: {epoch} -- ', avg_loss, avg_recon, avg_kl, avg_accuracy)
    
    
    ckpt = {'model': state.params}
    
    import orbax
    from flax.training import orbax_utils

    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(ckpt)
    
    # Remove the directory if it exists
    checkpoint_path = '/code/checkpoints/vae'
    if os.path.exists(checkpoint_path):
        shutil.rmtree(checkpoint_path)
    
    # Now save the checkpoint
    orbax_checkpointer.save(checkpoint_path, ckpt, save_args=save_args)
    print('Done -- saved to:', checkpoint_path)
        

🔧 Checking GPU availability...
JAX devices: [CudaDevice(id=0)]
JAX default backend: gpu
✅ Found 1 GPU(s): [CudaDevice(id=0)]
🎯 Setting JAX to use GPU...
Initializing VAE model...
(1, 5000)
Model initialized!
📋 Training Configuration:
   Learning Rate: 0.001
   Epochs: 500
   Batch Size: 32

📂 Setting up dataloader...
Original data shape: (200, 43788)
Selected 5000 genes using variance method
Reduced data shape: (200, 5000)
Final data - Input shape: (200, 5000), Target shape: (200, 5000, 8)

Starting training loop...
--------------------------------------------------------------------------------


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0 --  11053.105 11008.918 44.187187 0.12867396
Epoch: 1 --  10797.605 10747.034 50.570877 0.13731301
Epoch: 2 --  10646.978 10591.835 55.141254 0.14727151
Epoch: 3 --  10544.705 10490.592 54.11303 0.15589245
Epoch: 4 --  10468.067 10418.254 49.8139 0.16239585
Epoch: 5 --  10409.953 10364.685 45.269802 0.16753507
Epoch: 6 --  10364.599 10322.932 41.667435 0.17162053
Epoch: 7 --  10329.178 10290.416 38.76194 0.17473114
Epoch: 8 --  10301.852 10265.627 36.225094 0.17707616
Epoch: 9 --  10278.599 10244.584 34.01468 0.1789797
Epoch: 10 --  10259.212 10227.085 32.125927 0.18056485
Epoch: 11 --  10242.526 10212.005 30.519566 0.1819795
Epoch: 12 --  10227.983 10198.93 29.054049 0.18313068
Epoch: 13 --  10216.037 10188.271 27.766289 0.18422855
Epoch: 14 --  10204.795 10178.197 26.597609 0.18536124
Epoch: 15 --  10195.92 10170.336 25.583267 0.18617982
Epoch: 16 --  10186.701 10161.883 24.818104 0.18712164
Epoch: 17 --  10177.447 10153.2 24.24709 0.18838745
Epoch: 18 --  10168.077 10144.17