In [1]:
gpu_info = !nvidia-smi
gpu_info = gpu_info[:10]
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

Wed Oct  7 17:17:13 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.05    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import sys
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm_notebook as tq

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR

from torchsummary import summary

In [4]:
base_path = "/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/"
sys.path.append(base_path)

In [5]:
from src.data_utils import LoadDataset, SavePath
# from src.model_cifar import Discriminator
from src.config import TrainConfig

In [6]:
args = TrainConfig( base_path,              # project directory path
                    n_epochs = 100,         # number of epochs to train (default: 100)
                    batch_size = 128,       # input batch size for training (default: 128)
                    lr = 0.0002,            # learning rate (default: 0.0001)
                    dim_h = 128,            # hidden dimension (default: 128)')
                    n_z = 8,                # hidden dimension of z (default: 8)
                    LAMBDA = 10,            # regularization coef term (default: 10)
                    sigma = 1,              # variance of hidden dimension (default: 1)
                    n_channel = 3,          # input channels (default: 1)
                    img_size = 32 )         # image size

In [7]:
sp = SavePath(args, checkpoint_path="/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Wed-Oct--7-02-44-14-2020/")
# sp = SavePath(args)

/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Wed-Oct--7-02-44-14-2020/


In [8]:
def unfreeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def freeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [9]:
def save_models(model_path, epoch_no, encoder, decoder, discriminator):
    print("Saving models")
    if encoder:
        torch.save(encoder.state_dict(), model_path + "/encoder_" + "%d.pth" % epoch_no)
    if decoder:
        torch.save(decoder.state_dict(), model_path + "/decoder_" + "%d.pth" % epoch_no)
    if discriminator:
        torch.save(discriminator.state_dict(), model_path + "/discriminator_" + "%d.pth" % epoch_no)

def save_lists(list_path, epoch_no, reconstr_loss):
    print("Saving list")
    if reconstr_loss:
        np.savetxt(list_path + '/reconstr_loss_' + '%d.txt'% epoch_no, reconstr_loss)

In [10]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5]*3,[0.5]*3)
                ])

cdl = LoadDataset(args, 'cifar')

train_loader = cdl.get_data_loader(train=True)
test_loader = cdl.get_data_loader(train=False)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()

        self.n_channel = args.n_channel
        self.dim_h = args.dim_h
        self.n_z = args.n_z

        self.main = nn.Sequential(
            nn.Conv2d(self.n_channel, self.dim_h, 4, 2, 1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h, self.dim_h * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 2, self.dim_h * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 4, self.dim_h * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 8),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 8, self.dim_h, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h),
            nn.ReLU(True),
        )
        self.fc = nn.Linear(self.dim_h, self.n_z)

    def forward(self, x):
        x = self.main(x)
        x = x.squeeze()
        x = self.fc(x)
        return x

In [12]:
# enc = Encoder(args).cuda()
# summary(enc, (3, 32, 32))

In [13]:
class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()

        self.n_channel = args.n_channel
        self.dim_h = args.dim_h
        self.n_z = args.n_z

        self.proj = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 1 * 2 * 2),
            nn.ReLU()
        )

        self.main = nn.Sequential(
            nn.ConvTranspose2d(self.dim_h * 1, self.dim_h * 8, 4),
            nn.BatchNorm2d(self.dim_h * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 8, self.dim_h * 4, 4, stride=2),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 4, self.dim_h * 2, 4),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 2, self.n_channel, 4, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.proj(x)
        x = x.view(-1, self.dim_h * 1, 2, 2)
        x = self.main(x)
        return x

In [14]:
# dec = Decoder(args).cuda()
# summary(dec, (1, 8))

In [15]:
class Discriminator(nn.Module):
    def __init__(self, args):
        super(Discriminator, self).__init__()

        self.n_channel = args.n_channel
        self.dim_h = args.dim_h
        self.n_z = args.n_z

        self.main = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.main(x)
        return x

In [16]:
encoder, decoder, discriminator = Encoder(args), Decoder(args), Discriminator(args)
criterion = nn.MSELoss()

encoder.train()
decoder.train()
discriminator.train()

# Optimizers
enc_optim = optim.Adam(encoder.parameters(), lr = args.lr)
dec_optim = optim.Adam(decoder.parameters(), lr = args.lr)
dis_optim = optim.Adam(discriminator.parameters(), lr = 0.5 * args.lr)

In [17]:
if torch.cuda.is_available():
    encoder, decoder, discriminator = encoder.cuda(), decoder.cuda(), discriminator.cuda()

In [18]:
one = torch.tensor(1)
mone = one * -1
if torch.cuda.is_available():
    one, mone = one.cuda(), mone.cuda()

In [19]:
checkpoint = 0

