<a href="https://colab.research.google.com/github/jcandane/SSA/blob/main/LVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# setup

In [1]:
!mkdir -p LVAE

In [2]:
%%writefile LVAE/LVAE.py

import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.training import train_state
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from sklearn.decomposition import PCA

from flax import traverse_util
from flax.jax_utils import replicate
from flax.serialization import to_bytes, from_bytes

import json
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, upload_file

# PCA Encoder (flattened input)
class PCAEncoder(nn.Module):
    d: int
    input_dim: int

    @nn.compact
    def __call__(self, x):
        x     = x.reshape(x.shape[0], -1)
        μ     = nn.Dense(self.d)(x)
        logσ2 = nn.Dense(self.d)(x)
        σ     = jnp.exp(0.5 * logσ2)
        ε     = jax.random.normal(self.make_rng('noise'), μ.shape)
        return μ + σ * ε, μ, logσ2

# PCA Decoder
class PCADecoder(nn.Module):
    output_shape: tuple

    @nn.compact
    def __call__(self, z):
        flat_dim = np.prod(self.output_shape)
        x = nn.Dense(flat_dim)(z)
        x = x.reshape(-1, *self.output_shape)
        return x

# PCA-VAE Model
class PCAVAE(nn.Module):
    d: int
    input_shape: tuple

    def setup(self):
        input_dim = np.prod(self.input_shape)
        self.encoder = PCAEncoder(self.d, input_dim)
        self.decoder = PCADecoder(self.input_shape)

    def __call__(self, x):
        z, μ, logσ2 = self.encoder(x)
        recon       = self.decoder(z)
        return recon, μ, logσ2

