In [1]:
import torch
from torchvision import transforms
from torch.utils.data import Dataset
import os
from pathlib import Path
from PIL import Image
from torch import nn

In [2]:
class FolderDataset(Dataset):
    """
    Creates a PyTorch dataset from folder, returning two tensor images.
    Args: 
    main_dir : directory where images are stored.
    transform (optional) : torchvision transforms to be applied while making dataset
    """

    def __init__(self, main_dir, transform=None):
        self.main_dir = main_dir
        self.transform = transform
        self.all_imgs = self._get_images_path(main_dir)

    def __len__(self):
        return len(self.all_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.all_imgs[idx])
        image = Image.open(img_loc).convert("RGB")

        if self.transform is not None:
            tensor_image = self.transform(image)

        return tensor_image, tensor_image
    
    def _get_images_path(self, main_dir):
        image_extensions = ["jpg", "jpeg", "png", "gif", "bmp", "tiff"]
        data_dir = Path(main_dir)
        images_path = [file for file in data_dir.glob(
            '**/*') if file.suffix.lower()[1:] in image_extensions]
        return images_path

In [30]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class UnFlatten(nn.Module):
    def forward(self, input, size=36864):
        return input.view(input.size(0), size, 1, 1)
    
    

In [77]:
class VAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=1024, z_dim=32):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2),
            nn.ReLU(),
            Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(h_dim, 512, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
            nn.Sigmoid(),
        )
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)
        esp = torch.randn(*mu.size())
        z = mu + std * esp
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h)
        return z, mu, logvar

    def decode(self, z):
        z = self.fc3(z)
        z = self.decoder(z)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar

In [78]:
model =  VAE(h_dim=36864)

# model(torch.rand(1,3,224,224))
x, _, _ = model(torch.rand(1,3,224,224))
x.shape

torch.Size([1, 3, 224, 224])

In [69]:
import torch.nn.functional as F

def loss_fn(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
    # BCE = F.mse_loss(recon_x, x, size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD, BCE, KLD

In [66]:
import torch
import torchvision.transforms as T
import torch.optim as optim
import sys

transforms = T.Compose([T.ToTensor()]) # Normalize the pixels and convert to tensor.
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                            (0.2023, 0.1994, 0.2010)),
])

path= "C:/Users/Maods/Documents/Development/Mestrado/terumo/apps/renal-pathology-retrieval/data/02_data_split/train_data/"
full_dataset = FolderDataset(path, transform) # Create folder dataset.

train_size = 0.75
val_size = 1 - train_size

# Split data to train and test
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size]) 

# Create the train dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
 
# Create the validation dataloader
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32)


In [83]:

batch_size=32
EPOCHS = 10
model =  VAE(h_dim=36864)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 

for epoch in range(EPOCHS):
    for idx, (train_img, target_img) in enumerate(train_loader):
        recon_images, mu, logvar = model(train_img)
        loss, bce, kld = loss_fn(recon_images, target_img, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        to_print = "Epoch[{}/{}] Loss: {:.3f} {:.3f} {:.3f}".format(epoch+1, 
                                EPOCHS, loss.data.item()/batch_size, bce.data.item()/batch_size, kld.data.item()/batch_size)
        print(to_print)



Epoch[1/10] Loss: 101213.758 101213.758 0.000
Epoch[1/10] Loss: 84033.586 84033.531 0.053
Epoch[1/10] Loss: 1941409.000 1940456.250 952.777
Epoch[1/10] Loss: -53099.398 -53099.703 0.303
Epoch[1/10] Loss: 89793.742 89793.742 0.001
Epoch[1/10] Loss: 92777.656 92777.656 0.003
Epoch[1/10] Loss: 89976.148 89976.133 0.014
Epoch[1/10] Loss: 82882.992 82882.984 0.007
Epoch[1/10] Loss: 96399.266 96399.266 0.003
Epoch[1/10] Loss: 93285.680 93285.672 0.009
Epoch[1/10] Loss: 60047.527 60047.500 0.028
Epoch[1/10] Loss: -258187.719 -258187.781 0.068
Epoch[1/10] Loss: -5324.415 -5324.463 0.048
Epoch[1/10] Loss: 36022.215 36022.176 0.040
Epoch[1/10] Loss: 54664.004 54663.969 0.034
Epoch[1/10] Loss: 54546.703 54546.672 0.031
Epoch[1/10] Loss: 42758.254 42758.219 0.035
Epoch[1/10] Loss: -94284.570 -94284.625 0.057
Epoch[1/10] Loss: -756955.750 -756955.875 0.094
Epoch[1/10] Loss: -403818.688 -403818.750 0.069
Epoch[1/10] Loss: -406584.844 -406584.906 0.048
Epoch[1/10] Loss: -1366712.125 -1366712.250 0.07

KeyboardInterrupt: 

In [82]:
loss.data.item()

3427417.5

In [None]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        # encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)

        self.relu = nn.ReLU()

    def encode(self, x):
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h))

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_new = mu + sigma*epsilon
        x_reconstructed = self.decode(z_new)
        return x_reconstructed, mu, sigma


if __name__ == "__main__":
    x = torch.randn(4, 28*28)
    vae = VariationalAutoEncoder(input_dim=784)
    x_reconstructed, mu, sigma = vae(x)
    print(x_reconstructed.shape)
    print(mu.shape)
    print(sigma.shape)

In [None]:
import torch
import torchvision.datasets as datasets  # Standard datasets
from tqdm import tqdm
from torch import nn, optim
from model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20
NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4  # Karpathy constant

# Dataset Loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")


def inference(digit, num_examples=1):
    """
    Generates (num_examples) of a particular digit.
    Specifically we extract an example of each digit,
    then after we have the mu, sigma representation for
    each digit we can sample from that.

    After we sample we can run the decoder part of the VAE
    and generate examples.
    """
    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 10:
            break

    encodings_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1, 784))
        encodings_digit.append((mu, sigma))

    mu, sigma = encodings_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f"generated_{digit}_ex{example}.png")

for idx in range(10):
    inference(idx, num_examples=5)