In [1]:
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

In [6]:
from pixyz.distributions import Normal
from pixyz.utils import print_latex

In [7]:
class Residual_Block(nn.Module):
    def __init__(self, in_c=64, out_c=64, groups=1, scale=1.0):
        super(Residual_Block, self).__init__()
        # (N, N, in_c) -> (N, N, out_c)
        
        mid_c = int(out_c * scale)
        
        if in_c is not out_c:
            self.conv_expand = nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
        else:
            self.conv_expand = None
        
        self.conv1 = nn.Conv2d(in_channels=in_c, out_channels=mid_c, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_c)
        self.relu1 = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(in_channels=mid_c, out_channels=out_c, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu2 = nn.LeakyReLU(0.2, inplace=True)
        
    def forward(self, x):
        if self.conv_expand is not None:
            identity_data = self.conv_expand(x)
        else:
            identity_data = x
        
        output = self.relu1(self.bn1(self.conv1(x)))
        output = self.conv2(output)
        output = self.relu2(torch.add(self.bn2(output), identity_data))
        # output = self.relu2(self.bn2(torch.add(output,identity_data))) <- original code
        return output

In [8]:
class Encoder(Normal):
    def __init__(self, c_dim=3, h_dim=512, channels=[64, 128, 256, 512, 512, 512], image_size=256):
        super(Encoder, self).__init__(cond_var=["x"], var=["z"])
        # (image_size, ,image_size) -> (image_size//(2**len(channels)), image_size//(2**len(channels))
        
        assert (2 ** len(channels)) * 4 == image_size
        
        cc = channels[0]
        self.main = nn.Sequential(
            #(image_size, image_size, c_dim) -> (image_size, image_size, cc)
            nn.Conv2d(c_dim, cc, kernel_size=5, stride=1, padding=2, bias=False),
            nn.BatchNorm2d(cc),
            nn.LeakyReLU(0.2),
            # (image_size, image_size, cc) -> (image_size/2, image_size/2, cc)
            nn.AvgPool2d(kernel_size=2),
        )
        
        sz = image_size // 2
        for ch in channels[1:]:
            self.main.add_module('res_in_{}'.format(sz), Residual_Block(in_c=cc, out_c=ch, scale=1.0))
            self.main.add_module('down_to_{}'.format(sz//2), nn.AvgPool2d(kernel_size=2))
            cc, sz = ch, sz//2
        
        self.main.add_module('res_in_{}'.format(sz), Residual_Block(in_c=cc, out_c=cc, scale=1.0))
        # len(channels) = 6
        # 256 / (2**6) = 4
        self.fc = nn.Linear(cc*4*4, 2*h_dim)
    
    def forward(self, x):
        y = self.main(x).view(x.size(0), -1) # (Batch, image_size, image_size, c_dim) -> (Batch_size, channels[0] * 4 *4)
        y = self.fc(y) # (Batch_size, channels[0] * 4 *4) -> (Batch_size, 2*h_dim)
        mu, logvar = y.chunk(2, dim=1)
        return {"loc": mu, "scale": F.softplus(logvar)}   

In [9]:
a = Encoder()
print_latex(a)

<IPython.core.display.Math object>

In [12]:
class Decoder(Normal):
    def __init__(self, c_dim=3, h_dim=512, channels=[64, 128, 256, 512, 512, 512], image_size=256):
        super(Decoder, self).__init__(cond_var=["z"], var=["x"])
        
        assert (2 ** len(channels)) * 4 == image_size
        cc = channels[-1]
        self.fc = nn.Sequential(
            nn.Linear(h_dim, cc*4*4),
            nn.ReLU(True),
        )
        
        sz = 4
        
        self.main = nn.Sequential()
        for ch in channels[::-1]:
            self.main.add_module('res_in_{}'.format(sz), Residual_Block(in_c=cc, out_c=ch, scale=1.0))
            self.main.add_module('up_to_{}'.format(sz*2), nn.Upsample(scale_factor=2, mode='nearest'))
            cc, sz = ch, sz*2
        
        self.main.add_module('res_in_{}'.format(sz), Residual_Block(in_c=cc, out_c=cc, scale=1.0))
        self.main.add_module('predict', nn.Conv2d(in_channels=cc, out_channels=c_dim, kernel_size=5, stride=1, padding=2))
        
    def forward(self, z):
        z = z.view(z.size(0), -1) # (Batch, h_dim)
        y = self.fc(z) # (Batch, h_dim) -> (Batch, cc*4*4)
        y = y.view(z.size(0), -1, 4, 4) # (Batch, cc*4*4) -> (Batch, cc, 4, 4)
        y = self.main(y) # (Batch, cc, 4, 4) -> (Batch, 3, 256, 256)
        return {"loc": y, "scale": torch.ones(y.size())}

In [13]:
d = Decoder()
print_latex(d)

<IPython.core.display.Math object>

In [29]:
from pixyz.losses import KullbackLeibler, LogProb
from pixyz.losses import Expectation as E

## Loss

$L_{E}(x, z)=D_{K L}\left(q_{\phi}(z | x) \| p(z)\right) +\alpha * \max (0, m-D_{K L}\left(q_{\phi}(z | G(z)) \| p(z)\right))- \beta * E_{q_{\phi}(z | x)} \log p_{\theta}(x | z)$

$L_{G}(z)=\alpha * D_{K L}\left(q_{\phi}(z | G(z)) \| p(z)\right)- \beta * E_{q_{\phi}(z | x)} \log p_{\theta}(x | z)$

In [30]:
kl = KullbackLeibler(a, prior)
kl.eval()

ValueError: Input keys are not valid, expected ['x'] but got [].

In [31]:
recon = E(a, LogProb(d))
recon.eval()

ValueError: Input keys are not valid, expected ['x'] but got [].

In [14]:
from pixyz.models import Model

In [32]:
class IntroVAE(Model):
    def __init__(self, c_dim=3, h_dim=512, channels=[64, 128, 256, 512, 512, 512], image_size=256):
        super(IntroVAE, self).__init__()
        
        self.encoder = Encoder(c_dim, h_dim, channels, image_size)
        self.decoder = Decoder(c_dim, h_dim, channels, image_size)
        self.prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[h_dim], name="p_{prior}")
        
        self.kl_loss = KullbackLeibler(self.encoder, self.prior)
        self.recon_loss = E(self.encoder, LogProb(self.decoder))
        
        self.vae_loss = (self.kl_loss - self.recon_loss).mean()
        
        self.m_plus = 0
        self.weight_neg = 0
        self.weight_rec = 0
        self.weight_kl = 0
        
        self.lr_e = 1e-3
        self.lr_g = 1e-3
        
        self.optimizerE = optim.Adam(self.encoder.parameters(), lr=self.lr_e)
        self.optimizerG = optim.Adam(self.decoder.parameters(), lr=self.lr_g)
        
        distributions = [self.encoder, self.decoder]
        self.distributions = nn.ModuleList(distributions)
    
    def calculate_vae_loss(real):
        vae_loss = self.vae_loss.eval({"x": real})
        return vae_loss
    
    def calculate_intro_loss(rec, fake, real):
        # Encoder Loss
        loss_rec = self.recon_loss.eval({"x": real}).mean()
        
        lossE_real_kl = self.kl_loss.eval({"x": real}).mean()
        lossE_rec_kl = self.kl_loss.eval({"x": rec.detach()}).mean()
        lossE_fake_kl = self.kl_loss.eval({"x": fake.detach()}).mean()
        loss_margin = lossE_real_kl + \
                      (F.relu(self.m_plus-lossE_rec_kl) + \
                      F.relu(self.m_plus-lossE_fake_kl)) * 0.5 * self.weight_neg
        lossE = loss_rec  *self.weight_rec + loss_margin * self.weight_kl
        
        # Generator Loss
        lossG_rec_kl = self.kl_loss.eval({"x": rec}).mean()
        lossG_fake_kl = self.kl_loss.eval({"x": fake}).mean()
        lossG = (lossG_rec_kl + lossG_fake_kl)* 0.5 * self.weight_kl
        
        return lossE, lossG
    
    def train(self, train_x_dixt={}, vae=False):
        self.distributions.train()
        
        if vae:
            vae_loss = calculate_vae_loss(real=train_x_dixt["x"])
            
            self.optimizerG.zero_grad()
            self.optimizerE.zero_grad()       
            vae_loss.backward()                   
            self.optimizerE.step() 
            self.optimizerG.step()
            
            return vae_loss.item()
        
        else:
            real = train_x_dixt["x"]
            fake = (self.prior * self.decoder).sample()["x"]
            z = self.encoder.sample(train_x_dixt, return_all=False, reparam=True)
            # sample_mean() ???
            rec = self.decoder.sample({"z": z}, return_all=False)
            lossE, lossG = calculate_intro_loss(rec, fake, real)
            
            # update Encoder
            self.optimizerE.zero_grad()       
            lossE.backward()
            self.optimizerE.step()
            
            # update Decoder
            self.optimizerG.zero_grad()       
            lossG.backward()
            self.optimizerG.step()
            return lossE.item(), lossG.item()
        
    def test(self, train_x_dixt={}, vae=False):
        self.distributions.eval()
        with torch.no_grad():
            if vae:
                vae_loss = calculate_vae_loss(real=train_x_dixt["x"])
                return vae_loss.item()

            else:
                real = train_x_dixt["x"]
                fake = (self.prior * self.decoder).sample()["x"]
                z = self.encoder.sample(train_x_dixt, return_all=False, reparam=True)
                # sample_mean() ???
                rec = self.decoder.sample({"z": z}, return_all=False)
                lossE, lossG = calculate_intro_loss(rec, fake, real)
                return lossE.item(), lossG.item()   