In [26]:
from src.models.cyclegan import CycleGan
import pytorch_lightning as pl
from src.datamodule import CelebAData
pl.seed_everything()


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.c1 = nn.Conv2d(in_channels, out_channels, 4, 2, 1)
        self.b1 = nn.BatchNorm2d(out_channels)
        self.c2 = nn.Conv2d(out_channels, out_channels, 3, padding="same")
        self.b2 = nn.BatchNorm2d(out_channels)
        self.c3 = nn.Conv2d(out_channels, out_channels, 3, padding="same")
        self.b3 = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        x = nn.ReLU(True)(self.b1(self.c1(x)))
        y = x
        x = nn.ReLU(True)(self.b2(self.c2(x)))
        x = nn.ReLU(True)(self.b3(self.c3(x)))
        x = y+x
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.c1 = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1)
        self.b1 = nn.BatchNorm2d(out_channels)
        self.c2 = nn.ConvTranspose2d(out_channels, out_channels, 3, padding=1)
        self.b2 = nn.BatchNorm2d(out_channels)
        self.c3 = nn.ConvTranspose2d(out_channels, out_channels, 3, padding=1)
        self.b3 = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        x = nn.ReLU(True)(self.b1(self.c1(x)))
        y = x
        x = nn.ReLU(True)(self.b2(self.c2(x)))
        x = nn.ReLU(True)(self.b3(self.c3(x)))
        x = y+x
        return x

Global seed set to 1045058964


In [48]:
import torch
import torch.nn as nn
import torch.nn.init as init

class FactorVAE2(nn.Module):
    """Encoder and Decoder architecture for 3D Shapes, Celeba, Chairs data."""
    def __init__(self, z_dim=128):
        super(FactorVAE2, self).__init__()
        self.z_dim = z_dim
        self.encode = nn.Sequential(
            EncoderBlock(3,32),
            EncoderBlock(32,32),
            EncoderBlock(32,32),
            EncoderBlock(32,64),
            EncoderBlock(64,64),
            EncoderBlock(64,128),
            EncoderBlock(128,128),
            nn.Conv2d(128, 2*z_dim, 1)
        )
        self.decode = nn.Sequential(
            nn.Conv2d(z_dim, 128, 1),
            nn.ReLU(True),
            DecoderBlock(128, 128),
            DecoderBlock(128, 64),
            DecoderBlock(64, 64),
            DecoderBlock(64, 32),
            DecoderBlock(32, 32),
            DecoderBlock(32, 32),
            DecoderBlock(32, 3),
        )
        self.weight_init()

    def weight_init(self, mode='normal'):
        if mode == 'kaiming':
            initializer = kaiming_init
        elif mode == 'normal':
            initializer = normal_init

        for block in self._modules:
            for m in self._modules[block]:
                initializer(m)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = std.data.new(std.size()).normal_()
        return eps.mul(std).add_(mu)

    def forward(self, x, no_dec=False):
        stats = self.encode(x)
        mu = stats[:, :self.z_dim]
        logvar = stats[:, self.z_dim:]
        z = self.reparametrize(mu, logvar)
        print(z.shape)
        if no_dec:
            return z.squeeze()
        else:
            x_recon = self.decode(z)
            print(x_recon.shape)
            print(mu.shape)
            print(logvar.shape)
            return x_recon, mu, logvar, z.squeeze()


class Discriminator(nn.Module):
    def __init__(self, z_dim):
        super(Discriminator, self).__init__()
        self.z_dim = z_dim
        self.net = nn.Sequential(
            nn.Linear(z_dim, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 1000),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1000, 2),
        )
        self.weight_init()

    def weight_init(self, mode='normal'):
        if mode == 'kaiming':
            initializer = kaiming_init
        elif mode == 'normal':
            initializer = normal_init

        for block in self._modules:
            for m in self._modules[block]:
                initializer(m)

    def forward(self, z):
        return self.net(z.squeeze()).squeeze()


def kaiming_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.kaiming_normal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


def normal_init(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)

import torchsummary
# x = FactorVAE2().to('cuda:0')
x = Discriminator(128).to('cuda:0')
torchsummary.summary(x, (128,1,1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1000]         129,000
         LeakyReLU-2                 [-1, 1000]               0
            Linear-3                 [-1, 1000]       1,001,000
         LeakyReLU-4                 [-1, 1000]               0
            Linear-5                 [-1, 1000]       1,001,000
         LeakyReLU-6                 [-1, 1000]               0
            Linear-7                 [-1, 1000]       1,001,000
         LeakyReLU-8                 [-1, 1000]               0
            Linear-9                 [-1, 1000]       1,001,000
        LeakyReLU-10                 [-1, 1000]               0
           Linear-11                    [-1, 2]           2,002
Total params: 4,135,002
Trainable params: 4,135,002
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forw

In [None]:
import os
import pandas as pd
root = './data/celeba'
attributes = pd.read_csv(os.path.join(root, 'list_attr_celeba.csv'))
images_dir = os.path.join(root, 'img_align_celeba')


In [None]:
good_attribute_threshold = 0.3
good_attributes = ['image_id']
for key in attributes.keys():
    c = attributes[attributes[key]==1]['image_id'].count()
    if c < len(attributes)*(1-good_attribute_threshold) and c > len(attributes) * good_attribute_threshold:
        good_attributes.append(key)
gdf = attributes[good_attributes]

In [None]:
list(attributes.keys()).index("Mouth_Slightly_Open")
from argparse import Namespace
print(good_attributes)
print(list(attributes.keys()).index("Attractive"))

In [None]:
args = Namespace(
    log_every_n_steps=1,
    gradient_clip_val=0.5,
    learning_rate=1e-3,
    weight_decay=1e-6,
    precision=16,
    data_dir = './data/celeba',
    batch_size = 512,
    num_workers = 1,
    gpus=2,
    max_epochs=10,
    accumulate_grad_batches=4,
)

In [None]:
from src.models.classifier import Classifier, Identity
from src.models.cnn_for_encoded import SimpleCNN
from torchsummary import summary
cnn = SimpleCNN().to('cuda:0')
target_attr = 32
# translator = Identity()
translator = CycleGan(args, target_attr=target_attr).load_from_checkpoint('./logs/0126/version_0/checkpoints/epoch=99-step=7999.ckpt').A2B
classifier = Classifier(args, model=cnn, target_attr=target_attr, preEncoder=translator).to('cuda:0')

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger("logs/toy", name='')


In [None]:
import sys
sys.argv = ['-f']

In [None]:
import pytorch_lightning as pl
from src.datamodule import CelebAEncodedData
from argparse import ArgumentParser
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
argu = vars(parser.parse_args())
argu.update(vars(args))
args = Namespace(**argu)
trainer = pl.Trainer.from_argparse_args(args, logger=logger)
dm = CelebAEncodedData(args)
trainer.fit(classifier, datamodule=dm)