# Semi-Supervised VAE Showcase


This notebook demonstrates how the refactored SSVAE learns useful structure from mostly unlabeled data in three stages: (1) random initialization, (2) unsupervised VAE training, and (3) semi-supervised fine-tuning with only 50 labels.


## 1. Setup & Imports


In [None]:
import sys
import time
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import MinMaxScaler

# Ensure project root is on sys.path
ROOT = Path.cwd()
while ROOT != ROOT.parent and not (ROOT / 'pyproject.toml').exists():
    ROOT = ROOT.parent
if (ROOT / 'pyproject.toml').exists() and str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from configs.base import SSVAEConfig
from ssvae import SSVAE
from training.interactive_trainer import InteractiveTrainer

np.random.seed(0)
ARTIFACT_DIR = ROOT / 'artifacts' / 'showcase'
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_STAGE2 = ARTIFACT_DIR / 'stage2_unsupervised.ckpt'
CHECKPOINT_STAGE3 = ARTIFACT_DIR / 'stage3_semi_supervised.ckpt'
plt.rcParams.update({'figure.figsize': (8, 6), 'axes.grid': True})


In [None]:
def plot_latent_with_labels(z, labels, title, path=None, cmap_name="tab10"):
    z = np.asarray(z)
    labels = np.asarray(labels)
    cmap = plt.cm.get_cmap(cmap_name, 10)
    fig, ax = plt.subplots(figsize=(6, 5))
    scatter = ax.scatter(z[:, 0], z[:, 1], c=labels, cmap=cmap, s=8, alpha=0.6, edgecolors='none')
    ax.set_xlabel("Latent dim 1")
    ax.set_ylabel("Latent dim 2")
    ax.set_title(title)
    legend_handles = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor=cmap(i), markersize=6, label=str(i))
        for i in range(10)
    ]
    ax.legend(handles=legend_handles, title="Digit", loc='upper right', bbox_to_anchor=(1.3, 1.0))
    fig.tight_layout()
    if path is not None:
        fig.savefig(path, dpi=150, bbox_inches='tight')
    plt.show()
    plt.close(fig)

def plot_latent_continuous(z, values, title, path=None, cmap_name="viridis"):
    z = np.asarray(z)
    values = np.asarray(values)
    fig, ax = plt.subplots(figsize=(6, 5))
    scatter = ax.scatter(z[:, 0], z[:, 1], c=values, cmap=cmap_name, s=8, alpha=0.6, edgecolors='none')
    ax.set_xlabel("Latent dim 1")
    ax.set_ylabel("Latent dim 2")
    ax.set_title(title)
    cbar = fig.colorbar(scatter, ax=ax)
    cbar.set_label("Value")
    fig.tight_layout()
    if path is not None:
        fig.savefig(path, dpi=150, bbox_inches='tight')
    plt.show()
    plt.close(fig)

