In [1]:
import os
import torch
import numpy as np
import shutil
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

###############################################
#### Implement pytorch autoencoder on MNIST ###


# Define the encoder
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # Output: [batch_size, 32, 14, 14]
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # Output: [batch_size, 64, 7, 7]
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Flatten(),  # Output: [batch_size, 64*7*7]
            nn.Linear(64 * 7 * 7, 64),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(64, latent_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        return z


# Define the decoder
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.decoder_input = nn.Linear(latent_dim, 64 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (64, 7, 7)),
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # [batch_size, 64, 14, 14]
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # [batch_size, 32, 28, 28]
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 1, kernel_size=3, padding=1),  # [batch_size, 1, 28, 28]
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.decoder_input(z)
        x = self.decoder(x)
        return x


# Combine encoder and decoder into an autoencoder model
class Autoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def forward(self, x):
        z = self.encoder(x)
        reconstructed = self.decoder(z)
        return reconstructed

In [None]:
def extract_parameters(encoder):
    # Dictionary to store parameters for each layer
    params_dict = {}
    
    # Extract parameters from convolutional and linear layers
    for i, layer in enumerate(encoder.encoder):
        if isinstance(layer, (nn.Conv2d, nn.Linear)):
            params_dict[f'layer_{i}_weight'] = layer.weight.data.clone()
            params_dict[f'layer_{i}_bias'] = layer.bias.data.clone()
            
        elif isinstance(layer, nn.BatchNorm2d):
            params_dict[f'layer_{i}_weight'] = layer.weight.data.clone()  # gamma
            params_dict[f'layer_{i}_bias'] = layer.bias.data.clone()      # beta
            params_dict[f'layer_{i}_running_mean'] = layer.running_mean.clone()
            params_dict[f'layer_{i}_running_var'] = layer.running_var.clone()

    return params_dict

# Example usage:
latent_dim = 10
encoder = Encoder(latent_dim)
parameters = extract_parameters(encoder)

# Print the shapes of extracted parameters
for name, param in parameters.items():
    print(f"{name}: {param.shape}")

layer_0_weight: torch.Size([32, 1, 3, 3])
layer_0_bias: torch.Size([32])
layer_2_weight: torch.Size([32])
layer_2_bias: torch.Size([32])
layer_2_running_mean: torch.Size([32])
layer_2_running_var: torch.Size([32])
layer_3_weight: torch.Size([64, 32, 3, 3])
layer_3_bias: torch.Size([64])
layer_5_weight: torch.Size([64])
layer_5_bias: torch.Size([64])
layer_5_running_mean: torch.Size([64])
layer_5_running_var: torch.Size([64])
layer_7_weight: torch.Size([64, 3136])
layer_7_bias: torch.Size([64])
layer_10_weight: torch.Size([10, 64])
layer_10_bias: torch.Size([10])
