In [5]:
gpu_info = !nvidia-smi
gpu_info = gpu_info[:10]
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

Tue Mar  2 08:16:25 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  TITAN V             On   | 00000000:17:00.0 Off |                  N/A |
| 28%   41C    P8    26W / 250W |      0MiB / 12066MiB |      0%      Default |


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [20]:
import sys
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm_notebook as tq

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR
from pytorch_model_summary import summary
from tensorboardX import SummaryWriter

In [9]:
from src.dataset.utils import SavePath
from src.dataset.dataset import Cifar
from src.config import TrainConfig
from src.pytorch_msssim import MSSSIM, ssim

In [11]:
base_path = !pwd
base_path = base_path[0] + '/'

In [12]:
args = TrainConfig( base_path,              # project directory path
                    n_epochs = 200,         # number of epochs to train (default: 100)
                    batch_size = 128,       # input batch size for training (default: 128)
                    lr = 3e-4,            # learning rate (default: 0.0001)
                    dim_h = 128,            # hidden dimension (default: 128)')
                    n_z = 128,                # hidden dimension of z (default: 8)
                    LAMBDA = 10,            # regularization coef term (default: 10)
                    sigma = 1,              # variance of hidden dimension (default: 1)
                    n_channel = 3,          # input channels (default: 1)
                    img_size = 32 )         # image size

In [13]:
def unfreeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def freeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [14]:
def save_models(model_path, epoch_no, models):
    print("Saving models")
    for model_name, model in models.items():
        torch.save(model.state_dict(), model_path + model_name + "_" + "%d.pth" % epoch_no)

def save_values_to_tensorboard(writer, epoch_no, values_dict: dict):
    for name, val in values_dict.items():
        if type(val) == dict:
            writer.add_scalars(name, val, epoch_no)
        else:
            writer.add_scalar(name, val, epoch_no)

def save_images_to_tensorboard(writer, epoch_no, image, imname='im'):
    writer.add_image(imname +'_{}'.format(epoch_no), image, epoch_no)

In [15]:
sp = SavePath(args)

/home/pr/synth.data/Autoencoder/outs/Tue-Mar--2-08-21-21-2021/


In [16]:
transform = None # dont normalize

cdl = Cifar(args)
train_loader = cdl.get_data_loader(True, transform, [0,1,2,3,4])
test_loader = cdl.get_data_loader(False, transform, [0,1,2,3,4])

Files already downloaded and verified
Files already downloaded and verified


In [23]:
from src.models.model_cifar import Encoder as ConvEncoder
from src.models.model_cifar import Decoder as ConvDecoder
from src.models.model_cifar import GanDiscriminator

