In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.datasets as datasets  
from torch.utils.data import DataLoader  

In [2]:
from torchvision.datasets import ImageFolder

In [3]:
# Define the path to your custom image
custom_image_path = r'C:\Users\DK\Desktop\DKTech\pytorch\Custom Vae\input'

In [4]:
# Define the image transformation pipeline
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((56, 56)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [5]:
# Load the custom dataset
dataset = ImageFolder(root=custom_image_path, transform=transform)

In [6]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, z_dim, h_dim=200):
        super().__init__()
        # encoder
        self.img_2hid1 = nn.Linear(3136, 784)
        self.img_2hid2 = nn.Linear(784, 200)

        # one for mu and one for stds, note how we only output
        # diagonal values of covariance matrix. Here we assume
        # the pixels are conditionally independent 
        self.hid_2mu = nn.Linear(200, 20)
        self.hid_2sigma = nn.Linear(200, 20)

        # decoder
        self.z_2hid1 = nn.Linear(20, 200)
        self.z_2hid2 = nn.Linear(200, 784)
        self.hid_2img = nn.Linear(784, 3136)

    def encode(self, x):
        h = F.relu(self.img_2hid1(x))
        h = F.relu(self.img_2hid2(h))
        #h = F.relu(self.img_2hid3(h))
        mu = self.hid_2mu(h)
        sigma = self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        new_h = F.relu(self.z_2hid1(z))
        new_h = F.relu(self.z_2hid2(new_h))
        #new_h = F.relu(self.z_2hid3(new_h))
        x = torch.sigmoid(self.hid_2img(new_h))
        return x

    def forward(self, x):
        mu, sigma = self.encode(x)

        # Sample from latent distribution from encoder
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma*epsilon

        x = self.decode(z_reparametrized)
        return x, mu, sigma


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 3136
Z_DIM = 20
H_DIM = 200
NUM_EPOCHS = 250
BATCH_SIZE = 32
LR_RATE = 3e-4

In [8]:
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
# Define train function
def train(num_epochs, model, optimizer, loss_fn):
    # Start training
    for epoch in range(num_epochs):
        loop = tqdm(enumerate(train_loader))
        for i, (x, y) in loop:
            # Forward pass
            x = x.to(device).view(-1, INPUT_DIM)
            x_reconst, mu, sigma = model(x)

            # loss, formulas from https://www.youtube.com/watch?v=igP03FXZqgo&t=2182s
            reconst_loss = loss_fn(x_reconst, x)
            kl_div = - torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

            # Backprop and optimize
            loss = reconst_loss + kl_div
            optimizer.zero_grad() #before backword they are zero'ing the parameters
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item())



In [10]:
# Initialize model, optimizer, loss
model = VariationalAutoEncoder(3136, 20).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

# Run training
train(NUM_EPOCHS, model, optimizer, loss_fn)

2it [00:06,  3.12s/it, loss=3.12e+4]
2it [00:01,  1.16it/s, loss=3.05e+4]
2it [00:01,  1.44it/s, loss=3.07e+4]
2it [00:01,  1.44it/s, loss=3.05e+4]
2it [00:01,  1.72it/s, loss=2.97e+4]
2it [00:01,  1.84it/s, loss=3e+4]   
2it [00:01,  1.47it/s, loss=3.07e+4]
2it [00:01,  1.01it/s, loss=3.02e+4]
2it [00:01,  1.33it/s, loss=2.95e+4]
2it [00:01,  1.17it/s, loss=3e+4]   
2it [00:01,  1.72it/s, loss=2.85e+4]
2it [00:01,  1.43it/s, loss=2.91e+4]
2it [00:01,  1.57it/s, loss=3.07e+4]
2it [00:01,  1.79it/s, loss=2.86e+4]
2it [00:01,  1.26it/s, loss=3.11e+4]
2it [00:01,  1.65it/s, loss=2.89e+4]
2it [00:01,  1.31it/s, loss=2.9e+4] 
2it [00:01,  1.44it/s, loss=2.99e+4]
2it [00:01,  1.22it/s, loss=2.92e+4]
2it [00:01,  1.59it/s, loss=2.7e+4] 
2it [00:01,  1.76it/s, loss=2.59e+4]
2it [00:01,  1.88it/s, loss=2.91e+4]
2it [00:01,  1.96it/s, loss=3.06e+4]
2it [00:01,  1.89it/s, loss=2.52e+4]
2it [00:01,  1.89it/s, loss=2.77e+4]
2it [00:01,  1.85it/s, loss=27200.0]
2it [00:01,  1.91it/s, loss=2.84e+4]
2

2it [00:01,  1.17it/s, loss=2.33e+4]
2it [00:01,  1.13it/s, loss=2.15e+4]
2it [00:01,  1.05it/s, loss=2.32e+4]
2it [00:01,  1.40it/s, loss=1.95e+4]
2it [00:01,  1.34it/s, loss=2.42e+4]
2it [00:01,  1.30it/s, loss=1.91e+4]
2it [00:01,  1.17it/s, loss=2.38e+4]
2it [00:01,  1.05it/s, loss=2.49e+4]
2it [00:01,  1.07it/s, loss=2.4e+4] 
2it [00:01,  1.16it/s, loss=2.19e+4]
2it [00:01,  1.17it/s, loss=2.36e+4]
2it [00:01,  1.33it/s, loss=2.13e+4]
2it [00:01,  1.28it/s, loss=2.27e+4]
2it [00:01,  1.18it/s, loss=2.15e+4]
2it [00:01,  1.24it/s, loss=2.35e+4]
2it [00:01,  1.25it/s, loss=2.27e+4]
2it [00:01,  1.27it/s, loss=2.24e+4]
2it [00:01,  1.36it/s, loss=21410.0]
2it [00:01,  1.27it/s, loss=2.38e+4]
2it [00:01,  1.25it/s, loss=2.36e+4]
2it [00:01,  1.28it/s, loss=2.21e+4]
2it [00:01,  1.27it/s, loss=1.98e+4]
2it [00:01,  1.23it/s, loss=2.42e+4]
2it [00:01,  1.17it/s, loss=2.26e+4]
2it [00:01,  1.36it/s, loss=2.18e+4]
2it [00:01,  1.36it/s, loss=2.46e+4]
2it [00:01,  1.25it/s, loss=2.32e+4]
2

In [13]:
def inference(digit, num_examples=1):
    """
    Generates (num_examples) of a particular digit.
    Specifically we extract an example of each digit,
    then after we have the mu, sigma representation for
    each digit we can sample from that.

    After we sample we can run the decoder part of the VAE
    and generate examples.
    """
    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 2:
            break

    encodings_digit = []
    for d in range(2):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1, 3136))
        encodings_digit.append((mu, sigma))

    mu, sigma = encodings_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 56, 56)
        save_image(out, f"generated_new_{digit}_ex{example}.png")



In [17]:
inference(0)