<a href="https://colab.research.google.com/github/skywolfmo/MAE-flax/blob/master/MLAscent_6_Flax_MAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ML Ascnet 6: MAE for SSL

Author: Taha Bouhsine - ML GDE

Contact: @tahabsn - contact@tahabouhsine.com

In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from tensorflow.keras.datasets import cifar10
import numpy as np

class MLP(nn.Module):
    hidden_dim: int
    out_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.gelu(x)
        x = nn.Dense(features=self.out_dim)(x)
        return x

class Attention(nn.Module):
    num_heads: int

    @nn.compact
    def __call__(self, x):
        d_model = x.shape[-1]
        d_head = d_model // self.num_heads

        qkv = nn.Dense(features=d_model * 3, use_bias=False)(x)
        qkv = qkv.reshape(x.shape[0], -1, 3, self.num_heads, d_head)
        qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = qkv[0], qkv[1], qkv[2]

        attention = jnp.matmul(q, jnp.swapaxes(k, -1, -2)) / jnp.sqrt(d_head)
        attention = nn.softmax(attention, axis=-1)

        y = jnp.matmul(attention, v)
        y = jnp.transpose(y, (0, 2, 1, 3))
        y = y.reshape(x.shape[0], -1, d_model)

        y = nn.Dense(features=d_model)(y)
        return y

class TransformerBlock(nn.Module):
    num_heads: int
    mlp_dim: int

    @nn.compact
    def __call__(self, x):
        y = nn.LayerNorm()(x)
        y = Attention(num_heads=self.num_heads)(y)
        x = x + y

        y = nn.LayerNorm()(x)
        y = MLP(hidden_dim=self.mlp_dim, out_dim=x.shape[-1])(y)
        return x + y

class ViT(nn.Module):
    patch_size: int = 4
    hidden_dim: int = 256
    num_heads: int = 8
    num_layers: int = 6
    mlp_dim: int = 512

    @nn.compact
    def __call__(self, x, train=True):
        b, n, c = x.shape

        # Add position embedding
        pos_embedding = self.param('pos_embedding', nn.initializers.normal(stddev=0.02), (1, n, self.hidden_dim))
        x = nn.Dense(features=self.hidden_dim)(x)
        x = x + pos_embedding

        # Transformer blocks
        for _ in range(self.num_layers):
            x = TransformerBlock(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)

        return x

class MAEEncoder(nn.Module):
    patch_size: int = 4
    hidden_dim: int = 256
    num_heads: int = 8
    num_layers: int = 6
    mlp_dim: int = 512

    @nn.compact
    def __call__(self, x, mask):
        # x is already patchified
        b, n, c = x.shape

        # Add position embedding to unmasked tokens
        pos_embedding = self.param('pos_embedding', nn.initializers.normal(stddev=0.02), (1, n, self.hidden_dim))
        x = nn.Dense(features=self.hidden_dim)(x)
        x = x + pos_embedding

        # Apply mask
        x = x * mask[:, :, None]

        # Transformer blocks
        vit = ViT(patch_size=self.patch_size, hidden_dim=self.hidden_dim,
                  num_heads=self.num_heads, num_layers=self.num_layers, mlp_dim=self.mlp_dim)
        x = vit(x)

        return x

class MAEDecoder(nn.Module):
    patch_size: int = 4
    hidden_dim: int = 128
    num_heads: int = 4
    num_layers: int = 2
    mlp_dim: int = 256

    @nn.compact
    def __call__(self, x, mask):
        b, n = mask.shape

        # Add position embedding
        pos_embedding = self.param('pos_embedding', nn.initializers.normal(stddev=0.02), (1, n, self.hidden_dim))
        x = nn.Dense(features=self.hidden_dim)(x)

        # Add mask tokens
        mask_token = self.param('mask_token', nn.initializers.normal(stddev=0.02), (1, 1, self.hidden_dim))
        mask_tokens = jnp.broadcast_to(mask_token, (b, n, self.hidden_dim))
        x = x * mask[:, :, None] + mask_tokens * (1 - mask[:, :, None])

        x = x + pos_embedding

        # Transformer blocks
        vit = ViT(patch_size=self.patch_size, hidden_dim=self.hidden_dim,
                  num_heads=self.num_heads, num_layers=self.num_layers, mlp_dim=self.mlp_dim)
        x = vit(x)

        # Project to patch dimension
        x = nn.Dense(features=self.patch_size**2 * 3)(x)
        x = nn.sigmoid(x)

        return x

