In [6]:
import numpy as np
import matplotlib.pyplot as plt

from module.utils import *

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split

In [7]:
# Set the directory containing the images
input_directory = 'big_data'

transform = transforms.Compose([transforms.ToTensor()])

# Create the dataset
dataset = datasets.DatasetFolder(input_directory,  transform=transform,  loader=npy_loader,  extensions=['.npy'])

# Get the class labels and split the dataset into train and test sets
class_labels = dataset.classes
train_dataset, test_dataset = train_test_split(dataset, test_size=0.20, random_state=42)

# Print the number of samples in each split
print("Train set size:", len(train_dataset))
print("Test set size:", len(test_dataset))

# Create data loaders
batch_size = 24 
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Train set size: 204
Test set size: 52


In [8]:
# define the model
# this is a VAE model
# for the encoder, two Conv2d layers are used to extract features from the input image
# the output image is of size 1x128X128
# first have 32 filters and kernel size 4, stride 2, padding 1
# second have 64 filters and kernel size 4, stride 2, padding 1
# then the output is flattened 
# then  n_dense number of dense layers with n_hidden units
# the latent space is represented by two vectors, mean and logvar, dimension latent_dim

# for the decoder, the latent space is first passed through a n_dense number of dense layers with n_hidden units
# then the output is reshaped 
# two ConvTranspose2d layers are used to reconstruct the image
# the output image is of size 1x128X128

class VAE(torch.nn.Module):
    def __init__(self,n_conv =2, n_hidden=128, n_dense=1, latent_dim=2):
        super(VAE, self).__init__()
        self.n_hidden = n_hidden
        self.n_dense = n_dense
        self.latent_dim = latent_dim
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, 4, stride=2, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, 4, stride=2, padding=1),
            torch.nn.ReLU()
        )
        # if n_conv == 2, add another Conv2d layer
        if n_conv == 3:
            self.encoder.add_module("conv2", torch.nn.Conv2d(64, 128, 4, stride=2, padding=1))
            self.encoder.add_module("relu2", torch.nn.ReLU())

        self.encoder.add_module("flatten", torch.nn.Flatten())
        
        # now the output is flattened and n_dense number of dense layers are used
        # the output is the latent space
        for i in range(n_dense):
            if i == 0:
                if n_conv == 2:
                    self.encoder.add_module("dense_{}".format(i), torch.nn.Linear(64*32*32, n_hidden))
                if n_conv == 3:
                    self.encoder.add_module("dense_{}".format(i), torch.nn.Linear(128*16*16, n_hidden))
            else:
                self.encoder.add_module("dense_{}".format(i), torch.nn.Linear(n_hidden, n_hidden))
            self.encoder.add_module("relu_{}".format(i), torch.nn.ReLU())
        
        self.mean = torch.nn.Linear(n_hidden, latent_dim)
        self.logvar = torch.nn.Linear(n_hidden, latent_dim)

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, n_hidden),
            torch.nn.ReLU()
        )

        for i in range(n_dense):
            if i == n_dense-1:
                if n_conv == 2:
                    self.decoder.add_module("dense_{}".format(i), torch.nn.Linear(n_hidden, 64*32*32))
                if n_conv == 3:
                    self.decoder.add_module("dense_{}".format(i), torch.nn.Linear(n_hidden, 128*16*16))
            else:
                self.decoder.add_module("dense_{}".format(i), torch.nn.Linear(n_hidden, n_hidden))
            self.decoder.add_module("relu_{}".format(i), torch.nn.ReLU())
        
        if n_conv == 2:
            self.decoder.add_module("reshape", torch.nn.Unflatten(1, (64, 32, 32)))
        if n_conv == 3:
            self.decoder.add_module("reshape", torch.nn.Unflatten(1, (128, 16, 16)))
            self.decoder.add_module("deconv0", torch.nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1))
            self.decoder.add_module("relu0", torch.nn.ReLU())
        
        self.decoder.add_module("deconv1", torch.nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1))
        self.decoder.add_module("relu1", torch.nn.ReLU())
        self.decoder.add_module("deconv2", torch.nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1))
        self.decoder.add_module("sigmoid", torch.nn.Sigmoid())

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mean + eps*std
    
    def forward(self, x):
        x = self.encoder(x)
        mean = self.mean(x)
        logvar = self.logvar(x)
        z = self.reparameterize(mean, logvar)
        x = self.decoder(z)
        return x, mean, logvar
    

In [9]:
n_conv = 3
n_hidden = 40
n_dense = 2
latent_dim = 6
model = VAE(n_conv, n_hidden, n_dense, latent_dim)

In [10]:
# define the loss function
def loss_function(recon_x, x, mu, logvar):
    # mse for BCE
    BCE = torch.nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

factor = 0.8
patience = 100

# define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=factor, patience=patience)

In [11]:
# train the model
num_epochs = 5000
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for i, (images, _) in enumerate(train_loader):
        optimizer.zero_grad()
        recon_images, mu, logvar = model(images)
        loss = loss_function(recon_images, images, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    scheduler.step(train_loss)
    print(f'Epoch {epoch}, Loss {train_loss/len(train_loader.dataset)}')

Epoch 0, Loss 2802.4253121170345
Epoch 1, Loss 705.2563093596814
Epoch 2, Loss 446.9873645258885
Epoch 3, Loss 389.8036510991115
Epoch 4, Loss 372.1156652113971
Epoch 5, Loss 346.9305970435049
Epoch 6, Loss 315.6472598805147
Epoch 7, Loss 272.0441152535233
Epoch 8, Loss 242.50269512101715
Epoch 9, Loss 224.99896958295037
Epoch 10, Loss 213.71085252948836
Epoch 11, Loss 206.1837672813266
Epoch 12, Loss 202.72935216567095
Epoch 13, Loss 199.98805027382048
Epoch 14, Loss 197.25518679151347
Epoch 15, Loss 195.96240713082108
Epoch 16, Loss 194.07487457873773
Epoch 17, Loss 193.45695106655944
Epoch 18, Loss 192.7013262580423
Epoch 19, Loss 191.94423181870405
Epoch 20, Loss 191.49777999578737
Epoch 21, Loss 191.39152377259498
Epoch 22, Loss 190.3632489372702
Epoch 23, Loss 190.2813636929381
Epoch 24, Loss 189.96267999387254
Epoch 25, Loss 189.5307102577359
Epoch 26, Loss 189.46507472617952
Epoch 27, Loss 189.47074381510416
Epoch 28, Loss 189.28600834865196
Epoch 29, Loss 188.72892013250612
Ep

In [12]:
# save the model as models/VanVAE.pth
torch.save(model, 'models/Emulator_longtrain.pt')