In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageOps
import os
import torch.nn.functional as F
from torchvision.utils import save_image

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
class CustomImageDataset(Dataset):
  
    def read_data_set(self):
        all_img_files = []
        all_labels = []
        class_names = []
        for index, i in enumerate(os.listdir(self.data_set_path)):
            class_names.append(i)
            for j in os.listdir(self.data_set_path+'/'+i):
                all_img_files.append(self.data_set_path+'/'+i+'/'+j)
                all_labels.append(index)
        return all_img_files, all_labels, len(all_img_files), len(class_names)
    def __init__(self, data_set_path, transforms=None):
        self.data_set_path = data_set_path

        self.image_files_path, self.labels, self.length, self.num_classes = self.read_data_set()
        self.transforms = transforms

    def __getitem__(self, index):
        image = Image.open(self.image_files_path[index])
        image = ImageOps.grayscale(image)
        #image = image.convert("RGB")

        if self.transforms is not None:
            image = self.transforms(image)

        label = self.labels[index]
        return image, label

    def __len__(self):
        return self.length

Using cpu device


In [2]:
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [3]:
h_dim = 400
z_dim = 20
num_epochs = 100
batch_size = 128
learning_rate = 1e-3
image_size = 128 * 128
transforms_train = transforms.Compose([transforms.Resize((128, 128)),
                                       transforms.RandomRotation(10.),
                                       transforms.ToTensor()])
train_data_set = CustomImageDataset(data_set_path="./cat_dog/train", transforms=transforms_train)
data_loader = torch.utils.data.DataLoader(dataset=train_data_set, batch_size=batch_size, shuffle=True)

In [4]:
# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=128*128, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim) 
        self.fc3 = nn.Linear(h_dim, z_dim) 
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std
    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var
model = VAE().to(device)

In [None]:

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i+1) % 2 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    with torch.no_grad():
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 128, 128)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 128, 128), out.view(-1, 1, 128, 128)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))




Epoch[1/100], Step [2/4], Reconst Loss: 1595640.2500, KL Div: 24655.1172
Epoch[1/100], Step [4/4], Reconst Loss: 193615.6094, KL Div: 2509.4373
Epoch[2/100], Step [2/4], Reconst Loss: 1488045.7500, KL Div: 16328.0508
Epoch[2/100], Step [4/4], Reconst Loss: 181206.9688, KL Div: 888.4000
Epoch[3/100], Step [2/4], Reconst Loss: 1454680.8750, KL Div: 5395.0483
Epoch[3/100], Step [4/4], Reconst Loss: 181468.6562, KL Div: 400.7295
Epoch[4/100], Step [2/4], Reconst Loss: 1449862.5000, KL Div: 3034.8708
Epoch[4/100], Step [4/4], Reconst Loss: 182156.5938, KL Div: 348.9175
Epoch[5/100], Step [2/4], Reconst Loss: 1438007.0000, KL Div: 2440.5532
Epoch[5/100], Step [4/4], Reconst Loss: 179315.8438, KL Div: 309.1459
Epoch[6/100], Step [2/4], Reconst Loss: 1431224.6250, KL Div: 2595.8833
Epoch[6/100], Step [4/4], Reconst Loss: 177986.1562, KL Div: 268.9095
Epoch[7/100], Step [2/4], Reconst Loss: 1422469.2500, KL Div: 2527.5278
Epoch[7/100], Step [4/4], Reconst Loss: 177802.5000, KL Div: 311.4706
Epo

Epoch[58/100], Step [2/4], Reconst Loss: 1268190.6250, KL Div: 6880.3403
Epoch[58/100], Step [4/4], Reconst Loss: 164702.5469, KL Div: 758.7197
Epoch[59/100], Step [2/4], Reconst Loss: 1261343.5000, KL Div: 6977.5488
Epoch[59/100], Step [4/4], Reconst Loss: 160610.0312, KL Div: 814.8655
Epoch[60/100], Step [2/4], Reconst Loss: 1273879.5000, KL Div: 6407.4932
Epoch[60/100], Step [4/4], Reconst Loss: 154767.5312, KL Div: 708.5917
Epoch[61/100], Step [2/4], Reconst Loss: 1258059.1250, KL Div: 6618.1660
Epoch[61/100], Step [4/4], Reconst Loss: 155274.2500, KL Div: 1036.7024
Epoch[62/100], Step [2/4], Reconst Loss: 1269985.2500, KL Div: 6859.4365
Epoch[62/100], Step [4/4], Reconst Loss: 157404.5156, KL Div: 1034.2200
Epoch[63/100], Step [2/4], Reconst Loss: 1267698.5000, KL Div: 6655.5928
Epoch[63/100], Step [4/4], Reconst Loss: 154888.9688, KL Div: 901.1783
Epoch[64/100], Step [2/4], Reconst Loss: 1270701.7500, KL Div: 7132.3267
Epoch[64/100], Step [4/4], Reconst Loss: 158239.1094, KL Div:

In [None]:
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg 
import numpy as np
reconsPath = './samples/reconst-15.png'
Image = mpimg.imread(reconsPath)
plt.imshow(Image)
plt.axis('off')
plt.show()


In [None]:
genPath = './samples/sampled-15.png'
Image = mpimg.imread(genPath)
plt.imshow(Image) 
plt.axis('off') 
plt.show()