In [44]:
class ZDiscriminator(nn.Module):
    def __init__(self, args):
        super(ZDiscriminator, self).__init__()

        self.n_channel = args.n_channel
        self.dim_h = args.dim_h
        self.n_z = args.n_z

        self.main = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 2),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 2, self.dim_h // 2),
            nn.ReLU(True),
            nn.Linear(self.dim_h // 2, self.dim_h // 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h // 4, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.main(x)
        x = x.squeeze()
        return x

In [45]:
conv_encoder, conv_decoder = ConvEncoder(args).cuda(), ConvDecoder(args).cuda()
gan_discriminator = GanDiscriminator(args).cuda()
z_discriminator = ZDiscriminator(args).cuda()

mse_loss_fn = nn.MSELoss().cuda()
adversarial_loss_fn = nn.BCELoss().cuda()

opt_encoder = optim.Adam(conv_encoder.parameters(), lr = args.lr)
opt_decoder = optim.Adam(conv_decoder.parameters(), lr = args.lr)
opt_discriminator = optim.Adam(gan_discriminator.parameters(), lr = args.lr)
opt_zdiscriminator = optim.Adam(z_discriminator.parameters(), lr = args.lr)

In [46]:
print(summary(z_discriminator, torch.zeros((1, 1, 128)).cuda(), show_input=True, show_hierarchical=False))

-----------------------------------------------------------------------
      Layer (type)         Input Shape         Param #     Tr. Param #
          Linear-1         [1, 1, 128]          66,048          66,048
            ReLU-2         [1, 1, 512]               0               0
          Linear-3         [1, 1, 512]         262,656         262,656
            ReLU-4         [1, 1, 512]               0               0
          Linear-5         [1, 1, 512]         131,328         131,328
            ReLU-6         [1, 1, 256]               0               0
          Linear-7         [1, 1, 256]          16,448          16,448
            ReLU-8          [1, 1, 64]               0               0
          Linear-9          [1, 1, 64]           2,080           2,080
           ReLU-10          [1, 1, 32]               0               0
         Linear-11          [1, 1, 32]              33              33
        Sigmoid-12           [1, 1, 1]               0               0
Total

In [None]:
def load_models(checkpoint_path, checkpoint):
    lp = SavePath(args, checkpoint_path)
    _, _, model_load_path = lp.get_save_paths()
    conv_encoder.load_state_dict(torch.load(model_load_path + "/conv_encoder_{}.pth".format(checkpoint)))
    conv_decoder.load_state_dict(torch.load(model_load_path + "/conv_decoder_{}.pth".format(checkpoint)))
    gan_discriminator.load_state_dict(torch.load(model_load_path + "/gan_discriminator_{}.pth".format(checkpoint)))

In [19]:
checkpoint = 0
if checkpoint:
    load_models('/home/pr/synth.data/Autoencoder/outs/Mon-Mar--1-12-35-57-2021/', checkpoint)

reconstr_loss, disc_loss = [], []
norms = []
writer = SummaryWriter(log_dir = sp.results_path + "logs")

In [None]:
image_path, list_path, model_path = sp.get_save_paths()

In [18]:
one = torch.tensor(1)
mone = one * -1
if torch.cuda.is_available():
    one, mone = one.cuda(), mone.cuda()

In [20]:
# Adversarial ground truths
ones = Variable(torch.cuda.FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) # real
zeros = Variable(torch.cuda.FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) # fake

for epoch in range(checkpoint, checkpoint+args.n_epochs):
    for step, (images, _) in tq(enumerate(train_loader)):
        
        current_batch_size = images.size()[0]        
        images = images.cuda()

        encoder.zero_grad()
        decoder.zero_grad()
        z_discriminator.zero_grad()

        # ======== Train Discriminator ======== #

#         freeze_params(decoder)
#         freeze_params(encoder)
        unfreeze_params(z_discriminator)

        z_prior = torch.randn(current_batch_size, args.n_z).cuda()
        d_prior = z_discriminator(z_prior)
        z_real = encoder(images)
        d_real = z_discriminator(z_real.detach())

        # negate for gradient ascent
#         d_loss = -(torch.log(d_fake).mean() + torch.log(1 - d_real).mean())
#         d_loss = d_prior.mean() + (1 - d_real).mean()
        d_loss = adversarial_loss_fn(d_prior, ones) + adversarial_loss_fn(d_real, zeros)

        disc_loss.append(d_loss.item())

        d_loss.backward()

        opt_discriminator.step()


        # ======== Train Generator ======== #

#         unfreeze_params(decoder)
#         unfreeze_params(encoder)
        freeze_params(discriminator)

#         z_real = encoder(images)
        x_recon = decoder(z_real)
        d_real = discriminator(encoder(Variable(images.data)))

        d_real_val.append(d_real.mean().item())
        # recon_loss = criterion(x_recon, images) - (args.LAMBDA * (torch.log(d_real)).mean())

        recon_loss = criterion(x_recon, images)
        # d_loss = args.LAMBDA * (torch.log(d_real)).mean()
        d_loss = args.LAMBDA * (d_real).mean()

        reconstr_loss.append(recon_loss.data.item())

        recon_loss.backward()
        d_loss.backward()

        enc_optim.step()
        dec_optim.step()

    if (epoch + 1) % 1 == 0:
        temp = np.mean(disc_loss, axis=0)
        print("Epoch: [%d/%d], Recon Loss: [%.4f], Reg loss: [%.4f], D_fake: [%.4f], D_real: [%.4f], D_loss: [%.4f]" %
                (epoch + 1, args.n_epochs, 
                 np.mean(reconstr_loss), np.mean(d_real_val), temp[0], temp[1], temp[2]))
        disc_loss.clear()
        d_real_val.clear()

    if (epoch + 1) % 1 == 0:

        reconstr_loss_epoch.append(np.mean(reconstr_loss))
        reconstr_loss.clear()

        batch_size = 128
        test_iter = iter(test_loader)
        test_data = next(test_iter)

        z_real = encoder(Variable(test_data[0]).cuda())
        reconst = decoder(z_real).cpu().view(batch_size, 3, 32, 32)
        sample = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 3, 32, 32)

        image = torch.cat((test_data[0].view(batch_size, 3, 32, 32), 
                                reconst.data), axis=3)
        
        save_image(image, image_path +
                                '/inputs_reconstr_{}.png'.format(epoch+1), normalize = True)
        
        save_image(sample, image_path + 
                                '/sample_{}.png'.format(epoch+1), normalize = True)

    if (epoch + 1) % 25 == 0:
        save_models(model_path, epoch+1, encoder, decoder, discriminator)
        save_lists(list_path, epoch+1, reconstr_loss_epoch)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [1/100], Recon Loss: [0.0702], Reg loss: [0.1205], D_fake: [0.0828], D_real: [0.1036], D_loss: [0.9791]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [2/100], Recon Loss: [0.0470], Reg loss: [0.0000], D_fake: [0.0000], D_real: [0.0000], D_loss: [1.0000]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [3/100], Recon Loss: [0.0372], Reg loss: [0.0000], D_fake: [0.0000], D_real: [0.0000], D_loss: [1.0000]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