class MaskedAutoencoder(nn.Module):
    encoder: nn.Module
    decoder: nn.Module
    mask_ratio: float = 0.75
    patch_size: int = 4

    @nn.compact
    def __call__(self, img, train=True, rngs=None):
        # Patchify the image (for CIFAR-10, we'll use 4x4 patches)
        patches = self.patchify(img)
        batch, num_patches, _ = patches.shape

        # Create mask
        mask = self.create_mask(rngs['mask'] if rngs is not None else None, batch, num_patches)

        # Encode
        encoded = self.encoder(patches, mask)

        # Decode
        decoded = self.decoder(encoded, mask)

        # Unpatchify the output
        reconstructed = self.unpatchify(decoded, img.shape)

        return reconstructed, mask

    def patchify(self, imgs):
        batch_size, height, width, channels = imgs.shape
        num_patches = (height // self.patch_size) * (width // self.patch_size)
        patches = imgs.reshape(batch_size, height // self.patch_size, self.patch_size, width // self.patch_size, self.patch_size, channels)
        patches = patches.transpose(0, 1, 3, 2, 4, 5)
        patches = patches.reshape(batch_size, num_patches, -1)
        return patches

    def unpatchify(self, patches, original_shape):
        batch_size, height, width, channels = original_shape
        patches = patches.reshape(batch_size, height // self.patch_size, width // self.patch_size, self.patch_size, self.patch_size, channels)
        imgs = patches.transpose(0, 1, 3, 2, 4, 5)
        imgs = imgs.reshape(batch_size, height, width, channels)
        return imgs

    def create_mask(self, rng, batch, num_patches):
        if rng is None:
            rng = self.make_rng('mask')
        noise = jax.random.uniform(rng, (batch, num_patches))
        mask = jnp.where(noise > self.mask_ratio, 1., 0.)
        return mask

@jax.jit
def train_step(state, batch, rng):
    rng, new_rng = jax.random.split(rng)
    def loss_fn(params):
        reconstructed, mask = state.apply_fn({'params': params}, batch, train=True, rngs={'mask': rng})
        loss = optax.l2_loss(reconstructed, batch).mean()
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, new_rng



def load_cifar10():
    (x_train, _), (x_test, _) = cifar10.load_data()
    x_train = x_train.astype(np.float32) / 255.
    x_test = x_test.astype(np.float32) / 255.
    return x_train, x_test

import matplotlib.pyplot as plt

def visualize_reconstructions(original, reconstructed, mask, epoch):
    fig, axs = plt.subplots(2, 4, figsize=(20, 10))

    for i in range(4):
        # Original image
        axs[0, i].imshow(original[i])
        axs[0, i].set_title(f"Original {i+1}")
        axs[0, i].axis('off')

        # Reconstructed image
        axs[1, i].imshow(reconstructed[i])
        axs[1, i].set_title(f"Reconstructed {i+1}")
        axs[1, i].axis('off')

    plt.suptitle(f"Epoch {epoch}")
    plt.tight_layout()
    plt.savefig(f"reconstruction_epoch_{epoch}.png")
    plt.close()

def evaluate_reconstruction(state, test_images, rng):
    reconstructed, mask = state.apply_fn({'params': state.params}, test_images, train=False, rngs={'mask': rng})
    return reconstructed, mask

def main():
    x_train, x_test = load_cifar10()

    encoder = MAEEncoder()
    decoder = MAEDecoder()
    mae = MaskedAutoencoder(encoder=encoder, decoder=decoder)

    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)
    params = mae.init({'params': init_rng, 'mask': init_rng}, jnp.ones((1, 32, 32, 3)))['params']

    tx = optax.adam(learning_rate=1e-3)
    state = train_state.TrainState.create(apply_fn=mae.apply, params=params, tx=tx)

    batch_size = 128
    num_epochs = 100
    steps_per_epoch = len(x_train) // batch_size

    losses = []  # List to store loss values

    # Select a fixed set of test images for visualization
    test_sample_idx = np.random.choice(len(x_test), 4, replace=False)
    test_sample = x_test[test_sample_idx]

    for epoch in range(num_epochs):
        total_loss = 0
        for step in range(steps_per_epoch):
            batch_idx = np.random.choice(len(x_train), batch_size)
            batch = x_train[batch_idx]
            rng, step_rng = jax.random.split(rng)
            state, loss, rng = train_step(state, batch, step_rng)
            total_loss += loss

        avg_loss = total_loss / steps_per_epoch
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        # Visualize reconstructions every 10 epochs
        if (epoch + 1) % 10 == 0:
            rng, eval_rng = jax.random.split(rng)
            reconstructed, mask = evaluate_reconstruction(state, test_sample, eval_rng)
            visualize_reconstructions(test_sample, reconstructed, mask, epoch + 1)

    print("Pretraining complete!")

    # Plot the loss curve
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epochs + 1), losses)
    plt.title('Training Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Average Loss')
    plt.grid(True)
    plt.savefig("loss_curve.png")
    plt.close()

if __name__ == "__main__":
    main()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Epoch 1, Average Loss: 0.0271
Epoch 2, Average Loss: 0.0126
Epoch 3, Average Loss: 0.0105
Epoch 4, Average Loss: 0.0096
Epoch 5, Average Loss: 0.0092
Epoch 6, Average Loss: 0.0088
Epoch 7, Average Loss: 0.0086
Epoch 8, Average Loss: 0.0085
Epoch 9, Average Loss: 0.0082
Epoch 10, Average Loss: 0.0082
Epoch 11, Average Loss: 0.0080
Epoch 12, Average Loss: 0.0079
Epoch 13, Average Loss: 0.0078
Epoch 14, Average Loss: 0.0078
Epoch 15, Average Loss: 0.0077
Epoch 16, Average Loss: 0.0076
Epoch 17, Average Loss: 0.0076
Epoch 18, Average Loss: 0.0075
Epoch 19, Average Loss: 0.0075
Epoch 20, Average Loss: 0.0074
Epoch 21, Average Loss: 0.0074
Epoch 22, Average Loss: 0.0074
Epoch 23, Average Loss: 0.0073
Epoch 24, Average Loss: 0.0073
Epoch 25, Average Loss: 0.0073
Epoch 26, Average Loss: 0.0072
Epoch 27, Average Loss: 0.0072
Epoch 28, Average Loss: 0.0071
Epoch 29, Average Loss: 0.0071
Epoch 30, Average Loss: 0.0071
E