In [1]:
gpu_info = !nvidia-smi
gpu_info = gpu_info[:10]
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

Fri Sep 25 19:15:58 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.66       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P8     7W /  75W |      0MiB /  7611MiB |      0%      Default |


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import sys
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm_notebook as tq

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR
from torchsummary import summary

In [4]:
base_path = "/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/"
sys.path.append(base_path)
torch.manual_seed(123)

<torch._C.Generator at 0x7f7f51af4228>

In [5]:
from src.data_utils import LoadDataset, SavePath
from src.config import TrainConfig

In [6]:
args = TrainConfig( base_path,              # project directory path
                    n_epochs = 100,         # number of epochs to train (default: 100)
                    batch_size = 128,       # input batch size for training (default: 128)
                    lr = 0.0002,            # learning rate (default: 0.0001)
                    dim_h = 128,            # hidden dimension (default: 128)')
                    n_z = 8,                # hidden dimension of z (default: 8)
                    LAMBDA = 10,            # regularization coef term (default: 10)
                    sigma = 1,              # variance of hidden dimension (default: 1)
                    n_channel = 1,          # input channels (default: 1)
                    img_size = 28 )         # image size

In [7]:
# sp = SavePath(args, checkpoint_path="/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Thu-Sep-24-21-50-53-2020/")
sp = SavePath(args)

/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Fri-Sep-25-19-17-02-2020/


In [8]:
cdl = LoadDataset(args)
train_loader = cdl.get_data_loader(train=True)
test_loader = cdl.get_data_loader(train=False)

In [9]:
# for step, (images, _) in tq(enumerate(train_loader)):
#     shape = images.shape
#     images = images.reshape([shape[0], shape[1], shape[2]*shape[3]])
#     print(images.shape)
#     # print(images[0].fft(signal_ndim=1, normalized=False))
#     break

In [21]:
class Encoder(nn.Module):

    def __init__(self, args):
        super(Encoder, self).__init__()
        
        self.n_z = args.n_z
        self.dim_h = args.dim_h
        self.dim_input = args.img_size ** 2

        self.main = nn.Sequential(
            nn.Linear(self.dim_input, self.dim_h * 8),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 8, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 2),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 2, self.dim_h),
            nn.ReLU(True),
            nn.Linear(self.dim_h, self.n_z)
        )
    
    def forward(self, x):
        x = self.main(x)
        return x

In [22]:
class Decoder(nn.Module):

    def __init__(self, args):
        super(Decoder, self).__init__()

        self.n_z = args.n_z
        self.dim_h = args.dim_h
        self.dim_output = args.img_size ** 2

        self.main = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h),
            nn.ReLU(True),
            nn.Linear(self.dim_h, self.dim_h * 2),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 2, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 8),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 8, self.dim_output),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.main(x)
        return x

In [23]:
class Discriminator(nn.Module):
    def __init__(self, args):
        super(Discriminator, self).__init__()

        self.dim_h = args.dim_h
        self.n_z = args.n_z

        self.main = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.main(x)
        return x

In [24]:
def unfreeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def freeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [25]:
def save_models(model_path, epoch_no, encoder, decoder, discriminator):
    print("Saving models")
    if encoder:
        torch.save(encoder.state_dict(), model_path + "/encoder_" + "%d.pth" % epoch_no)
    if decoder:
        torch.save(decoder.state_dict(), model_path + "/decoder_" + "%d.pth" % epoch_no)
    if discriminator:
        torch.save(discriminator.state_dict(), model_path + "/discriminator_" + "%d.pth" % epoch_no)

def save_lists(list_path, epoch_no, reconstr_loss):
    print("Saving list")
    if reconstr_loss:
        np.savetxt(list_path + '/reconstr_loss_' + '%d.txt'% epoch_no, reconstr_loss)

In [26]:
encoder, decoder, discriminator = Encoder(args), Decoder(args), Discriminator(args)
criterion = nn.MSELoss()

encoder.train()
decoder.train()
discriminator.train()

# Optimizers
enc_optim = optim.Adam(encoder.parameters(), lr = args.lr)
dec_optim = optim.Adam(decoder.parameters(), lr = args.lr)
dis_optim = optim.Adam(discriminator.parameters(), lr = 0.5*args.lr)

In [27]:
if torch.cuda.is_available():
    encoder, decoder, discriminator = encoder.cuda(), decoder.cuda(), discriminator.cuda()

In [28]:
one = torch.tensor(1)
mone = one * -1
if torch.cuda.is_available():
    one, mone = one.cuda(), mone.cuda()

In [29]:
checkpoint = 0

