# Model—Machine Learning Surrogate for MuSHrooM 

## Introduction
### Code Structure / Outline
1) Train CVAE Model
2) Train DNN (Using Generated Points in Latent Space)
3) Define and Use GAIT Model

### Model Architecture

- Encoder $\epsilon$ (CNN)
    - Conv2D Layer with kernel size 3x3 [In: `1x600x600` Out: `?x600x600`]
    - ?Conv2D Layer with kernel size 3x3 [In: `1x600x600` Out: `1x600x600`]
    - ?Conv2D Layer with kernel size 3x3 [In: `1x600x600` Out: `1x600x600`]
    - ?Conv2D Layer with kernel size 3x3 [In: `1x600x600` Out: `1x600x600`]
    - ?Conv2D Layer with kernel size 3x3 [In: `1x600x600` Out: `1x600x600`]
    - ?Flattening Layer [In: `1x600x600` Out: `1x600x600`]
    - ?FC Layers
        - ?Mean Vector ($\mu$) Layer [In: `?` Out: `?`] — center of latent distribution
        - ?Log-variance vector ($\log{(\sigma^2)}$) — spread of latent distribution
- Latent Space (Sampling)
    - Latent vector $\mathcal{z}$
$$\mathcal{z} = \mu + e^{\frac{\log{\sigma^2}}{2}} \cdot \epsilon \quad \text{where } \epsilon \text{ is randomly drawn from std. normal distribution}$$
- Deep Neural Network (three hidden layers)
- Decoder $\mathcal{D}$ (CNN) (mirror of encoder - TBD)

In [None]:
# general imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# data loading and processing
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

## 1 CVAE Model

Set hyperparameters

In [None]:
# model training
lr_cvae = 0.001
num_epochs_cvae = 10 # was 5000


# loss weighting
w_phi = 1.0
w_grad = 1.0
w_kl = 1.0

Define Model Components

In [None]:
# Encoder Network
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)  # 28x28 -> 14x14
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1) # 14x14 -> 7x7
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) # 7x7 -> 3x3
        self.fc_mu = nn.Linear(128 * 3 * 3, latent_dim)
        self.fc_logvar = nn.Linear(128 * 3 * 3, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

In [None]:
# Decoder Network
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 128 * 3 * 3)
        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # 3x3 -> 7x7
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)   # 7x7 -> 14x14
        self.deconv3 = nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1)    # 14x14 -> 28x28

    def forward(self, z):
        z = F.relu(self.fc(z))
        z = z.view(z.size(0), 128, 3, 3)  # Reshape to (batch_size, 128, 3, 3)
        z = F.relu(self.deconv1(z))
        z = F.relu(self.deconv2(z))
        return torch.sigmoid(self.deconv3(z))

In [None]:
class CVAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(CVAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        return mu + epsilon * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar

Define Loss Function

In [None]:
def loss_function(recon_x, x, mu, logvar):
    # Reconstruction loss (binary cross-entropy)
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum') # might need to change to L2 norm
    
    # TODO add recon_loss for divergence
    
    # KL divergence - regularizes the distribution of the latent space to be close to a standard normal distribution
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + kl_divergence

Initialize Model and Optimizer

In [None]:
model = CVAE(latent_dim=20)
optimizer = optim.Adam(model.parameters(), lr=lr_cvae)


Train Model

In [None]:
# Training loop
for epoch in range(num_epochs_cvae):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    
    avg_loss = train_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs_cvae}, Loss: {avg_loss:.4f}")

## DNN Model

## GAIT Model