KeyboardInterrupt: ignored

In [76]:
reconstr_loss

[]

In [None]:
z_real = encoder(Variable(test_data[0]).cuda())        
reconst = decoder(z_real).cpu().view(batch_size, 1, 28, 28)
sample = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 1, 28, 28)

RuntimeError: ignored

In [None]:
image = torch.cat((test_data[0].view(batch_size, 1, 28, 28), 
                                reconst.data), axis=3)

In [None]:
image.shape

torch.Size([128, 1, 28, 56])

In [None]:
np.mean(disc_loss, axis=0)

In [None]:
disc_loss[0]

Output hidden; open in https://colab.research.google.com to view.

In [3]:
# encoder block (used in encoder and discriminator)
class EncoderBlock(nn.Module):
    def __init__(self, channel_in, channel_out):
        super(EncoderBlock, self).__init__()
        # convolution to halve the dimensions
        self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=5, padding=2, stride=2,
                              bias=False)
        self.bn = nn.BatchNorm2d(num_features=channel_out, momentum=0.9)

    def forward(self, ten, out=False,t = False):
        # here we want to be able to take an intermediate output for reconstruction error
        if out:
            ten = self.conv(ten)
            ten_out = ten
            ten = self.bn(ten)
            ten = F.relu(ten, False)
            return ten, ten_out
        else:
            ten = self.conv(ten)
            ten = self.bn(ten)
            ten = F.relu(ten, True)
            return ten


# decoder block (used in the decoder)
class DecoderBlock(nn.Module):
    def __init__(self, channel_in, channel_out):
        super(DecoderBlock, self).__init__()
        # transpose convolution to double the dimensions
        self.conv = nn.ConvTranspose2d(channel_in, channel_out, kernel_size=5, padding=2, stride=2, output_padding=1,
                                       bias=False)
        self.bn = nn.BatchNorm2d(channel_out, momentum=0.9)

    def forward(self, ten):
        ten = self.conv(ten)
        ten = self.bn(ten)
        ten = F.relu(ten, True)
        return ten


