In [1]:
import sys
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm_notebook as tq

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR

In [2]:
base_path = "/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/"
sys.path.append(base_path)
torch.manual_seed(123)

<torch._C.Generator at 0x7fcf61b9d1b0>

In [3]:
from src.data_utils import CustomDataLoader, SavePath
from src.model import Encoder, Decoder, Discriminator

In [13]:
sp = SavePath(args)

None


In [4]:
def unfreeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def freeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [5]:
def save_models(model_path, epoch_no, encoder, decoder, discriminator):
    if encoder:
        torch.save(encoder.state_dict(), model_path + "/encoder_" + "%d.pth" % epoch_no)
    if decoder:
        torch.save(decoder.state_dict(), model_path + "/decoder_" + "%d.pth" % epoch_no)
    if discriminator:
        torch.save(discriminator.state_dict(), model_path + "/discriminator_" + "%d.pth" % epoch_no)

def save_lists(list_path, epoch_no, reconstr_loss):
    if reconstr_loss:
        np.savetxt(lists_path + '/reconstr_loss_' + '%d.txt'% epoch_no, reconstr_loss)

In [6]:
class options:
    def __init__(self):
        self.n_epochs = 100         # number of epochs to train (default: 100)
        self.batch_size = 128       # input batch size for training (default: 128)
        self.lr = 0.0002            # learning rate (default: 0.0001)
        self.dim_h = 128            # hidden dimension (default: 128)')
        self.n_z = 8                # hidden dimension of z (default: 8)
        self.LAMBDA = 10            # regularization coef term (default: 10)
        self.sigma = 1              # variance of hidden dimension (default: 1)
        self.n_channel = 1          # input channels (default: 1)
        self.img_size = 28          # image size
        self.base_path = base_path  # project directory path

args = options()

In [7]:
cdl = CustomDataLoader(args)
train_loader = cdl.get_data_loader(train=True)
test_loader = cdl.get_data_loader(train=False)

In [8]:
encoder, decoder, discriminator = Encoder(args), Decoder(args), Discriminator(args)
criterion = nn.MSELoss()

encoder.train()
decoder.train()
discriminator.train()

Discriminator(
  (main): Sequential(
    (0): Linear(in_features=8, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU(inplace=True)
    (8): Linear(in_features=512, out_features=1, bias=True)
    (9): Sigmoid()
  )
)

In [9]:
# Optimizers
enc_optim = optim.Adam(encoder.parameters(), lr = args.lr)
dec_optim = optim.Adam(decoder.parameters(), lr = args.lr)
dis_optim = optim.Adam(discriminator.parameters(), lr = 0.5 * args.lr)

enc_scheduler = StepLR(enc_optim, step_size=30, gamma=0.5)
dec_scheduler = StepLR(dec_optim, step_size=30, gamma=0.5)
dis_scheduler = StepLR(dis_optim, step_size=30, gamma=0.5)

In [10]:
if torch.cuda.is_available():
    encoder, decoder, discriminator = encoder.cuda(), decoder.cuda(), discriminator.cuda()

In [11]:
one = torch.Tensor([1])
mone = one * -1

In [12]:
if torch.cuda.is_available():
    one, mone = one.cuda(), mone.cuda()

In [15]:
image_path, list_path, model_path = sp.get_save_paths()

In [16]:
reconstr_loss_epoch = []
reconstr_loss = []

In [17]:
for epoch in range(args.n_epochs):
    for step, (images, _) in tq(enumerate(train_loader)):
        
        reconstr_loss.clear() 
        # discriminator_loss.clear()
        
        if torch.cuda.is_available():
            images = images.cuda()

        encoder.zero_grad()
        decoder.zero_grad()
        discriminator.zero_grad()

        # ======== Train Discriminator ======== #

        freeze_params(decoder)
        freeze_params(encoder)
        unfreeze_params(discriminator)

        z_fake = torch.randn(images.size()[0], args.n_z) * args.sigma

        if torch.cuda.is_available():
            z_fake = z_fake.cuda()

        d_fake = discriminator(z_fake)

        z_real = encoder(images)
        d_real = discriminator(z_real)

        torch.log(d_fake).mean().backward(mone)
        torch.log(1 - d_real).mean().backward(mone)

        dis_optim.step()

        # ======== Train Generator ======== #

        unfreeze_params(decoder)
        unfreeze_params(encoder)
        freeze_params(discriminator)

        batch_size = images.size()[0]

        z_real = encoder(images)
        x_recon = decoder(z_real)
        d_real = discriminator(encoder(Variable(images.data)))

        recon_loss = criterion(x_recon, images)
        d_loss = args.LAMBDA * (torch.log(d_real)).mean()

        recon_loss.backward(one)
        d_loss.backward(mone)

        enc_optim.step()
        dec_optim.step()

        reconstr_loss.append(recon_loss.data.item())

        if (step + 1) % 300 == 0:
            print("Epoch: [%d/%d], Step: [%d/%d], Reconstruction Loss: %.4f" %
                  (epoch + 1, args.epochs, step + 1, len(train_loader), recon_loss.data.item()))

    if (epoch + 1) % 1 == 0:

        reconstr_loss_epoch.append(np.mean(reconstr_loss))

        batch_size = 104
        test_iter = iter(test_loader)
        test_data = next(test_iter)

        z_real = encoder(Variable(test_data[0]).cuda())
        reconst = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 1, 28, 28)

        save_image(test_data[0].view(batch_size, 1, 28, 28), image_path + '/wae_gan_input.png')
        save_image(reconst.data, image_path + '/wae_gan_images_%d.png' % (epoch + 1))
    
    if (epoch + 1) % 25 == 0:
        save_models(model_path, epoch+1, encoder, decoder, discriminator)
        save_lists(list_path, epoch+1, reconstr_loss_epoch)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

