<a href="https://colab.research.google.com/github/jhu-nanoenergy/VAE-models/blob/main/AE_initial_easy_mode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import scipy.io as spio
import scipy.stats as stat
import pandas as pd
import numpy as np
import os
import random

import torch
import torchvision
import torch.optim as optim
import argparse
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F

from tqdm import tqdm
from torchvision import datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import h5py

plt.style.use('ggplot')

In [None]:
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)

# Get all data from drive
int_data_all = spio.loadmat('/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/Rockfish Training Data Gen/int_total_sqr_no_struct.mat', squeeze_me=True)
# ext_data_all = spio.loadmat('/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/Rockfish Training Data Gen/ext_total_sqr_no_struct.mat', squeeze_me=True)
wavelengths = int_data_all['lambda']

Mounted at /content/drive/


In [None]:
# int_data_all.keys()

In [None]:
fname_int_mask = '/content/drive/MyDrive/Thon Group Master Folder/Sreyas/Photonic Crystals/Rockfish Training Data Gen/int_mask_iter.h5'

hf = h5py.File(fname_int_mask, "r")
int_masks = hf['masks']
print(np.shape(int_masks)) #type h5py dataset, but when you index it, it's np ndarray


(20000, 256, 256)


In [None]:
A = np.expand_dims(int_data_all['A'],2)
R = np.expand_dims(int_data_all['R'],2)
T = np.expand_dims(int_data_all['T'],2)
spec = (np.concatenate((A, R, T), axis=2))

In [None]:
np.shape(np.swapaxes(spec,1,0))

(20000, 221, 3)

In [None]:
int_data_all.keys()


dict_keys(['__header__', '__version__', '__globals__', 'A', 'R', 'T', 'h', 'i', 'lambda', 'mask_iter'])

In [None]:
class ImageDataset(Dataset):
    def __init__(self,  masks, data_all, transform=transforms.ToTensor() ):
        super(Dataset, self).__init__()
        A = np.expand_dims(data_all['A'],2)
        R = np.expand_dims(data_all['R'],2)
        T = np.expand_dims(data_all['T'],2)
        spec = np.concatenate((A, R, T), axis=2)
        print(np.shape(np.swapaxes(spec,1,0)))
        self.spectra = torch.from_numpy(np.swapaxes(spec,1,0))
        
        htemp = torch.from_numpy(data_all['h'])
        self.heights = htemp.unsqueeze(1)
        self.masks =  masks
        self.transform = transform  
              

    def __len__(self):
        return len(self.heights)

    def __getitem__(self, idx):
        image = (self.masks[idx,:,:])
        height= np.around( self.heights[idx], decimals=1)
        # spectra = (( self.spectra[idx] ))
        if self.transform:
            image = self.transform(image)
        return image, height
        # return image, spectra, height

In [None]:
# # Linear VAE example that includes random parameterization........

features = 16
# define a simple linear VAE
class LinearVAE(nn.Module):
    def __init__(self):
        super(LinearVAE, self).__init__()
 
        # encoder
        self.enc1 = nn.Linear(in_features=(256*256), out_features=512)
        self.enc2 = nn.Linear(in_features=512, out_features=features*2)
 
        # decoder 
        self.dec1 = nn.Linear(in_features=features, out_features=512)
        self.dec2 = nn.Linear(in_features=512, out_features=(256*256))
    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample

    def forward(self, x):
        # encoding
                
        x = F.relu(self.enc1(x))
        # print(x.size()) torch.Size([batchsize, 512])
        x = self.enc2(x).view(-1, 2, features)
        # print(x.size()) torch.Size([batchsize, 2, 16])
        # get `mu` and `log_var`
        mu = x[:, 0, :] # the first feature values as mean
        log_var = x[:, 1, :] # the other feature values as variance
        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
 
        # decoding
        x = F.relu(self.dec1(z))
        # print(x.size()) torch.Size([batchsize, 512])
        reconstruction = torch.sigmoid(self.dec2(x))
        return reconstruction, mu, log_var

In [None]:
# DEFINE RANDOM SEED 
# CREATE VALIDATION SET

In [None]:
full_dataset = ImageDataset(masks=int_masks, data_all = int_data_all )