def show_reconstructions(inputs, reconstructions, count=8, title="Reconstructions", path=None, seed=0):
    rng = np.random.default_rng(seed)
    idx = rng.choice(inputs.shape[0], size=count, replace=False)
    fig, axes = plt.subplots(2, count, figsize=(1.6 * count, 3.2))
    for i, index in enumerate(idx):
        axes[0, i].imshow(inputs[index], cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(reconstructions[index], cmap='gray')
        axes[1, i].axis('off')
    axes[0, 0].set_ylabel("Input")
    axes[1, 0].set_ylabel("Recon")
    fig.suptitle(title, fontsize=14)
    fig.tight_layout()
    if path is not None:
        fig.savefig(path, dpi=150, bbox_inches='tight')
    plt.show()
    plt.close(fig)


## 2. Data Preparation


In [None]:
print("Fetching MNIST (this may download once)...")
X_all, y_all = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
X_all = X_all.astype(np.float32)
y_all = y_all.astype(np.int32)

TRAIN_SIZE = 10000
TEST_SIZE = 2000
x_train_raw = X_all[:TRAIN_SIZE]
y_train = y_all[:TRAIN_SIZE]
x_test_raw = X_all[60000:60000 + TEST_SIZE]
y_test = y_all[60000:60000 + TEST_SIZE]

scaler = MinMaxScaler()
x_train_scaled = scaler.fit_transform(x_train_raw)
x_test_scaled = scaler.transform(x_test_raw)

x_train_images = x_train_scaled.reshape(-1, 28, 28)
x_test_images = x_test_scaled.reshape(-1, 28, 28)

def binarize_images(images, threshold=0.5):
    return np.where(images > threshold, 1.0, 0.0).astype(np.float32)

x_train = binarize_images(x_train_images)
x_test = binarize_images(x_test_images)

labels_all_nan = np.full(x_train.shape[0], np.nan, dtype=np.float32)
print(f'Train set: {x_train.shape}, Test set: {x_test.shape}')


## 3. Stage 1 – Untrained Model
Random weights yield no meaningful structure in the latent space.


In [None]:
config = SSVAEConfig(max_epochs=60, patience=10, batch_size=1024)
stage1_model = SSVAE(input_dim=(28, 28), config=config)
stage1_latent, stage1_recon, stage1_pred, stage1_certainty = stage1_model.predict(x_train)
plot_latent_with_labels(stage1_latent, y_train, "Stage 1: Untrained Model (Random Initialization)", path=ARTIFACT_DIR / "stage1_latent.png")


## 4. Stage 2 – Unsupervised Training
Train the VAE with no labels to uncover structure from reconstruction alone.


In [None]:
stage2_model = SSVAE(input_dim=(28, 28), config=config)
unsupervised_labels = np.full(x_train.shape[0], np.nan, dtype=np.float32)
start = time.time()
history_stage2 = stage2_model.fit(x_train, unsupervised_labels, weights_path=str(CHECKPOINT_STAGE2))
stage2_elapsed = time.time() - start
stage2_recon_loss = history_stage2['reconstruction_loss'][-1]
print(f'Unsupervised training time: {stage2_elapsed / 60:.2f} minutes')
print(f'Final reconstruction loss: {stage2_recon_loss:.4f}')


In [None]:
stage2_latent, stage2_recon, _, _ = stage2_model.predict(x_train)
plot_latent_with_labels(stage2_latent, y_train, "Stage 2: After Unsupervised Training (No Labels Used)", path=ARTIFACT_DIR / "stage2_latent.png")
show_reconstructions(x_train, stage2_recon, title="Reconstructions After Unsupervised Training", path=ARTIFACT_DIR / "stage2_recon.png")


## 5. Stage 3 – Semi-Supervised Fine-Tuning (50 Labels)
Select 50 labels (5 per class) and continue training with the interactive trainer.


In [None]:
rng = np.random.default_rng(0)
labels_semi = np.full(x_train.shape[0], np.nan, dtype=np.float32)
labeled_indices = []
for digit in range(10):
    digit_idx = np.where(y_train == digit)[0]
    chosen = rng.choice(digit_idx, size=5, replace=False)
    labeled_indices.extend(chosen)
labels_semi[labeled_indices] = y_train[labeled_indices].astype(np.float32)
labeled_fraction = len(labeled_indices) / x_train.shape[0]
print(f'Number of labeled samples: {len(labeled_indices)} ({labeled_fraction * 100:.2f}% of training data)')

interactive_trainer = InteractiveTrainer(stage2_model)
start = time.time()
history_stage3 = interactive_trainer.train_epochs(
    num_epochs=15,
    data=x_train,
    labels=labels_semi,
    weights_path=str(CHECKPOINT_STAGE3),
    patience=5,
)
stage3_elapsed = time.time() - start
print(f'Semi-supervised fine-tuning time: {stage3_elapsed:.1f} seconds')


In [None]:
stage3_latent, _, stage3_pred, stage3_certainty = stage2_model.predict(x_train)
_, _, stage3_pred_test, _ = stage2_model.predict(x_test)
labeled_accuracy = np.mean(stage3_pred[labeled_indices] == y_train[labeled_indices])
train_accuracy = np.mean(stage3_pred == y_train)
test_accuracy = np.mean(stage3_pred_test == y_test)
print(f'Accuracy on labeled subset: {labeled_accuracy * 100:.2f}%')
print(f'Accuracy on full train set: {train_accuracy * 100:.2f}%')
print(f'Accuracy on held-out test set: {test_accuracy * 100:.2f}%')
plot_latent_with_labels(stage3_latent, y_train, "Stage 3: After Semi-Supervised Training (50 Labels)", path=ARTIFACT_DIR / "stage3_latent.png")
plot_latent_with_labels(stage3_latent, stage3_pred, "Predictions on Unlabeled Data", path=ARTIFACT_DIR / "stage3_predictions.png")
plot_latent_continuous(stage3_latent, stage3_certainty, "Prediction Certainty", path=ARTIFACT_DIR / "stage3_certainty.png")


In [None]:
comparison_path = ARTIFACT_DIR / "comparison.png"
fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharex=True, sharey=True)
titles = [
    "Stage 1: Random",
    "Stage 2: Unsupervised",
    "Stage 3: Semi-Supervised",
]
latent_sets = [stage1_latent, stage2_latent, stage3_latent]
for ax, latent, title in zip(axes, latent_sets, titles):
    scatter = ax.scatter(latent[:, 0], latent[:, 1], c=y_train, cmap=plt.cm.get_cmap('tab10', 10), s=8, alpha=0.6, edgecolors='none')
    ax.set_title(title)
    ax.set_xlabel('Latent dim 1')
axes[0].set_ylabel('Latent dim 2')
legend_handles = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor=plt.cm.get_cmap('tab10', 10)(i), markersize=6, label=str(i))
    for i in range(10)
]
axes[-1].legend(handles=legend_handles, title='Digit', loc='upper right', bbox_to_anchor=(1.25, 1.0))
fig.suptitle('Semi-Supervised Learning Journey: From Random to Structured Classification', fontsize=16)
fig.tight_layout()
fig.savefig(comparison_path, dpi=150, bbox_inches='tight')
plt.show()
plt.close(fig)


In [None]:
print('--- Summary ---')
print(f'Train samples used: {x_train.shape[0]} (labels provided for {len(labeled_indices)})')
print(f'Fraction labeled: {labeled_fraction * 100:.2f}%')
print(f'Unsupervised reconstruction loss: {stage2_recon_loss:.4f}')
print(f'Train accuracy after semi-supervised fine-tuning: {train_accuracy * 100:.2f}%')
print(f'Test accuracy after semi-supervised fine-tuning: {test_accuracy * 100:.2f}%')
print(f'Artifacts saved to: {ARTIFACT_DIR.resolve()}')
