# MNIST Denoising with From-Scratch CNN (NumPy only)

This notebook:
- Uses the provided `DataGenerator` to load and normalize MNIST.
- Adds synthetic noise to images to create a denoising task.
- Implements a small fully-convolutional autoencoder **from scratch with NumPy**.
- Trains with MSE loss and visualizes denoised outputs.



In [32]:
import numpy as np
from data_generator import DataGenerator

dg = DataGenerator(verbose=True)
dg.generate(dataset="mnist", N_train=5000, N_valid=0.1)

# Clean targets (image-to-image target)
y_train = dg.x_train
y_valid = dg.x_valid

# Create noisy inputs
rng = np.random.RandomState(0)
sigma = 0.25  # noise strength; try 0.15â€“0.35

x_train = np.clip(y_train + sigma * rng.randn(*y_train.shape).astype(np.float32), -1.0, 1.0)
x_valid = np.clip(y_valid + sigma * rng.randn(*y_valid.shape).astype(np.float32), -1.0, 1.0)

print("x_train:", x_train.shape, x_train.min(), x_train.max())
print("y_train:", y_train.shape, y_train.min(), y_train.max())


Data specification:
	Dataset type:           mnist
	Number of classes:      10
	Number of channels:     1
	Training data shape:    (5000, 28, 28, 1)
	Validation data shape:  (6000, 28, 28, 1)
	Test data shape:        (10000, 28, 28, 1)
x_train: (5000, 28, 28, 1) -1.0 1.0
y_train: (5000, 28, 28, 1) -1.0 1.0


In [36]:
import importlib
import cnn
importlib.reload(cnn)

from cnn import CNN, init_deep_image_to_image_cnn

input_shape = x_train.shape[1:]  # (28, 28, 1) for MNIST
W_list, b_list, lname = init_deep_image_to_image_cnn(
    input_shape=x_train.shape[1:],
    num_filters=(8, 8),          # two conv layers
)

model = CNN(dataset=dg, verbose=True)
model.setup_model(W_list, b_list, lname, activation="relu")


CNN model set up with layers:
  Layer 0: conv, W shape: (3, 3, 1, 8)
  Layer 1: conv, W shape: (3, 3, 8, 8)
  Layer 2: conv_out, W shape: (3, 3, 8, 1)


In [None]:
history = model.fit(x_train, y_train, epochs=5, batch_size=64, lr=1e-3)





In [None]:
stats = model.evaluate(x_valid[:128], y_valid[:128], metric="mse")
print("Valid MSE:", stats)

stats_psnr = model.evaluate(x_valid[:128], y_valid[:128], metric="psnr", max_val=1.0)
print("Valid PSNR:", stats_psnr)
