<a href="https://colab.research.google.com/github/jhu-nanoenergy/VAE-models/blob/main/AE_framework_draft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model Setup and Data Preparation

In [1]:
# Helpful tutorial / example links
# https://github.com/timbmg/VAE-CVAE-MNIST/blob/master/models.py
# https://debuggercafe.com/getting-started-with-variational-autoencoder-using-pytorch/

In [2]:
import matplotlib.pyplot as plt
import scipy.io as spio
import scipy.stats as stat
import pandas as pd
import numpy as np
import os
import random
import gc
import h5py

os.environ['CUDA_LAUNCH_BLOCKING'] = "1" #leftover from debugging but generally useful to have for cuda device side assert errors


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms

!pip install -q -U ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
import argparse
from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.utils.checkpoint import checkpoint 
from torch import autograd

!pip install -q -U torchinfo
from torchinfo import summary


plt.style.use('ggplot')

[K     |████████████████████████████████| 52.7 MB 1.2 MB/s 
[K     |████████████████████████████████| 225 kB 79.3 MB/s 
[K     |████████████████████████████████| 4.1 MB 44.3 MB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.[0m
[?25h

In [3]:
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)

Mounted at /content/drive/


In [4]:
# # DEFINE HYPERPARAMETERS # #
# for defining data
half_data_num = 500 # amount of data to use from EACH int / ext dataset
bsize = 6 # batch size, careful about making bigger because can cause cuda error 

# for defining network
fv = [256, 256, 512, 1024, 2048] #Channels/Convolutions of Images
fv_inv = [1024, 512, 256] #Channels/Tranpose Convolutions of Latent Space
ks = 3 #kernel size
feat_size = 512 #Feature Space Size
latent_features = 20 # dimensionality of latent space

# for defining loss
alpha = 2 #how much to weight MSE loss

# for defining training 
epochs = 10 # number of epochs to train for
lr = 1e-6 # learning rate of SGD optimizer
w_d = 1e-5 # weight decay of SGD optimizer

In [5]:
int_data_all = spio.loadmat('/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/Rockfish Training Data Gen/int_total_sqr_no_struct.mat', squeeze_me=True)
# ext_data_all = spio.loadmat('/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/Rockfish Training Data Gen/ext_total_sqr_no_struct.mat', squeeze_me=True)
wavelengths = int_data_all['lambda']
wavelengths = np.delete(wavelengths,np.where(wavelengths==[0.5]))
wavelengths = np.delete(wavelengths,np.where(wavelengths==[1]))
wavelengths = np.transpose(np.tile(wavelengths,(3,1)))

spec_points = np.shape(wavelengths)[0] # number of points in the spectra
# print(np.shape(wavelengths))

In [6]:
# import data from hdf file
fname_mask = '/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/Rockfish Training Data Gen/allData.h5'

hdf_file = h5py.File(fname_mask, "r")

# # Print statements for investigating data parameters
print(list(hdf_file.keys()))
dext_spectra = hdf_file['ext/maskCell'] 
dint_spectra = hdf_file['int/maskCell']
combined = np.concatenate((dext_spectra[:11],dint_spectra[:11]), axis=0)
print(np.shape(combined))

['ext', 'int']
(22, 256, 256)


In [7]:
class ImageDataset(Dataset):
  #hf is the hdf5 file object
  #datanum is the number of datapoints from EACH set that will be used in the model
    def __init__(self,  hf, datanum, transform= transforms.Compose([ transforms.ToTensor(), transforms.ConvertImageDtype(dtype=torch.half)])  ):
        super(Dataset, self).__init__()
        dext_height = hf['ext/height'] # height values of ext data
        dint_height = hf['int/height'] # height values of int data
        dext_spectra = hf['ext/spectCell'] # spectra values of ext data
        dint_spectra = hf['int/spectCell'] # spectra values of int data

        # get data into correct format
        self.spectra = torch.from_numpy(np.concatenate((dext_spectra[:datanum], dint_spectra[:datanum]), axis=0)) # spectra
        self.heights = torch.from_numpy(np.concatenate((dext_height[:datanum],dint_height[:datanum]), axis=0)) # heights
        self.masks = np.concatenate((hf['ext/maskCell'][:datanum],hf['int/maskCell'][:datanum]), axis=0)  # masks
        self.sizes =  torch.from_numpy(np.concatenate((hf['ext/size'][:datanum],hf['int/size'][:datanum]), axis=0)) # could switch to sreyas size calculation instead, not sure
        ext_label = np.ones((np.shape(dext_height))) # 1 label for ext data
        int_label = np.zeros((np.shape(dint_height))) # 0 label for ext data
        self.labels = torch.from_numpy(np.concatenate( (ext_label, int_label), axis=0)) # combine labels
        self.transform = transform             

    def __len__(self):
        return len(self.sizes)

    # currently extracting image, spectra and height 
    def __getitem__(self, idx): 
        image = (self.masks[idx,:,:]) # input mask image
        spectra = (( self.spectra[idx] )) # spectra  
        if self.transform: 
            image = self.transform(image) 
        return image, spectra

In [8]:
#Testing for Thin H5 
fname_mask2 ='/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/Rockfish Training Data Gen/allData_thin.h5'
hdf_file2 = h5py.File(fname_mask2, "r")
print(list(hdf_file2.keys()))
combined2=np.concatenate((hdf_file2['masks'][:11], hdf_file2['masks'][25000:25000+11]), axis = 0)
print(np.shape(hdf_file2['masks']))

['absp', 'ext_in', 'height', 'height_bin', 'lambda', 'masks', 'refl', 'size_frac', 'tran']
(50000, 256, 256)


In [9]:
#Different Dataloader for thin data

class ImageDataset_thin(Dataset):
  #hf is the hdf5 file object
  #datanum is the number of datapoints from EACH set that will be used in the model
    def __init__(self,  hf, datanum, transform= transforms.Compose([ transforms.ToTensor(), transforms.ConvertImageDtype(dtype=torch.half)])  ):
        super(Dataset, self).__init__()

        #Format
        self.masks = np.concatenate((hf['masks'][:datanum],hf['masks'][25000:25000+datanum]), axis=0)  # masks
        self.absp = torch.from_numpy(np.concatenate(hf['absp'], hf['absp'][25000:25000+datanum]), axis=0)
        self.refl = torch.from_numpy(np.concatenate(hf['refl'], hf['refl'][25000:25000+datanum]), axis=0)       
        self.tran = torch.from_numpy(np.concatenate(hf['tran'], hf['tran'][25000:25000+datanum]), axis=0) # spectra
        self.heights = torch.from_numpy(np.concatenate(hf['height'],hf['height'][250001:25000+datanum]), axis=0) # heights
        self.height_bin = torch.from_numpy(np.concatenate(hf['height_bin'][:datanum], hf['height_bin'][25000:25000+datanum]), axis=0) # bins
        self.size_frac =  torch.from_numpy(np.concatenate(hf['size_frac'][:datanum], hf['size_frac'][25000:25000+datanum]), axis=0) #
        self.label = torch.from_numpy(np.concatenate(hf['ext_in'][:datanum], hf['ext_in'][25000:25000+datanum]))
        self.transform = transform         

    def __len__(self):
        return len(self.sizes)

    # currently extracting image, spectra and height 
    def __getitem__(self, idx): 
        image = self.masks[idx,:,:] # input mask image
        absp =  self.absp[:,idx]
        refl = self.refl[:,idx]
        tran = self.tran[:,idx]
        height = self.heights[:,idx]
        if self.transform: 
            image = self.transform(image)
             
        return image, absp, refl, tran, height

In [10]:
# # Data preparation

full_dataset = ImageDataset(hdf_file, half_data_num )

# Define ratios of train, validation and testing data
train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = int(0.1 * len(full_dataset))

# Use random split with seed
split_seed=42;
print("Split Seed is:", split_seed)
data_train, data_val, data_test = torch.utils.data.random_split(full_dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(split_seed))


# Split data into random batches
train_dataloader = DataLoader(data_train, batch_size = bsize, shuffle=True)
test_dataloader = DataLoader(data_test, batch_size = bsize)
valid_dataloader = DataLoader(data_val, batch_size = bsize)

# clear some unnecessary variables
#del hdf_file, int_data_all
gc.collect()

Split Seed is: 42


100



```
 
```

# "Encoder" (combined Feature Extraction Network, Prediction Network and Recognition Network)

In [11]:
from re import X
class Encoder(nn.Module):
  def __init__(self):
    super().__init__()

    # Feature Extraction Network
    self.enc1 = nn.Sequential( # Feature Extraction Network extraction_layers
            # Conv_1
            nn.Conv2d(1, fv[0], kernel_size=ks, stride=1, padding="same"),
            nn.BatchNorm2d(fv[0]),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            
            # Conv_2 + Pool_1
            nn.Conv2d(fv[0], fv[1], kernel_size=ks,  stride=1, padding="same"),
            nn.BatchNorm2d(fv[1]),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),     
            nn.MaxPool2d(2,2),

            # Conv_3 + Pool_2
            nn.Conv2d(fv[1], fv[2], kernel_size=ks,  stride=1, padding="same"),
            nn.BatchNorm2d(fv[2]),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.MaxPool2d(2,2),

            # Conv_4 + Pool_3
            nn.Conv2d(fv[2], fv[3], kernel_size=ks,  stride=1, padding="same"),
            nn.BatchNorm2d(fv[3]),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.MaxPool2d(2,2),
        )
    self.enc2 = nn.Sequential( # feature_layer of Feature Extraction Network
            nn.Linear(32*32*1024, feat_size), #reducing to feat_size
            nn.BatchNorm2d(1)
        )

    # # Prediction network
    self.pred1 = nn.Sequential( 
            nn.Linear(feat_size, feat_size),
            nn.BatchNorm2d(1),
            nn.Linear(feat_size, feat_size),
            nn.BatchNorm2d(1),
            nn.Linear(feat_size, spec_points*3)
        )


    # # Recognition network
    self.rec1 =  nn.Sequential( nn.Linear(feat_size+3*spec_points, feat_size),
                               nn.BatchNorm2d(1),
    )
    self.fc_mean = nn.Linear(feat_size, latent_features)
    self.fc_cov = nn.Linear(feat_size, latent_features)
    
  def forward(self, x):

    # # Run Feature Extraction Network
    # encoding
    # batch size usually equals int(torch.numel(x)/256/256), but the final batch in loader may be less than batch size
    x = (x.view(int(torch.numel(x)/256/256), 1, 256, 256)) # get x in format of 4D tensor for inputting into Conv layers 
    x = self.enc1(x)

    # get e in format of 4D tensor for inputting into Linear layer 
    e = x.view(int(torch.numel(x)/32/32/1024), 32*32*1024)     
    e = torch.unsqueeze(e,1)
    e = torch.unsqueeze(e,1)    
    e = self.enc2(e) # run linear layer

    # # Run Prediction Network
    p = torch.sigmoid(self.pred1(e)) # p = predicted spectra

    # # Run Recognition Network
    input_rec = torch.cat((e,p), 3) # Combine condensed geometry features with the predicted spectra    
    x = self.rec1(input_rec)
    mu = self.fc_mean(x)
    log_var = self.fc_cov(x)

    return p, mu, log_var 

# "Decoder" (Regeneration Network)

In [12]:
class Decoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.recog = nn.Sequential(
          #Fc_4 Fc_5 Fc_6
          nn.Linear(latent_features+3*spec_points, feat_size),
          nn.BatchNorm2d(1),
          nn.Linear(feat_size, feat_size),
          nn.BatchNorm2d(1),
          nn.Linear(feat_size, 32*32*fv_inv[0]),
          nn.BatchNorm2d(1),
        )

    self.reconstruct = nn.Sequential(
          nn.ConvTranspose2d(fv_inv[0], fv_inv[1], 3, stride=2, padding = 1, output_padding=1),
          nn.ConvTranspose2d(fv_inv[1], fv_inv[2], 3, stride=2, padding = 1, output_padding=1),
          nn.ConvTranspose2d(fv_inv[2], 1, 3, stride=2, padding = 1, output_padding=1)
        )


  def forward(self, spectra,latent):
    input_dec = torch.cat((spectra,latent),3) # combine spectra output with latent space output
    x = self.recog(input_dec) # run recognition layer
    x = x.view(int(torch.numel(x)/32/32/fv_inv[0]), fv_inv[0], 32, 32) # ensure the 4D tensor is the correct size
    x = self.reconstruct(x) # run reconstruction layer
    reconstruction =  torch.sigmoid( x )
    return reconstruction # output geometry based on spectra and latent space

# VAE Model Definition (includes reparameterization)

In [13]:
# primary VAE module

class CustomVAE(nn.Module):
    def __init__(self):
        super(CustomVAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample

    def forward(self, x):      
        x = self.encoder(x)  
        spectra, mu, log_var  = x        

        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder( spectra, z )
        recon_x = recon_x.view(int(torch.numel(recon_x)/256/256), 256, 256) # reconstructed geometry
        return recon_x, mu, log_var, spectra

    # TO DO: DEFINE SPECTRA PREDICTION FUNCTION AND GEOMETRY GENERATION


:# Training and Validation Functions

In [14]:
# combined loss function that determines the entire network
def final_loss(loss1_bce, loss2_mse, mu, logvar):
    """
    This function will add the reconstruction loss and the  KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    
    
    # KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # previous KLD definition
    KLD = 0
    # print(KLD)
    return (loss1_bce + alpha*loss2_mse + KLD)

    # TO DO: determine whether KLD is still going to NAN or not


In [15]:
# fit function for training the model
def fit(model, dataloader):
    model.train() 
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_dataloader)/dataloader.batch_size)):
      if torch.cuda.is_available():
        # with autograd.detect_anomaly(): # uncomment this to debug when you receive "Cuda: device-side assert error"  
          
          data, spectra_in = [d.cuda( ) for d in data] # load in data from data loader
          data = data.view(int(torch.numel(data)/256/256),  256, 256) # resize data for loss function later
          optimizer.zero_grad() # initialize gradients to zero

          reconstruction, mu, logvar, out_spectra = model(data) # run model
          pout = out_spectra.view(-1, 3, spec_points)  # reformat spectra for plotting
          
          # leftover code from when i was trying to debug NAN error, potentially unnecessary
          reconstruction = reconstruction.clamp(0,1) # clamp between 0 and 1
          reconstruction[reconstruction!=reconstruction] = 1 # set NAN values  

          # solve for loss
          bce_loss = criterion_mask(reconstruction, data)          
          mse_loss = criterion(spectra_in.float(), pout)          
          loss = final_loss(bce_loss, mse_loss, mu, logvar)
          running_loss += loss.item()          

          # backpropagate loss and then step the optimizer
          loss.backward()
          optimizer.step()

    train_loss = running_loss/len(dataloader.dataset)
    return train_loss
    


In [16]:
# validate function for testing the model on the validation data

def validate(model, dataloader, plot_on):
  model.eval()
  running_loss = 0.0
  with torch.no_grad(): # all gradients off because currently just evaluating the model
    for i, data in tqdm(enumerate(dataloader), total=int(len(valid_dataloader)/dataloader.batch_size)):
        data, spectra_in = [d.cuda( ) for d in data] # load in data from data loader
        len_temp = int(torch.numel(data)/256/256) # usually this is batch size, but sometimes last batch is a different size
        data = data.view(len_temp,  256, 256) # resize data for loss function later

        reconstruction, mu, logvar, out_spectra = model(data) # run model
        pout = out_spectra.view(-1, 3, spec_points) #reformat spectra for plotting

        # leftover code from when i was trying to debug NAN error, potentially unnecessary
        reconstruction = reconstruction.clamp(0,1) # clamp between 0 and 1
        reconstruction[reconstruction!=reconstruction] = 1 # set NAN values  
        
        # solve for loss
        bce_loss = criterion_mask(reconstruction, data)
        mse_loss = criterion(spectra_in.float(), pout)
        loss = final_loss(bce_loss, mse_loss, mu, logvar)
        running_loss += loss.item()

        # save and plot the last batch input and output 
        if plot_on:
          if i == len_temp - 1:
            print("MSE loss: " + str(mse_loss)) # spectra prediction loss

            # first plot shows input geometry vs output geometry
            num_replicas = 4
            fig, axs = plt.subplots(2,num_replicas)
            for x in range( num_replicas ):      
              axs[0,x].imshow(torch.squeeze(data.view(len_temp, 1, 256, 256)[x]).cpu())
              axs[0,x].xaxis.set_visible(False)
              axs[0,x].yaxis.set_visible(False)
              axs[1,x].imshow(torch.squeeze(reconstruction.view(len_temp, 1, 256, 256)[x]).cpu())
              axs[1,x].xaxis.set_visible(False)
              axs[1,x].yaxis.set_visible(False)
            fig.suptitle(str(epoch+1))

            # second plot shows input spectra vs output spectra
            og = spectra_in[0].detach().cpu().numpy()
            pred = pout[0].detach().cpu().numpy()
            
            fig2, axs = plt.subplots(2,1)
            axs[0].plot(wavelengths, np.transpose(og) )
            axs[1].plot(wavelengths, np.transpose(pred) )
            fig2.suptitle(str(epoch+1))

            # save figures
            fig.savefig(f"/content/drive/MyDrive/Thon Group Master Folder/Serene/Spectral Selectivity Project/outputs/{epoch+1}geom_output.png")
            fig2.savefig(f"/content/drive/MyDrive/Thon Group Master Folder/Serene/Spectral Selectivity Project/outputs/{epoch+1}spectra_output.png")
            plt.show()
            
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

# Model Training

In [17]:
model_custom = CustomVAE().cuda() # send model architecture to cuda
print(summary(model_custom, input_size = (bsize,1,256,256))) # print model summary
optimizer = optim.SGD(model_custom.parameters(), lr=lr, weight_decay = w_d) #use SGD optimizer to mimic paper
criterion_mask = nn.BCELoss(reduction='sum') # use BCE loss for mask / geometry criterion
criterion = nn.MSELoss() # use MSE loss for spectra criterion

KeyboardInterrupt: ignored

In [None]:
# check that cuda is at acceptable limits
torch.cuda.synchronize()
torch.cuda.empty_cache()
print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [None]:
# Loop over epochs 
train_loss = []
val_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    
    # empty cuda cache to help prevent unneeded memory usage
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    
    # fit data
    train_epoch_loss = fit(model_custom, train_dataloader) 

    # test on validation data
    if epoch == 0 or not ((epoch+1) % 5): # plot output every 5 epochs
      val_epoch_loss = validate(model_custom, valid_dataloader, 1)
    else: # determine validation loss without plotting
      val_epoch_loss = validate(model_custom, valid_dataloader, 0)
    
    # add to variables
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)

    # print progress
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

In [None]:
# print post summary of GPU usage
!nvidia-smi