RuntimeError: ignored

In [127]:
class Nested:
    def __init__(self, nestedList):
            
            self.curr = nestedList
            self.stack, self.index = [], 0
            
    def flatten(self, curr, stack, index):

        if index < len(curr):
            if type(curr[index]) is int:
                # return curr[index], curr, stack, index+1
                retval = curr[index]
                index += 1
            else:
                stack.append((curr, index+1))
                curr = curr[index]
                index = 0
                retval = None
                
        elif stack:
            curr, index = stack.pop()
            retval = None
        else:
            curr = None
            retval = None

        return retval, curr, stack, index

    def value_generator(self):
        while self.curr or self.stack:
            val, self.curr, self.stack, self.index = self.flatten(self.curr, self.stack, self.index)
            print(val, self.curr, self.stack, self.index)
            if val:
                # print(val, self.curr, self.stack, self.index)
                yield val


In [128]:
nestedList = [1, [2, [3]],4,[5,6],7]

In [129]:
nested = Nested(nestedList)

In [130]:
while 1:
    try:
        next(nested.value_generator())
    except StopIteration:
        print(False)
        break

1 [1, [2, [3]], 4, [5, 6], 7] [] 1
None [2, [3]] [([1, [2, [3]], 4, [5, 6], 7], 2)] 0
2 [2, [3]] [([1, [2, [3]], 4, [5, 6], 7], 2)] 1
None [3] [([1, [2, [3]], 4, [5, 6], 7], 2), ([2, [3]], 2)] 0
3 [3] [([1, [2, [3]], 4, [5, 6], 7], 2), ([2, [3]], 2)] 1
None [2, [3]] [([1, [2, [3]], 4, [5, 6], 7], 2)] 2
None [1, [2, [3]], 4, [5, 6], 7] [] 2
4 [1, [2, [3]], 4, [5, 6], 7] [] 3
None [5, 6] [([1, [2, [3]], 4, [5, 6], 7], 4)] 0
5 [5, 6] [([1, [2, [3]], 4, [5, 6], 7], 4)] 1
6 [5, 6] [([1, [2, [3]], 4, [5, 6], 7], 4)] 2
None [1, [2, [3]], 4, [5, 6], 7] [] 4
7 [1, [2, [3]], 4, [5, 6], 7] [] 5
None None [] 5
False