train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = int(0.1 * len(full_dataset))
data_temp, data_test = torch.utils.data.random_split(full_dataset, [train_size+val_size, test_size])

data_train, data_val = torch.utils.data.random_split(data_temp, [train_size, val_size])
bsize = 256
train_dataloader = DataLoader(data_train, batch_size = bsize, shuffle=True)
test_dataloader = DataLoader(data_test, batch_size = bsize, shuffle=True)
valid_dataloader = DataLoader(data_val, batch_size = bsize, shuffle=True)
print(len(train_dataloader))
print(len(test_dataloader))
print(len(valid_dataloader))

(20000, 221, 3)
55
8
16


In [None]:
print(len(full_dataset))

20000


In [None]:
# train_features, train_heights = next(iter(train_dataloader))
# print(f"Feature batch shape: {train_features.size()}")
# print(f"Heights batch shape: {train_heights.size()}")

# train_features, train_spectras, train_heights = next(iter(train_dataloader))
# print(f"Feature batch shape: {train_features.size()}")
# print(f"Spectras batch shape: {train_spectras.size()}")
# print(f"Heights batch shape: {train_heights.size()}")


In [None]:
# print(train_heights)

In [None]:
# c_ind = 10;
# img = train_features[c_ind].squeeze()
# plt.imshow(img, cmap="gray")
# plt.show()
# plt.plot(wavelengths, train_spectras[c_ind,:,0],wavelengths, train_spectras[c_ind,:,1],wavelengths, train_spectras[c_ind,:,2] )
# plt.legend(("A", "R", "T"))

In [None]:

epochs = 10
batch_size = bsize
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model = LinearVAE().to(device, dtype=torch.float)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss(reduction='sum')

In [None]:
def final_loss(bce_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD



In [None]:
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_dataloader)/dataloader.batch_size)):
        data, _ = data
        data = data.to(device, dtype=torch.float)
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        bce_loss = criterion(reconstruction, data)
        print(reconstruction.shape)
        print(data.shape)
        loss = final_loss(bce_loss, mu, logvar)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss

In [None]:
def validate(model, dataloader):
  model.eval()
  running_loss = 0.0
  with torch.no_grad():
    for i, data in tqdm(enumerate(dataloader), total=int(len(valid_dataloader)/dataloader.batch_size)):
        data, h = data
        data = data.to(device, dtype=torch.float)
        data = data.view(data.size(0), -1)
        reconstruction, mu, logvar = model(data)
        bce_loss = criterion(reconstruction, data)
        loss = final_loss(bce_loss, mu, logvar)
        running_loss += loss.item()

        # save the last batch input and output of every epoch
        # if i == int(len(valid_dataloader)/dataloader.batch_size) - 1:
        if i == 14:   
        # if i == 1:
            # both = torch.cat((data.view(batch_size, 1, 256, 256)[:4],  
            #                   reconstruction.view(batch_size, 1, 256, 256)[:4]))
            
            num_replicas = 4
            fig, axs = plt.subplots(2,num_replicas)
            for x in range( num_replicas ):
      
              axs[0,x].imshow(torch.squeeze(data.view(batch_size, 1, 256, 256)[x]).cpu())
              axs[0,x].xaxis.set_visible(False)
              axs[0,x].yaxis.set_visible(False)
              axs[1,x].imshow(torch.squeeze(reconstruction.view(batch_size, 1, 256, 256)[x]).cpu())
              axs[1,x].xaxis.set_visible(False)
              axs[1,x].yaxis.set_visible(False)
            fig.suptitle(str(epoch))
            # fig.savefig(f"/content/drive/MyDrive/Thon Group Master Folder/Serene/Spectral Selectivity Project/outputs/output{epoch}.png")
            # save_image(both.cpu(), f"../outputs/output{epoch}.png", nrow=num_rows)
    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [None]:
train_loss = []
val_loss = []
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_dataloader)
    val_epoch_loss = validate(model, valid_dataloader)
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Val Loss: {val_epoch_loss:.4f}")

Epoch 1 of 10


1it [01:10, 70.10s/it]

torch.Size([256, 65536])
torch.Size([256, 65536])


2it [01:30, 40.79s/it]

torch.Size([256, 65536])
torch.Size([256, 65536])


2it [01:33, 46.81s/it]


KeyboardInterrupt: ignored