class Encoder(nn.Module):
    def __init__(self, channel_in=3, z_size=128):
        super(Encoder, self).__init__()
        self.size = channel_in
        layers_list = []
        # the first time 3->64, for every other double the channel size
        for i in range(3):
            if i == 0:
                layers_list.append(EncoderBlock(channel_in=self.size, channel_out=64))
                self.size = 64
            else:
                layers_list.append(EncoderBlock(channel_in=self.size, channel_out=self.size * 2))
                self.size *= 2

        # final shape Bx256x8x8
        self.conv = nn.Sequential(*layers_list)
        self.fc = nn.Sequential(nn.Linear(in_features=8 * 8 * self.size, out_features=1024, bias=False),
                                nn.BatchNorm1d(num_features=1024,momentum=0.9),
                                nn.ReLU(True))
        # two linear to get the mu vector and the diagonal of the log_variance
        self.l_mu = nn.Linear(in_features=1024, out_features=z_size)
        self.l_var = nn.Linear(in_features=1024, out_features=z_size)

    def forward(self, ten):
        ten = self.conv(ten)
        ten = ten.view(len(ten), -1)
        ten = self.fc(ten)
        mu = self.l_mu(ten)
        logvar = self.l_var(ten)
        return mu, logvar

    def __call__(self, *args, **kwargs):
        return super(Encoder, self).__call__(*args, **kwargs)


class Decoder(nn.Module):
    def __init__(self, z_size, size):
        super(Decoder, self).__init__()
        # start from B*z_size
        self.fc = nn.Sequential(nn.Linear(in_features=z_size, out_features=8 * 8 * size, bias=False),
                                nn.BatchNorm1d(num_features=8 * 8 * size,momentum=0.9),
                                nn.ReLU(True))
        self.size = size
        layers_list = []
        layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size))
        layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size//2))
        self.size = self.size//2
        layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size//4))
        self.size = self.size//4
        # final conv to get 3 channels and tanh layer
        layers_list.append(nn.Sequential(
            nn.Conv2d(in_channels=self.size, out_channels=3, kernel_size=5, stride=1, padding=2),
            nn.Tanh()
        ))

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):

        ten = self.fc(ten)
        ten = ten.view(len(ten), -1, 8, 8)
        ten = self.conv(ten)
        return ten

    def __call__(self, *args, **kwargs):
        return super(Decoder, self).__call__(*args, **kwargs)


class Discriminator(nn.Module):
    def __init__(self, channel_in=3,recon_level=3):
        super(Discriminator, self).__init__()
        self.size = channel_in
        self.recon_levl = recon_level
        # module list because we need need to extract an intermediate output
        self.conv = nn.ModuleList()
        self.conv.append(nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True)))
        self.size = 32
        self.conv.append(EncoderBlock(channel_in=self.size, channel_out=128))
        self.size = 128
        self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256))
        self.size = 256
        self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256))
        # final fc to get the score (real or fake)
        self.fc = nn.Sequential(
            nn.Linear(in_features=8 * 8 * self.size, out_features=512, bias=False),
            nn.BatchNorm1d(num_features=512,momentum=0.9),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=1),
        )

    def forward(self, ten_orig, ten_predicted, ten_sampled, mode='REC'):
        if mode == "REC":
            ten = torch.cat((ten_orig, ten_predicted, ten_sampled), 0)
            for i, lay in enumerate(self.conv):
                # we take the 9th layer as one of the outputs
                if i == self.recon_levl:
                    ten, layer_ten = lay(ten, True)
                    # we need the layer representations just for the original and reconstructed,
                    # flatten, because it's a convolutional shape
                    layer_ten = layer_ten.view(len(layer_ten), -1)
                    return layer_ten
                else:
                    ten = lay(ten)
        else:
            ten = torch.cat((ten_orig, ten_predicted, ten_sampled), 0)
            for i, lay in enumerate(self.conv):
                    ten = lay(ten)

            ten = ten.view(len(ten), -1)
            ten = self.fc(ten)
            return F.sigmoid(ten)


    def __call__(self, *args, **kwargs):
        return super(Discriminator, self).__call__(*args, **kwargs)

In [4]:
disc = Discriminator()

In [None]:
# print(summary(conv_encoder, torch.zeros((1, 3, 32, 32)).cuda(), show_input=False, show_hierarchical=False))
# print(summary(conv_decoder, torch.zeros((1, 1, 100)).cuda(), show_input=False, show_hierarchical=False))
print(summary(disc, torch.zeros((1, 1, 32, 32)).cuda(), show_input=False, show_hierarchical=False))