<h1>Variational Autoencoder (VAE) Architecture</h1>

In [17]:
import os
import math
import string
import pickle
import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wget
import zipfile
import json5  
from safetensors.torch import load_file, save_file 
from typing import Tuple, List, Dict
import pandas as pd
import import_ipynb


from config_setup import TrainingConfig

<h3>Encoder </h3>

In [14]:
class Encoder(nn.Module):
    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config
        layers = []
        current_channels = config.input_channels
        dummy_h, dummy_w = config.window_size, config.feature_dim 
        
        for layer_cfg in config.encoder_conv_configs:
            # Extract parameters for Conv2d, excluding 'activation'
            conv_params = {k: v for k, v in layer_cfg.items() if k != 'activation'}
            layers.append(nn.Conv2d(current_channels, **conv_params))
            current_channels = layer_cfg["out_channels"]
            if "activation" in layer_cfg:
                # Call the static method from the TrainingConfig class
                layers.append(TrainingConfig.get_activation(layer_cfg["activation"]))
            
            kh, kw = layer_cfg["kernel_size"]
            sh, sw = layer_cfg["stride"]
            ph, pw = (layer_cfg.get("padding", 0),)*2 if isinstance(layer_cfg.get("padding", 0), int) else layer_cfg.get("padding", 0)
            dummy_h = math.floor((dummy_h + 2 * ph - kh) / sh + 1)
            dummy_w = math.floor((dummy_w + 2 * pw - kw) / sw + 1)

        layers.append(nn.Flatten())
        self.conv_to_flatten_shape = (current_channels, dummy_h, dummy_w)
        flattened_size = current_channels * dummy_h * dummy_w
        
        current_features = flattened_size 
        for layer_cfg in config.encoder_fc_configs:
            layers.append(nn.Linear(current_features, layer_cfg["out_features"]))
            current_features = layer_cfg["out_features"]
            if "activation" in layer_cfg:
                layers.append(TrainingConfig.get_activation(layer_cfg["activation"]))
        
        self.model = nn.Sequential(*layers)
        self.fc_z_mean = nn.Linear(current_features, config.latent_dim)
        self.fc_z_log_var = nn.Linear(current_features, config.latent_dim)
    
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x_encoded = self.model(x)
        z_mean = self.fc_z_mean(x_encoded)
        z_log_var = self.fc_z_log_var(x_encoded)
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        z = z_mean + eps * std
        return z_mean, z_log_var, z

<h3> Decoder</h3>

In [15]:
class Decoder(nn.Module):
    def __init__(self, config: TrainingConfig, encoder_conv_output_shape: Tuple[int, int, int]):
        super().__init__()
        self.config = config
        self.encoder_conv_output_shape = encoder_conv_output_shape
        layers = []
        decoder_start_features = np.prod(encoder_conv_output_shape)
        
        current_features = config.latent_dim
        if config.decoder_fc_configs:
            for layer_cfg in config.decoder_fc_configs:
                layers.append(nn.Linear(current_features, layer_cfg["out_features"]))
                current_features = layer_cfg["out_features"]
                if "activation" in layer_cfg:
                    layers.append(TrainingConfig.get_activation(layer_cfg["activation"]))
            layers.append(nn.Linear(current_features, decoder_start_features))
        else:
            layers.append(nn.Linear(config.latent_dim, decoder_start_features))

        if config.encoder_fc_configs:
             layers.append(TrainingConfig.get_activation(config.encoder_fc_configs[-1]['activation']))

        self.fc_part = nn.Sequential(*layers)
        
        conv_transpose_layers = []
        current_channels = self.encoder_conv_output_shape[0]
        for i, layer_cfg in enumerate(config.decoder_conv_transpose_configs):
            conv_params = {k: v for k, v in layer_cfg.items() if k != 'activation'}
            conv_transpose_layers.append(nn.ConvTranspose2d(current_channels, **conv_params))
            current_channels = layer_cfg["out_channels"]
            if "activation" in layer_cfg:
                 conv_transpose_layers.append(TrainingConfig.get_activation(layer_cfg["activation"]))

        self.conv_transpose_part = nn.Sequential(*conv_transpose_layers)
        
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        x = self.fc_part(z)
        x = x.view(-1, *self.encoder_conv_output_shape)
        x_reconstructed = self.conv_transpose_part(x)
        return x_reconstructed


<h3> VAE </h3>

In [16]:
class VAE(nn.Module):
    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config, self.encoder.conv_to_flatten_shape)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        z_mean, z_log_var, z = self.encoder(x)
        x_reconstructed = self.decoder(z)
        return x_reconstructed, z_mean, z_log_var, z

    def loss_function(self, x_original: torch.Tensor, x_reconstructed: torch.Tensor, 
                      z_mean: torch.Tensor, z_log_var: torch.Tensor) -> Dict[str, torch.Tensor]:
        bce_loss = nn.functional.binary_cross_entropy(x_reconstructed, x_original, reduction='none')
        reconstruction_loss = torch.mean(torch.sum(bce_loss, dim=(1, 2, 3)))
        kl_loss = torch.mean(-0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp(), dim=1))
        total_loss = reconstruction_loss + kl_loss
        return {"total_loss": total_loss, "reconstruction_loss": reconstruction_loss, "kl_loss": kl_loss}