# Training wrapper with β and LR scheduling
class LVAE:
    def __init__(self, d=16, epochs=20, batch_size=128, seed=0, β=1.0,
                 modal_id:str="", modal_secret:str="", hf_token:str="", repo_id:str=""):
        self.d = d
        self.epochs = epochs
        self.batch_size = batch_size
        self.seed = seed
        self.rng = jax.random.PRNGKey(seed)
        self.losses, self.recons, self.kls = [], [], []
        self.β = β

        # New cloud integration parameters
        self.modal_id     = modal_id
        self.modal_secret = modal_secret
        self.hf_token     = hf_token
        self.repo_id      = repo_id

    @property
    def config(self):
        #exclude = {"model", "state", "ds"} ## exclude these from defining the class
        exclude = {"model", "state", "ds", "rng", "losses", "recons", "kls", "verbose"}
        return {k: v for k, v in self.__dict__.items() if k not in exclude}

    @classmethod
    def from_config(cls, config):
        init_keys = cls.__init__.__annotations__.keys()
        filtered_config = {k: config[k] for k in init_keys if k in config}
        return cls(**filtered_config)

    def init_model(self, data):
        input_shape = data.shape[1:]
        flat_data = data.reshape(len(data), -1)
        pca = PCA(n_components=self.d)
        pca.fit(flat_data)
        pca_weights = pca.components_.T

        self.model = PCAVAE(self.d, input_shape)
        variables = self.model.init({'params': self.rng, 'noise': self.rng}, jnp.ones((1, *input_shape)))

        rng_weights = np.random.RandomState(self.seed).normal(scale=1e-2, size=pca_weights.shape)
        variables['params']['encoder']['Dense_0']['kernel'] = (pca_weights+rng_weights).astype(np.float32)
        variables['params']['encoder']['Dense_0']['bias']   = np.zeros(self.d, dtype=np.float32)
        variables['params']['encoder']['Dense_1']['kernel'] = np.zeros((flat_data.shape[1], self.d), dtype=np.float32)
        variables['params']['encoder']['Dense_1']['bias']   = np.full((self.d,), -10.0, dtype=np.float32)

        variables['params']['decoder']['Dense_0']['kernel'] = (pca_weights+rng_weights).T.astype(np.float32)
        variables['params']['decoder']['Dense_0']['bias']   = flat_data.mean(axis=0).astype(np.float32)

        steps_per_epoch = len(data) // self.batch_size
        schedule = optax.warmup_cosine_decay_schedule(
            init_value=1e-7, peak_value=1e-4,
            warmup_steps=steps_per_epoch*2,  # 2 epochs warmup
            decay_steps=steps_per_epoch*(self.epochs-2),
            end_value=1e-7,
        )

        self.tx = optax.adam(schedule)
        self.state = train_state.TrainState.create(
            apply_fn=self.model.apply, params=variables['params'], tx=self.tx
        )

    @partial(jax.jit, static_argnums=0)
    def train_step(self, state, batch, rng, β):
        def loss_fn(params):
            recon, μ, logσ2 = self.model.apply({'params': params}, batch, rngs={'noise': rng})
            recon_loss      = jnp.mean((batch - recon)**2)
            kl_loss         = -0.5 * jnp.mean(1 + logσ2 - μ**2 - jnp.exp(logσ2))
            return recon_loss + β * kl_loss, (recon_loss, kl_loss)

        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, (recon_loss, kl_loss)), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
        return state, loss, recon_loss, kl_loss

    def train(self, dataset, gpu=None):

        if gpu is not None:
            modal_app_code = f'''
import modal

jax_image = (
    modal.Image.debian_slim()
    .run_commands("pip install --upgrade pip")
    .run_commands(
        "pip install flax optax einops matplotlib scikit-learn numpy huggingface_hub jax[cuda12] --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
    )
)

app = modal.App(name="LVAE-app")

@app.function(image=jax_image, gpu="{gpu}", timeout=400)
def train_remote(config, dataset, hf_token=None, repo_id=None):
    from LVAE.LVAE import LVAE
    tvae = LVAE(**config)
    tvae.train(dataset) #, hf_token=hf_token, repo_id=repo_id)
    return None'''

            # Write to modal_app.py
            with open("LVAE/modal_app.py", "w") as f:
                f.write(modal_app_code)

            # Optional remote training with Modal
            from LVAE.modal_app import app, train_remote

            with app.run():
                return train_remote.remote(self.config, dataset) #, hf_token=self.hf_token, repo_id=self.repo_id)


        self.init_model(dataset)
        num_batches = len(dataset) // self.batch_size

        for epoch in range(1, self.epochs + 1):
            β = min(self.β, epoch / (self.epochs * 0.25))  # Linear β schedule from 0 to 1 halfway through training
            epoch_loss = epoch_recon = epoch_kl = 0.0
            perm = np.random.permutation(len(dataset))

            for i in range(num_batches):
                batch_idx = perm[i*self.batch_size:(i+1)*self.batch_size]
                batch = dataset[batch_idx]
                self.rng, rng_step = jax.random.split(self.rng)
                self.state, loss, r_loss, kl_loss = self.train_step(self.state, batch, rng_step, β)
                epoch_loss  += loss
                epoch_recon += r_loss
                epoch_kl    += kl_loss

            avg_loss  = epoch_loss / num_batches
            avg_recon = epoch_recon / num_batches
            avg_kl    = epoch_kl / num_batches
            self.losses.append(avg_loss)
            self.recons.append(avg_recon)
            self.kls.append(avg_kl)
            if epoch % 100 == 0:
                print(f"Epoch {epoch}: Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, KL={avg_kl:.4f}, β={β:.2f}")

        self.upload_to_huggingface(self.repo_id)
        self.plot_results(dataset, round=False)

    def reconstruct(self, x):
        self.rng, rng_step = jax.random.split(self.rng)
        recon, _, _ = self.model.apply({'params': self.state.params}, x, rngs={'noise': rng_step})
        return recon

    def upload_to_huggingface(self,
                              repo_id: str,
                              weights_file: str = "flax_model.msgpack",
                              config_file: str = "config.json"):

        assert self.hf_token, "🤗 HuggingFace token (`hf_token`) is not set!"

        # Save token to local HfFolder for API authentication
        HfFolder.save_token(self.hf_token)
        api = HfApi()

        # ── 1. Save model parameters ────────────────────────────────────────
        with open(weights_file, "wb") as f:
            f.write(to_bytes(self.state.params))

        # ── 2. Merge model & training configuration ─────────────────────────
        full_config = {
            "d": self.d,
            "epochs": self.epochs,
            "batch_size": self.batch_size,
            "seed": self.seed,
            "modal_id": self.modal_id,
            "modal_secret": self.modal_secret,
            "rng": self.rng.tolist() if isinstance(self.rng, jnp.ndarray) else self.rng,
            "training_time": getattr(self, "training_time", None),
            "losses": [float(x) for x in self.losses],
            "recons": [float(x) for x in self.recons],
            "kls": [float(x) for x in self.kls],
            "input_shape": self.model.input_shape,
            "latent_dim": self.d,
            "model_type": "Linear Variational Autoencoder (LVAE)"
        }

        with open(config_file, "w") as f:
            json.dump(full_config, f, indent=2)

        # ── 3. Upload files to HuggingFace ──────────────────────────────────
        for fname in [weights_file, config_file]:
            upload_file(
                path_or_fileobj=fname,
                path_in_repo=fname,
                repo_type="dataset",
                repo_id=repo_id,
                token=self.hf_token
            )

        print(f"🤗 LVAE successfully uploaded to Hugging Face: {repo_id}")


    def plot_results(self, x, num_images=8, round=False):
        recon = self.reconstruct(x[:num_images])
        if round:
            recon = np.round(recon).astype(bool)
        fig, axes = plt.subplots(2, num_images, figsize=(num_images*2,4))
        for i in range(num_images):
            axes[0,i].imshow(x[i].squeeze(), cmap='gray'); axes[0,i].axis('off')
            axes[1,i].imshow(recon[i].squeeze(), cmap='gray'); axes[1,i].axis('off')
        mnist_example = "mnist_example.png"
        plt.savefig(mnist_example)
        plt.show()

        plt.figure(figsize=(8,4))
        plt.plot(self.losses, label='Total loss')
        plt.plot(self.recons, label='Recon loss')
        plt.plot(self.kls, label='KL loss')
        plt.yscale("log")
        plt.legend()
        plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.show()
        loss_png = "Loss"+".png"
        plt.savefig(loss_png)
        plt.show()

        if self.hf_token!="" and self.repo_id!="":
            upload_file(
                path_or_fileobj=loss_png,
                path_in_repo=loss_png,
                repo_type="dataset",
                repo_id=self.repo_id,
                token=self.hf_token
            )
            print(f"🖼️ Visualization uploaded to Hugging Face: {self.repo_id}/{loss_png}")
            upload_file(
                path_or_fileobj=mnist_example,
                path_in_repo=mnist_example,
                repo_type="dataset",
                repo_id=self.repo_id,
                token=self.hf_token
            )
            print(f"🖼️ Visualization uploaded to Hugging Face: {self.repo_id}/{mnist_example}")

