In [1]:
import os
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data as data_utils
import wandb
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath("__file__"))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath("__file__")))))
import custom_loaders
from conv_layers import UpTranspose2d, DownConv2d, UpTranspose2d, UpSampleConv
from utils import count_parameters

In [88]:
class DecoderTrans(nn.Module):

    def __init__(self, channels, kernelSize=4):
        super().__init__()

        self.gen = nn.ModuleList([UpTranspose2d(channels[i], channels[i+1], kernelSize) for i in range(len(channels) - 2)])
        self.output = nn.ConvTranspose2d(channels[-2], channels[-1], kernel_size=kernelSize, stride = 2, padding=(kernelSize//2 - 1))

    def forward(self, image):
            
        for block in self.gen:
            image = block(image)
    
        output = torch.tanh(self.output(image))
        return output

class DecoderUpSample(nn.Module):

    def __init__(self, channels, kernelSize=3):
        super().__init__()

        self.gen = nn.ModuleList([UpSampleConv(channels[i], channels[i+1], kernelSize) for i in range(len(channels) - 2)])
        self.upSample = nn.Upsample(scale_factor=2, mode='nearest') 
        self.output = nn.Conv2d(channels[-2], channels[-1], kernel_size=kernelSize, padding = (kernelSize-1)//2)


    def forward(self, image):

        for block in self.gen:
            image = block(self.upSample(image))

        output = torch.tanh(self.output(self.upSample(image)))
        return output
    
class Encoder(nn.Module):

    def __init__(self, channels, kernelSize=4):
        super().__init__()

        self.dis = nn.ModuleList([DownConv2d(channels[i], channels[i+1], kernelSize) for i in range(len(channels) - 2)])
        self.out = nn.Conv2d(in_channels=channels[-2], out_channels=channels[-1], kernel_size=kernelSize, stride = 2, padding = kernelSize//2 - 1)

    def forward(self, image):
        
        for block in self.dis:
            image = block(image)
        
        out = self.out(image).squeeze(2).squeeze(2)
        mean, logvar = out[:, :out.shape[1]//2], out[:, out.shape[1]//2:]
        return mean, logvar
    
class VAE(nn.Module):
    def __init__(self, channelsD, channelsE):
        super(VAE, self).__init__()
        self.encoder = Encoder(channelsE)
        self.decoder = DecoderUpSample(channelsD)

    def reparameterize(self, mu, logvar):

        eps = torch.normal(mean = torch.zeros_like(mu), std= torch.ones_like(logvar))
        return mu + eps * torch.exp(logvar * 0.5)


    def forward(self, img):

        mu, logvar = self.encoder(img)
        z = self.reparameterize(mu, logvar).unsqueeze(2).unsqueeze(2)
        return self.decoder(z), mu, logvar


In [89]:
latent_len = 100
img_size = 64
n_channels = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
channelsD = [latent_len, 256, 128, 128,64, 32, n_channels]
channelsE = [n_channels, 32, 64, 128, 128,256, latent_len*2]

lr = 0.001
epochs = 20

assert(len(channelsD) == len(channelsE))
assert(img_size == 2**(len(channelsD) - 1))

In [None]:
class Args():
    def __init__(self):
        self.dataset = 'GAN'
        self.imgPath = '../datasets/CelebA_train/img_align_celeba'
        self.imgSize = img_size
        self.download = False
        self.imgC = n_channels
        self.num_images = 5000
        self.convert2bw = False

args = Args()
print("Loading data...")
train_dataset = custom_loaders.get_data_loader(args)

In [96]:
vae = VAE(channelsD, channelsE).to(device)
count_parameters(vae)

Total Trainable Params: 3.473291 M


3473291

In [98]:
bs = 32
train_loader = data_utils.DataLoader(train_dataset, batch_size=bs, shuffle=True)

MSEloss = nn.MSELoss()

optimizer = torch.optim.Adam(vae.parameters(), lr=lr, betas=(0.5, 0.999))
fixed_noise = torch.rand(bs,latent_len,1,1).to(device)

In [99]:
# config={"epochs": epochs, "batch_size": bs, "lr": lr,
#            "img_size": img_size, "n_channels": n_channels,
#            "latent_len": latent_len}

# wandb.init(project='pytorch-gen-celeba', entity='basujindal123', config=config)

In [100]:
log_iter = 200
log = True

lossKL = 0
lossMSE = 0
iter = 0


for i in (range(epochs)):
    for data in tqdm(train_loader):
        imgs = data.to(device)

        iter+=1
        # Training Discriminator
        vae.zero_grad()

        recon_imgs, mu, logvar = vae(imgs)

        loss = MSEloss(imgs, recon_imgs)

        loss.backward()

        optimizer.step()

        lossMSE += loss.item()
    print(loss)

        # if((iter+1)%log_iter == 0 and log==True):

        #     vae.eval()
        #     with torch.no_grad():
        #         fixed_fake_imgs = vae.decoder(fixed_noise[:16]).detach()

        #     # wandb.log({
        #     #     'lossMSE': lossMSE,
        #     #     'lossKL': lossKL,
        #     #     'Fake Images': [wandb.Image(i) for i in fixed_fake_imgs],
        #     #     'Real Images' : [wandb.Image(i) for i in real_imgs[:16].detach()]
        #     #     })

        #     lossKL = 0
        #     lossMSE = 0

100%|██████████| 16/16 [00:00<00:00, 127.98it/s]


tensor(0.1993, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 131.75it/s]


tensor(0.2049, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 129.47it/s]


tensor(0.1887, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 133.53it/s]


tensor(0.1517, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 132.66it/s]


tensor(0.1455, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 131.49it/s]


tensor(0.1356, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 134.21it/s]


tensor(0.1463, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 134.53it/s]


tensor(0.1443, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 133.51it/s]


tensor(0.1707, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 133.70it/s]


tensor(0.1497, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 133.82it/s]


tensor(0.1650, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 133.95it/s]


tensor(0.1254, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 134.37it/s]


tensor(0.1164, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 135.08it/s]


tensor(0.1246, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 134.78it/s]


tensor(0.1090, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 132.69it/s]


tensor(0.1264, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 134.93it/s]


tensor(0.1230, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 134.45it/s]


tensor(0.1105, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 134.46it/s]


tensor(0.1329, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|██████████| 16/16 [00:00<00:00, 135.28it/s]

tensor(0.1218, device='cuda:0', grad_fn=<MseLossBackward0>)



