In [20]:
import jax
from flax import nnx
import jax.numpy as jnp

In [21]:
# Import additional libraries for data loading and training
from datasets import load_dataset
import numpy as np
import optax
from tqdm import tqdm
import PIL
import importlib.util

In [22]:
jax.devices()

[CpuDevice(id=0)]

In [23]:
data_dir = "./data"
# Load MNIST dataset
dataset = load_dataset("mnist", cache_dir=data_dir).with_format("numpy")
train_data = dataset["train"]
test_data = dataset["test"]

print(f"Train samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")
print(f"Sample shape: {train_data[0]['image']}")

Train samples: 60000
Test samples: 10000
Sample shape: [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3  18  18  18 126 136
  175  26 166 255 247 127   0   0   0   0]
 [  0   0   0   0   0   0   0   0  30  36  94 154 170 253 253 253 253 253
  225 172 253 242 195  64   0   0   0   0]
 [  0   0   0   0   0   0   0  49 238 253 253 253 253 253 253 253 253 251
   93  82  82  56  39   0   0   0   0   0]
 [  0   0

In [28]:
# Training configuration
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
NUM_EPOCHS = 10


In [30]:
# Preprocess data: flatten images and normalize
def preprocess_batch(batch):
    """Convert images to flattened arrays and normalize to [0, 1]"""
    images = np.array(batch["image"])
    # Flatten 28x28 images to 784-dim vectors
    images = images.reshape(images.shape[0], -1).astype(np.float32) / 255.0
    return images

# Create data batches
def create_batches(data, batch_size=32, shuffle=True):
    """Generator that yields batches of data"""
    indices = np.arange(len(data))
    if shuffle:
        np.random.shuffle(indices)
    
    for i in range(0, len(indices), batch_size):
        batch_indices = indices[i:i + batch_size]
        batch = data.select(batch_indices)
        yield preprocess_batch(batch)

In [39]:
# Define loss function
def mse_loss(model, x):
    """Mean squared error reconstruction loss"""
    x_reconstructed = model(x)
    return jnp.mean((x - x_reconstructed) ** 2)

@nnx.jit
def train_step(model, optimizer, x):
    """Single training step"""
    loss, grads = nnx.value_and_grad(mse_loss)(model, x)
    
    # Update with BOTH model and grads
    optimizer.update(grads=grads, model=model) 
    return loss

In [40]:
# Update model to handle MNIST dimensions (784 input, 784 output)
Encoder = lambda rngs: nnx.Linear(784, 128, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(128, 784, rngs=rngs)

class MNISTAutoEncoder(nnx.Module):
    def __init__(self, rngs):
        self.encoder = Encoder(rngs)
        self.decoder = Decoder(rngs)
    
    def __call__(self, x) -> jax.Array:
        return self.decoder(self.encoder(x))
    
    def encode(self, x) -> jax.Array:
        return self.encoder(x)

In [41]:

# Initialize model and optimizer
rngs = nnx.Rngs(0)
model = MNISTAutoEncoder(rngs)
optimizer = nnx.Optimizer(model, optax.adam(LEARNING_RATE), wrt=nnx.Param)

print("Model initialized")
print(f"Encoder: {model.encoder}")
print(f"Decoder: {model.decoder}")

Model initialized
Encoder: [38;2;79;201;177mLinear[0m[38;2;255;213;3m([0m[38;2;105;105;105m # Param: 100,480 (401.9 KB)[0m
  [38;2;156;220;254mkernel[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 100,352 (401.4 KB)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArray[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m784[0m, [38;2;182;207;169m128[0m[38;2;255;213;3m)[0m, [38;2;156;220;254mdtype[0m[38;2;212;212;212m=[0mdtype('float32')[38;2;255;213;3m)[0m
  [38;2;255;213;3m)[0m,
  [38;2;156;220;254mbias[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 128 (512 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArray[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m128[0m,[

In [42]:
# Training loop
for epoch in range(NUM_EPOCHS):
    epoch_losses = []
    
    # Create progress bar
    num_batches = len(train_data) // BATCH_SIZE
    pbar = tqdm(create_batches(train_data, BATCH_SIZE), 
                total=num_batches, 
                desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    for batch_x in pbar:
        loss = train_step(model, optimizer, batch_x)
        epoch_losses.append(float(loss))
        pbar.set_postfix({"loss": f"{loss:.4f}"})
    
    avg_loss = np.mean(epoch_losses)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Average Loss: {avg_loss:.4f}")

Epoch 1/10:   0%|          | 0/468 [00:00<?, ?it/s]

Epoch 1/10: 469it [00:45, 10.24it/s, loss=0.0104]                         


Epoch 1/10 - Average Loss: 0.0273


Epoch 2/10: 469it [00:42, 11.07it/s, loss=0.0061]                         


Epoch 2/10 - Average Loss: 0.0082


Epoch 3/10:  20%|█▉        | 92/468 [00:08<00:35, 10.58it/s, loss=0.0063]


KeyboardInterrupt: 

In [43]:
# Evaluate on test set
test_losses = []
for batch_x in create_batches(test_data, BATCH_SIZE, shuffle=False):
    loss = mse_loss(model, batch_x)
    test_losses.append(float(loss))

avg_test_loss = np.mean(test_losses)
print(f"Test Loss: {avg_test_loss:.4f}")

Test Loss: 0.0062
