<a href="https://colab.research.google.com/github/jhu-nanoenergy/VAE-models/blob/Thin-Model/VAE_Thin.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model Setup and Data Preparation

In [1]:
# Helpful tutorial / example links
# https://github.com/timbmg/VAE-CVAE-MNIST/blob/master/models.py
# https://debuggercafe.com/getting-started-with-variational-autoencoder-using-pytorch/

#Basic Packages
import matplotlib.pyplot as plt
import scipy.io as spio
import scipy.stats as stat
import pandas as pd
import numpy as np
import statsmodels
import os
import random
import h5py
import sys

#PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.utils.checkpoint import checkpoint 
from torch import autograd

#Ray Tune for hyperparameters
!pip install -q -U ray
from functools import partial
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

#Troubleshooting / Memory
!pip install -q -U torchinfo
from torchinfo import summary
import gc
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" #leftover from debugging but generally useful to have for cuda device side assert errors
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Plots
plt.style.use('ggplot')
!pip install -q -U seaborn

from google.colab import drive
drive.mount('/content/drive')

[K     |████████████████████████████████| 52.7 MB 164 kB/s 
[K     |████████████████████████████████| 225 kB 87.1 MB/s 
[K     |████████████████████████████████| 4.1 MB 70.0 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.[0m
[?25hMounted at /content/drive


In [2]:
# import data from hdf file
fname_mask2 ='/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/Rockfish Training Data Gen/allData_thin.h5'
hdf_file2 = h5py.File(fname_mask2, "r")
print(list(hdf_file2.keys()))

#Useful data sizes
tdn=np.shape(hdf_file2["absp"])[0] #total data number
hdn=tdn/2 #half data number
wavelengths =hdf_file2["lambda"][:]
spec_points=np.shape(wavelengths)[1]
half_data_num = 500 # amount of data to use from EACH int / ext dataset

['absp', 'ext_in', 'height', 'height_bin', 'lambda', 'masks', 'refl', 'size_frac', 'tran']


In [3]:
# # DEFINE HYPERPARAMETERS # #
# for defining data

bsize = 3 # batch size, careful about making bigger because can cause cuda error 
fv = [256, 256, 512, 1024, 2048] #Channels/Convolutions of Images
fv_inv = [1024, 512, 256] #Channels/Tranpose Convolutions of Latent Space
ks = 3 #kernel size
feat_size = 512 #Feature Space Size
latent_features = 20 # dimensionality of latent space

#Loss Paramaters
alpha = 30 #how much to weight MSE loss
beta = 1/40 #how much to weight KLD

#Training Parameters 
epochs = 50 # number of epochs to train for
lr = 1e-2 # learning rate of SGD optimizer
w_d = 1e-5 # weight decay of SGD optimizer

In [4]:
#Different Dataloader for thin data
class ImageDataset_thin(Dataset):
  #hf is the hdf5 file object
  #datanum is the number of datapoints from EACH set that will be used in the model
    def __init__(self, hf, datanum, transform= transforms.Compose([ transforms.ToTensor(), transforms.ConvertImageDtype(dtype=torch.float)])):
        super(Dataset, self).__init__()

        #Format
        self.masks = np.concatenate((hf['masks'][:datanum], hf['masks'][25000:25000+datanum]), axis=0) # masks
        self.absp = torch.from_numpy(np.concatenate((hf['absp'][:datanum], hf['absp'][25000:25000+datanum]), axis=0))
        self.refl = torch.from_numpy(np.concatenate((hf['refl'][:datanum], hf['refl'][25000:25000+datanum]), axis=0))       
        self.tran = torch.from_numpy(np.concatenate((hf['tran'][:datanum], hf['tran'][25000:25000+datanum]), axis=0))
        self.spectra= torch.stack([self.absp, self.refl, self.tran], dim = 1) #Combine
        self.heights = torch.from_numpy(np.concatenate((hf['height'],hf['height'][25000:25000+datanum]), axis=0)) # heights
        self.height_bin = torch.from_numpy(np.concatenate((hf['height_bin'][:datanum], hf['height_bin'][25000:25000+datanum]), axis=0)) # bins
        self.size_frac =  torch.from_numpy(np.concatenate((hf['size_frac'][:datanum], hf['size_frac'][25000:25000+datanum]), axis=0)) #
        self.label = torch.from_numpy(np.concatenate((hf['ext_in'][:datanum], hf['ext_in'][25000:25000+datanum]), axis=0))
        self.transform = transform

    def __len__(self):
        return len(self.size_frac)

    # currently extracting image, spectra and height 
    def __getitem__(self, idx): 
        image = self.masks[idx,:,:] # input mask image
        spectra = self.spectra[idx,:]
        height = self.heights[idx,:]

        if self.transform: 
            image = self.transform(image)
             
        return image, spectra, height

In [5]:
# # Data preparation
full_dataset2 = ImageDataset_thin(hdf_file2, half_data_num)
dummy=full_dataset2[0][1]
print(np.shape(dummy))


# Define ratios of train, validation and testing data
train_size2 = int(0.7 * len(full_dataset2))
val_size2 = int(0.2 * len(full_dataset2))
test_size2 = int(0.1 * len(full_dataset2))

# Use random split with seed
split_seed=42;
print("Split Seed is:", split_seed)
data_train2, data_val2, data_test2 = torch.utils.data.random_split(full_dataset2, [train_size2, val_size2, test_size2], generator=torch.Generator().manual_seed(split_seed))

# Split data into random batches, kills last batch
train_dataloader2 = DataLoader(data_train2, batch_size = bsize, shuffle=True, drop_last=True)
test_dataloader2 = DataLoader(data_test2, batch_size = bsize, drop_last=True)
valid_dataloader2 = DataLoader(data_val2, batch_size = bsize, drop_last=True)

torch.Size([3, 219])
Split Seed is: 42


#Model (Thin)

In [6]:
#Feature Extraction, Prediction, Recognition
class Encoder_thin(nn.Module):
  def __init__(self, fv, fv_inv, ks, feat_size, spec_points):
    super().__init__()

    self.fv = fv #Feature V=Vectors
    self.fv_inv = fv_inv #Reconstruction Vecotrs
    self.ks = ks #Kernel Size
    self.feat_size = feat_size #Latent Space Size
    self.spec_points = spec_points ## of wavelengths (always 219)

    # Feature Extraction Network 
    self.enc1 = nn.Sequential(
            # Conv_1
            nn.Conv2d(1, fv[0], kernel_size=ks, stride=1, padding="same"),
            nn.BatchNorm2d(fv[0]),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            
            # Conv_2 + Pool_1
            nn.Conv2d(fv[0], fv[1], kernel_size=ks,  stride=1, padding="same"),
            nn.BatchNorm2d(fv[1]),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),     
            nn.MaxPool2d(2,2),

            # Conv_3 + Pool_2
            nn.Conv2d(fv[1], fv[2], kernel_size=ks,  stride=1, padding="same"),
            nn.BatchNorm2d(fv[2]),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.MaxPool2d(2,2),

            # Conv_4 + Pool_3
            nn.Conv2d(fv[2], fv[3], kernel_size=ks,  stride=1, padding="same"),
            nn.BatchNorm2d(fv[3]),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.MaxPool2d(2,2),
        )
    
    #Feature Extraction Network to feature layer
    self.enc2 = nn.Sequential(
            nn.Linear(fv[3]*32*32, feat_size), #Reducing to Feature Size
            nn.BatchNorm1d(feat_size)
        )

    # Prediction network -- Reworked for separate spectra
    self.pred1 = nn.Sequential( 
            nn.Linear(feat_size, feat_size),
            nn.BatchNorm1d(feat_size),
            nn.Linear(feat_size, feat_size),
            nn.BatchNorm1d(feat_size),
            nn.Linear(feat_size, spec_points*3)
        )


    # # Recognition network
    self.rec1 =  nn.Sequential( nn.Linear(feat_size+3*spec_points, feat_size),
                               nn.BatchNorm1d(feat_size),
    )
    self.fc_mean = nn.Linear(feat_size, latent_features)
    self.fc_cov = nn.Linear(feat_size, latent_features)
    
  def forward(self, x):

    # Run Feature Extraction Network
    x = x.float()
    x = self.enc1(x)
    x = x.float()
    flat_size = np.prod(x.size()[1:])
    x = x.view(-1, flat_size) 
    x = self.enc2(x)
    
    # Run Prediction Network
    p = torch.sigmoid(self.pred1(x)) # p = predicted spectra

    # Run Recognition Network
    input_rec = torch.cat((x, p), 1) # Combine condensed geometry features with the predicted spectra    
    x = self.rec1(input_rec)
    mu = self.fc_mean(x)
    log_var = self.fc_cov(x)

    return p, mu, log_var

In [7]:
# x = torch.randn(bsize, 1, 256, 256)
# net = Encoder_thin(fv, fv_inv, ks, feat_size, spec_points)
# p, mu , log_var = net(x)

# print(p.shape)
# print(mu.shape)
print(log_var.shape)

torch.Size([3, 657])
torch.Size([3, 20])
torch.Size([3, 20])


In [8]:
class Decoder_thin(nn.Module):
  def __init__(self, fv, fv_inv, ks, feat_size, spec_points):
    super().__init__()

    self.fv = fv #Feature V=Vectors
    self.fv_inv = fv_inv #Reconstruction Vecotrs
    self.ks = ks #Kernel Size
    self.feat_size = feat_size #Latent Space Size
    self.spec_points = spec_points ## of wavelengths (always 219)

    self.recon1 = nn.Sequential(
          #Fc_4 Fc_5 Fc_6
          nn.Linear(latent_features+3*spec_points, feat_size),
          nn.BatchNorm1d(feat_size),
          nn.Linear(feat_size, feat_size),
          nn.BatchNorm1d(feat_size),
          nn.Linear(feat_size, 32*32*fv_inv[0]),
          nn.BatchNorm1d(32*32*fv_inv[0]),
        )

    self.recon2 = nn.Sequential(
          nn.ConvTranspose2d(fv_inv[0], fv_inv[1], 3, stride=2, padding = 1, output_padding=1),
          nn.ConvTranspose2d(fv_inv[1], fv_inv[2], 3, stride=2, padding = 1, output_padding=1),
          nn.ConvTranspose2d(fv_inv[2], 1, 3, stride=2, padding = 1, output_padding=1)
        )


  def forward(self, spectra, latent):
    x = torch.cat((spectra, latent), 1)
    x = self.recon1(x)
    x = x.view(int(torch.numel(x)/32/32/fv_inv[0]), fv_inv[0], 32, 32)
    x = self.recon2(x)
    reconstruction =  torch.sigmoid(x)
    return reconstruction

In [9]:
# spectra = torch.randn(bsize, 3*spec_points)
# latent = torch.randn(bsize, 20)

# net = Decoder_thin(fv, fv_inv, ks, feat_size, spec_points)
# x = net(spectra, latent)
# print(x.shape)

torch.Size([3, 1, 256, 256])


In [10]:
class CustomVAE_thin(nn.Module):
    def __init__(self, fv, fv_inv, ks, feat_size, spec_points):
        super(CustomVAE_thin, self).__init__()
        
        self.encoder = Encoder_thin(fv, fv_inv, ks, feat_size, spec_points)
        self.decoder = Decoder_thin(fv, fv_inv, ks, feat_size, spec_points)

    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample

    def forward(self, x):
        #Encode       
        spectra, mu, log_var = self.encoder(x)

        #Reparameterize
        sample = self.reparameterize(mu, log_var)

        #Reconstruct
        recon_x = self.decoder(spectra, sample)
        #recon_x = recon_x.view(int(torch.numel(recon_x)/256/256), 256, 256) #Will see what this
        return recon_x, mu, log_var, spectra 

In [11]:
# x = torch.randn(bsize, 1, 256, 256)
# net = CustomVAE_thin(fv, fv_inv, ks, feat_size, spec_points)
# recon_x, mu, log_var, spectra = net(x)

# print(recon_x.shape)
# print(mu.shape)
# print(log_var.shape)
# print(spectra.shape)

# net = None
# del net
# gc.collect()

torch.Size([3, 1, 256, 256])
torch.Size([3, 20])
torch.Size([3, 20])
torch.Size([3, 657])


11

# Loss, Fit, Validation (Thin)

In [12]:
#Loss Paramaters (REPEAT)
alpha = 2 #how much to weight MSE loss
beta = 1/40 #how much to weight KLD

#Training Parameters 
epochs = 10 # number of epochs to train for
lr = 1e-2 # learning rate of SGD optimizer
w_d = 1e-5 # weight decay of SGD optimizer

In [13]:
# combined loss function that guides the entire network's training
def final_loss(loss1_bce, loss2_mse, mu, log_var):
    # mu: the mean from the latent vector
    # logvar: log variance from the latent vector
    
    #KLD From:
    #https://stats.stackexchange.com/questions/318748/deriving-the-kl-divergence-loss-for-vaes/370048#370048
    #https://www.geeksforgeeks.org/role-of-kl-divergence-in-variational-autoencoders/
    KLD = 0.5*torch.mean(-1*(log_var+1) + mu.pow(2) - log_var.exp())

    return (loss1_bce + alpha*loss2_mse + KLD)

In [19]:
#Fit Function
def fit(model, dataloader):
    model.train() 
    running_loss = 0.0
    bce_losses = 0.0
    mse_losses = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_dataloader2)/dataloader.batch_size)):
      if torch.cuda.is_available():
          with autograd.detect_anomaly(): # uncomment this to debug when you receive "Cuda: device-side assert error"  
          
            data, spectra_in, height = [d.cuda( ) for d in data] # load in data from data loader
            #data = data.view(int(torch.numel(data)/256/256),  256, 256) # resize data for loss function later
            optimizer.zero_grad() # initialize gradients to zero

            reconstruction, mu, logvar, out_spectra = model(data) # run model
            pout = out_spectra.view(-1, 3, spec_points)  # reformat spectra for plotting
          
            # leftover code from when i was trying to debug NAN error, potentially unnecessary
            #reconstruction = reconstruction.clamp(0,1) # clamp between 0 and 1
            #reconstruction[reconstruction!=reconstruction] = 1 # set NAN values  

            # solve for loss
            bce_loss = criterion_mask(reconstruction, data) 
            mse_loss = criterion(spectra_in.float(), pout)      
            loss = final_loss(bce_loss, mse_loss, mu, logvar)

            # add losses to overall loss "counters"
            running_loss += loss.item()          
            bce_losses += bce_loss.item()
            mse_losses += mse_loss.item()

            # backpropagate loss and then step the optimizer
            loss.backward()
            optimizer.step()

    print(f"Train BCE Loss: {bce_losses:.4f}")
    print(f"Train MSE Loss: {mse_losses:.4f}")
    train_loss = running_loss
    return train_loss, bce_losses, mse_losses

In [15]:
#Validation Function
def validate(model, dataloader, plot_on):
  model.eval()
  running_loss = 0.0
  with torch.no_grad(): # all gradients off because currently just evaluating the model
    for i, data in tqdm(enumerate(dataloader), total=int(len(valid_dataloader2)/dataloader.batch_size)):
        data, spectra_in, height = [d.cuda( ) for d in data] # load in data from data loader
        len_temp = int(torch.numel(data)/256/256) # usually this is batch size, but sometimes last batch is a different size
        data = data.view(len_temp,  256, 256) # resize data for loss function later

        reconstruction, mu, logvar, out_spectra = model(data) # run model
        pout = out_spectra.view(-1, 3, spec_points) #reformat spectra for plotting

        # leftover code from when i was trying to debug NAN error, potentially unnecessary
        # reconstruction = reconstruction.clamp(0,1) # clamp between 0 and 1
        # reconstruction[reconstruction!=reconstruction] = 1 # set NAN values  
        
        # solve for loss
        bce_loss = criterion_mask(reconstruction, data) 
        mse_loss = criterion(spectra_in.float(), pout) 
        loss = final_loss(bce_loss, mse_loss, mu, logvar)
        running_loss += loss.item()

        # save and plot the last batch input and output 
        if plot_on:
          if i == len_temp - 1:
            # first plot shows input geometry vs output geometry
            num_replicas = 4
            fig, axs = plt.subplots(2,num_replicas)
            for x in range( num_replicas ):      
              axs[0,x].imshow(torch.squeeze(data.view(len_temp, 1, 256, 256)[x]).cpu())
              axs[0,x].xaxis.set_visible(False)
              axs[0,x].yaxis.set_visible(False)
              axs[1,x].imshow(torch.squeeze(reconstruction.view(len_temp, 1, 256, 256)[x]).cpu())
              axs[1,x].xaxis.set_visible(False)
              axs[1,x].yaxis.set_visible(False)
            fig.suptitle(str(epoch+1))

            # second plot shows input spectra vs output spectra
            og = spectra_in[0].detach().cpu().numpy()
            pred = pout[0].detach().cpu().numpy()
            
            fig2, axs = plt.subplots(2,1)
            axs[0].plot(wavelengths, np.transpose(og) )
            axs[1].plot(wavelengths, np.transpose(pred) )
            fig2.suptitle(str(epoch+1))

            fig.savefig(f"/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/ML Model/VAE/outputs/{epoch+1}geom_output.png")
            fig2.savefig(f"/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/ML Model/VAE/outputs/{epoch+1}spectra_output.png")
            plt.show()
            
    # val_loss = running_loss/len(dataloader.dataset)
    val_loss = running_loss
    return val_loss

#Training (Thin)

In [16]:
model_custom = CustomVAE_thin(fv, fv_inv, ks, feat_size, spec_points)
if torch.cuda.is_available():
    model_custom.cuda()

#print(summary(model_custom, input_size = (bsize, 1, 256, 256))) # print model summary
optimizer = optim.SGD(model_custom.parameters(), lr = lr)
criterion_mask = nn.BCELoss(reduction='mean') #BCE
criterion = nn.MSELoss(reduction='mean') #MSE

In [17]:
# check that cuda is at acceptable limits
torch.cuda.synchronize()
torch.cuda.empty_cache()
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    4171 MB |    4171 MB |    4171 MB |       0 B  |
|       from large pool |    4168 MB |    4168 MB |    4168 MB |       0 B  |
|       from small pool |       3 MB |       3 MB |       3 MB |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |    4171 MB |    4171 MB |    4171 MB |       0 B  |
|       from large pool |    4168 MB |    4168 MB |    4168 MB |       0 B  |
|       from small pool |       3 MB |       3 MB |       3 MB |       0 B  |
|---------------------------------------------------------------

In [20]:
# # Loop over epochs

# keep track of different losses
train_loss = []
val_loss = []
train_bce_loss = []
train_mse_loss = []
train_kld_loss=[]
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    
    # empty cuda cache to help prevent unneeded memory usage
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    
    # fit data
    train_epoch_loss, train_epoch_bce, train_epoch_mse = fit(model_custom, train_dataloader2) 

    # test on validation data
    if epoch == 0 or not ((epoch+1) % 5): # plot output every 5 epochs
      val_epoch_loss = validate(model_custom, valid_dataloader2, 1)
    else: # determine validation loss without plotting
      val_epoch_loss = validate(model_custom, valid_dataloader2, 0)
    
    # add to variables
    train_loss.append(train_epoch_loss)
    train_bce_loss.append(train_epoch_bce)
    train_mse_loss.append(alpha*train_epoch_mse)
    train_kld_loss.append((train_epoch_loss - train_epoch_bce-alpha*train_epoch_mse))
    val_loss.append(val_epoch_loss)

    # print progress
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

Epoch 1 of 10


RuntimeError: ignored

In [None]:
# Plot different train losses
x = range(len(train_loss))
fig_losses, axs = plt.subplots(2,2)
axs[0,0].plot(x, train_loss )
axs[0,0].set_xlabel("Epochs")
axs[0,0].set_ylabel("Total")
axs[0,1].plot(x, train_kld_loss)
axs[0,1].set_xlabel("Epochs")
axs[0,1].set_ylabel("beta*KLD")
axs[1,0].plot(x, train_bce_loss)
axs[1,0].set_xlabel("Epochs")
axs[1,0].set_ylabel("BCE")
axs[1,1].plot(x, train_mse_loss)
axs[1,1].set_xlabel("Epochs")
axs[1,1].set_ylabel("alpha*MSE")

fig_losses.suptitle("Train Losses")

# save figures
fig_losses.savefig(f"/content/drive/MyDrive/Thon Group Master Folder/Serene/Spectral Selectivity Project/outputs/train_losses_output.png")

In [None]:
# Plot total validation loss
fig_val = plt.figure()
plt.plot(x, val_loss)
plt.xlabel("Epochs")
plt.ylabel("Combined Validation Loss")