# Writing a Convolutional NN autoencoder

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


transform = transforms.ToTensor()
mnist_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(dataset=mnist_data, batch_size=64, shuffle=True)

What are we going to do here ??

- convolutional encoder, which uses conv2d layers to preserve spatial structure
- transpose convolutional decoder, which uses convtranspose2d to reconstruct images

In [None]:
class CNNAutoencoder(nn.Module):
    def __init__(self):
        super(CNNAutoencoder, self).__init__()
        

        self.encoder = nn.Sequential(
          
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 16x14x14
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 32x7x7
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=7),                     # 64x1x1
            nn.ReLU()
        )
        
      
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=7),            # 32x7x7
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # 16x14x14
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 1x28x28
            nn.Sigmoid()  # Output between 0 and 1
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
class DeepCNNAutoencoder(nn.Module):
    def __init__(self):
        super(DeepCNNAutoencoder, self).__init__()
        
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),   # 32x28x28
            nn.ReLU(),
            nn.MaxPool2d(2, 2)                            # 32x14x14
        )
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),  # 64x14x14
            nn.ReLU(),
            nn.MaxPool2d(2, 2)                            # 64x7x7
        )
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1), # 128x7x7
            nn.ReLU(),
            nn.MaxPool2d(2, 2, padding=1)                 # 128x4x4
        )
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1), # 256x4x4
            nn.ReLU()
        )
        
        # Decoder
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), # 128x8x8
            nn.ReLU()
        )
        
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),  # 64x16x16
            nn.ReLU()
        )
        
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),   # 32x32x32
            nn.ReLU()
        )
        
        self.final = nn.Sequential(
            nn.Conv2d(32, 1, kernel_size=5, padding=2),            # 1x32x32
            nn.Sigmoid()
        )

        self.crop = nn.Sequential()
    
    def forward(self, x):
        enc1_out = self.enc1(x)
        enc2_out = self.enc2(enc1_out)
        enc3_out = self.enc3(enc2_out)
        bottleneck_out = self.bottleneck(enc3_out)
        dec1_out = self.dec1(bottleneck_out)
        dec2_out = self.dec2(dec1_out)
        dec3_out = self.dec3(dec2_out)
        final_out = self.final(dec3_out)
        cropped = final_out[:, :, 2:30, 2:30]
        
        return cropped


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


model = CNNAutoencoder().to(device)  
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("\nModel Architecture:")
print(model)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

In [None]:
def train_autoencoder(model, data_loader, criterion, optimizer, epochs=15):
    model.train()
    train_losses = []
    
    for epoch in range(epochs):
        running_loss = 0.0
        for batch_idx, (data, _) in enumerate(data_loader):
            data = data.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            reconstructed = model(data)
            loss = criterion(reconstructed, data)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if batch_idx % 200 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch_idx}], Loss: {loss.item():.6f}')
        
        epoch_loss = running_loss / len(data_loader)
        train_losses.append(epoch_loss)
        print(f'Epoch [{epoch+1}/{epochs}] Average Loss: {epoch_loss:.6f}')
    
    return train_losses

In [None]:
def visualize_reconstruction(model, data_loader, num_images=8):
    model.eval()
    with torch.no_grad():
        
        data_iter = iter(data_loader)
        images, _ = next(data_iter)
        images = images[:num_images].to(device)
        reconstructed = model(images)
        
        images = images.cpu()
        reconstructed = reconstructed.cpu()
        fig, axes = plt.subplots(2, num_images, figsize=(15, 4))
        
        for i in range(num_images):
            axes[0, i].imshow(images[i].squeeze(), cmap='gray')
            axes[0, i].set_title('Original')
            axes[0, i].axis('off')
            
            axes[1, i].imshow(reconstructed[i].squeeze(), cmap='gray')
            axes[1, i].set_title('Reconstructed')
            axes[1, i].axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
def visualize_feature_maps(model, data_loader, layer_idx=0):
    model.eval()
    with torch.no_grad():
  
        data_iter = iter(data_loader)
        images, _ = next(data_iter)
        img = images[0:1].to(device)
        
        if hasattr(model, 'encoder'):
            x = img
            for i, layer in enumerate(model.encoder):
                x = layer(x)
                if i == layer_idx * 2 + 1:
                    feature_maps = x
                    break

        feature_maps = feature_maps.cpu().squeeze()

        fig, axes = plt.subplots(4, 4, figsize=(12, 12))
        for i in range(min(16, feature_maps.shape[0])):
            row, col = i // 4, i % 4
            axes[row, col].imshow(feature_maps[i], cmap='viridis')
            axes[row, col].set_title(f'Feature Map {i+1}')
            axes[row, col].axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
print("\nStarting training...")
train_losses = train_autoencoder(model, data_loader, criterion, optimizer, epochs=15)

plt.figure(figsize=(10, 6))
plt.plot(train_losses)
plt.title('CNN Autoencoder Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()

print("Visualizing reconstructions...")
visualize_reconstruction(model, data_loader)
print("Visualizing feature maps from first encoder layer...")
visualize_feature_maps(model, data_loader, layer_idx=0)