if checkpoint:
    _, list_path, model_path = sp.get_save_paths()

    encoder.load_state_dict(torch.load(
                    model_path + "/encoder_{}.pth".format(checkpoint)))
    decoder.load_state_dict(torch.load(
                    model_path + "/decoder_{}.pth".format(checkpoint)))
    discriminator.load_state_dict(torch.load(
            model_path + "/discriminator_{}.pth".format(checkpoint)))
    reconstr_loss_epoch = np.loadtxt(
        list_path + '/reconstr_loss_{}.txt'.format(checkpoint)).tolist()
else:
    reconstr_loss_epoch = []
reconstr_loss = []

In [30]:
image_path, list_path, model_path = sp.get_save_paths()

In [None]:
for epoch in range(checkpoint, checkpoint+args.n_epochs):
    for step, (images, _) in tq(enumerate(train_loader)):
        
        reconstr_loss.clear() 

        images = images.reshape([images.size()[0], 1, args.img_size**2])

        if torch.cuda.is_available():
            images = images.cuda()

        encoder.zero_grad()
        decoder.zero_grad()
        discriminator.zero_grad()

        # ======== Train Discriminator ======== #

        freeze_params(decoder)
        freeze_params(encoder)
        unfreeze_params(discriminator)

        z_fake = torch.randn(images.size()[0], args.n_z) * args.sigma

        if torch.cuda.is_available():
            z_fake = z_fake.cuda()

        d_fake = discriminator(z_fake)

        z_real = encoder(images)
        d_real = discriminator(z_real)

        # negate for gradient ascent
        d_loss = -(torch.log(d_fake).mean() + torch.log(1 - d_real).mean())

        d_loss.backward()

        dis_optim.step()

        # ======== Train Generator ======== #

        unfreeze_params(decoder)
        unfreeze_params(encoder)
        freeze_params(discriminator)

        batch_size = images.size()[0]

        z_real = encoder(images)
        x_recon = decoder(z_real)
        d_real = discriminator(encoder(Variable(images.data)))

        recon_loss = criterion(x_recon, images) - args.LAMBDA * (torch.log(d_real)).mean()

        recon_loss.backward()
        # d_loss.backward(mone)

        enc_optim.step()
        dec_optim.step()

        reconstr_loss.append(recon_loss.data.item())

    if (epoch + 1) % 1 == 0:
        print("Epoch: [%d/%d], Step: [%d/%d], Reconstruction Loss: %.4f" %
                (epoch + 1, args.n_epochs, step + 1, len(train_loader), recon_loss.data.item()))

    if (epoch + 1) % 1 == 0:

        reconstr_loss_epoch.append(np.mean(reconstr_loss))

        batch_size = args.batch_size
        test_iter = iter(test_loader)
        test_data = next(test_iter)
        test_data = Variable(test_data[0]).reshape([batch_size, 1, args.img_size**2])
        # test_data = test_data

        z_real = encoder(test_data.cuda())
        reconst = decoder(z_real).cpu().view(batch_size, 1, 28, 28)
        sample = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 1, 28, 28)

        image = torch.cat((test_data.view(batch_size, 1, 28, 28), 
                                reconst.data), axis=3)
        
        save_image(image, image_path +
                                '/inputs_reconstr_{}.png'.format(epoch+1))
        
        save_image(sample, image_path + 
                                '/sample_{}.png'.format(epoch+1))

        # save_image(test_data[0].view(batch_size, 1, 28, 28), image_path + '/wae_gan_input.png')
        # save_image(reconst.data, image_path + '/wae_gan_images_%d.png' % (epoch + 1))
    
    if (epoch + 1) % 25 == 0:
        save_models(model_path, epoch+1, encoder, decoder, discriminator)
        save_lists(list_path, epoch+1, reconstr_loss_epoch)
    

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [1/100], Step: [469/469], Reconstruction Loss: 14.2895


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [2/100], Step: [469/469], Reconstruction Loss: 7.4935


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [3/100], Step: [469/469], Reconstruction Loss: 5.5590


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [4/100], Step: [469/469], Reconstruction Loss: 8.3077


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [5/100], Step: [469/469], Reconstruction Loss: 9.7592


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [6/100], Step: [469/469], Reconstruction Loss: 4.7941


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [7/100], Step: [469/469], Reconstruction Loss: 8.5357


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [8/100], Step: [469/469], Reconstruction Loss: 8.0075


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [9/100], Step: [469/469], Reconstruction Loss: 6.9264


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [10/100], Step: [469/469], Reconstruction Loss: 8.4204


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [11/100], Step: [469/469], Reconstruction Loss: 4.5407


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))