if checkpoint:
    _, list_path, model_path = sp.get_save_paths()

    encoder.load_state_dict(torch.load(
                    model_path + "/encoder_{}.pth".format(checkpoint)))
    decoder.load_state_dict(torch.load(
                    model_path + "/decoder_{}.pth".format(checkpoint)))
    discriminator.load_state_dict(torch.load(
            model_path + "/discriminator_{}.pth".format(checkpoint)))
    reconstr_loss_epoch = np.loadtxt(
        list_path + '/reconstr_loss_{}.txt'.format(checkpoint)).tolist()

if not checkpoint:
    reconstr_loss_epoch = []
reconstr_loss = []
disc_loss = []
d_real_val = []

image_path, list_path, model_path = sp.get_save_paths()

In [20]:
for epoch in range(checkpoint, checkpoint+args.n_epochs):
    for step, (images, _) in tq(enumerate(train_loader)):
        
        if torch.cuda.is_available():
            images = images.cuda()

        encoder.zero_grad()
        decoder.zero_grad()
        discriminator.zero_grad()

        # ======== Train Discriminator ======== #

        freeze_params(decoder)
        freeze_params(encoder)
        unfreeze_params(discriminator)

        z_fake = torch.randn(images.size()[0], args.n_z) * args.sigma

        if torch.cuda.is_available():
            z_fake = z_fake.cuda()

        d_fake = discriminator(z_fake)

        z_real = encoder(images)
        d_real = discriminator(z_real)

        # negate for gradient ascent
        # d_loss = -(torch.log(d_fake).mean() + torch.log(1 - d_real).mean())
        d_loss = d_fake.mean() + (1 - d_real).mean()

        disc_loss.append([d_fake.mean().item(), d_real.mean().item(), d_loss.item()])

        d_loss.backward()

        dis_optim.step()

        if (step + 1) % 10 == 0:

            # ======== Train Generator ======== #

            unfreeze_params(decoder)
            unfreeze_params(encoder)
            freeze_params(discriminator)

            batch_size = images.size()[0]

            z_real = encoder(images)
            x_recon = decoder(z_real)
            d_real = discriminator(encoder(Variable(images.data)))

            d_real_val.append(d_real.mean().item())
            # recon_loss = criterion(x_recon, images) - (args.LAMBDA * (torch.log(d_real)).mean())

            recon_loss = criterion(x_recon, images)
            # d_loss = args.LAMBDA * (torch.log(d_real)).mean()
            d_loss = args.LAMBDA * (d_real).mean()

            reconstr_loss.append(recon_loss.data.item())

            recon_loss.backward()
            d_loss.backward()

            enc_optim.step()
            dec_optim.step()

    if (epoch + 1) % 1 == 0:
        temp = np.mean(disc_loss, axis=0)
        print("Epoch: [%d/%d], Recon Loss: [%.4f], Reg loss: [%.4f], D_fake: [%.4f], D_real: [%.4f], D_loss: [%.4f]" %
                (epoch + 1, args.n_epochs, 
                 np.mean(reconstr_loss), np.mean(d_real_val), temp[0], temp[1], temp[2]))
        disc_loss.clear()
        d_real_val.clear()

    if (epoch + 1) % 1 == 0:

        reconstr_loss_epoch.append(np.mean(reconstr_loss))
        reconstr_loss.clear()

        batch_size = 128
        test_iter = iter(test_loader)
        test_data = next(test_iter)

        z_real = encoder(Variable(test_data[0]).cuda())
        reconst = decoder(z_real).cpu().view(batch_size, 3, 32, 32)
        sample = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 3, 32, 32)

        image = torch.cat((test_data[0].view(batch_size, 3, 32, 32), 
                                reconst.data), axis=3)
        
        save_image(image, image_path +
                                '/inputs_reconstr_{}.png'.format(epoch+1), normalize = True)
        
        save_image(sample, image_path + 
                                '/sample_{}.png'.format(epoch+1), normalize = True)

    if (epoch + 1) % 25 == 0:
        save_models(model_path, epoch+1, encoder, decoder, discriminator)
        save_lists(list_path, epoch+1, reconstr_loss_epoch)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [1/100], Recon Loss: [0.0702], Reg loss: [0.1205], D_fake: [0.0828], D_real: [0.1036], D_loss: [0.9791]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [2/100], Recon Loss: [0.0470], Reg loss: [0.0000], D_fake: [0.0000], D_real: [0.0000], D_loss: [1.0000]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [3/100], Recon Loss: [0.0372], Reg loss: [0.0000], D_fake: [0.0000], D_real: [0.0000], D_loss: [1.0000]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

KeyboardInterrupt: ignored

In [76]:
reconstr_loss

[]

In [None]:
z_real = encoder(Variable(test_data[0]).cuda())        
reconst = decoder(z_real).cpu().view(batch_size, 1, 28, 28)
sample = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 1, 28, 28)

RuntimeError: ignored

In [None]:
image = torch.cat((test_data[0].view(batch_size, 1, 28, 28), 
                                reconst.data), axis=3)

In [None]:
image.shape

torch.Size([128, 1, 28, 56])

In [None]:
np.mean(disc_loss, axis=0)

In [None]:
disc_loss[0]

Output hidden; open in https://colab.research.google.com to view.