def set_modal_tokens(token_id: str, token_secret: str):
    import subprocess

    cmd = [
        "modal", "token", "set",
        "--token-id", token_id,
        "--token-secret", token_secret
    ]
    subprocess.run(cmd, check=True)
    print("✅ Tokens configured successfully.")
    return None

Overwriting LVAE/LVAE.py


# Get User Login Data

In [3]:
from google.colab import userdata
import numpy as np
try:
    import modal
except ImportError:
    !pip install -q modal
    import modal
try:
    from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, upload_file
except:
    !pip install -q -U huggingface_hub
    from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, upload_file
from LVAE.LVAE import set_modal_tokens

MODAL_ID     = userdata.get("MODAL_TOKEN_ID")
MODAL_SECRET = userdata.get("MODAL_TOKEN_SECRET")
HF_TOKEN     = userdata.get("HF_TOKEN")
HF_REPO      = "jcandane/pca" ### <<---- change this!
set_modal_tokens(MODAL_ID, MODAL_SECRET)

✅ Tokens configured successfully.


# Run Simulation

In [4]:
from tensorflow.keras.datasets import mnist
from LVAE.LVAE import LVAE

### MNIST Dataset
(x_train, _), _ = mnist.load_data()
x_train         = (x_train / 255.).astype(np.float32)[..., None]

### LVAE Training
lae = LVAE(d=100, epochs=100, batch_size=128, seed=42, β=0.1, hf_token=HF_TOKEN, repo_id=HF_REPO, modal_id=MODAL_ID, modal_secret=MODAL_SECRET) ## d=200
lae.train(x_train, gpu="h100")