## Deep generative models learn by defining objective function and using gradient descent method
- Unlike traditional generative models, deep generative models don't learn by sampling


<img src='../tutorial_figs/pixyz_API.png'>

## A framework in which objective function and optimization algorithm can be set independently(Model API)
- Model API document: https://docs.pixyz.io/en/v0.0.4/models.html  

Here, we train the model with defined probability distributions and loss function by using Model API.

In [1]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

batch_size = 256
seed = 1
torch.manual_seed(seed)

<torch._C.Generator at 0x7fa0a01e6c70>

In [2]:
# MNIST dataset
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

### Define probability distributions

In [3]:
from pixyz.distributions import Normal, Bernoulli

x_dim = 784
z_dim = 64

# inference model q(z|x)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

    
# generative model p(x|z)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
    
gen_ber_x__z = Generator().to(device)
infer_nor_z__x = Inference().to(device)

prior_nor_z = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[z_dim], name="p_{prior}").to(device)

### Define Loss

In [4]:
from pixyz.losses import LogProb
from pixyz.losses import Expectation as E
from pixyz.losses import KullbackLeibler
from pixyz.utils import print_latex

# log likelihood
logprob_gen_x__z = LogProb(gen_ber_x__z)

# Expectation
E_infer_z__x_logprob_gen_x__z = E(infer_nor_z__x, logprob_gen_x__z)

# KL divergence
KL_infer_nor_z__x_prior_nor_z = KullbackLeibler(infer_nor_z__x, prior_nor_z)

# Subtraction between losses
total_loss = KL_infer_nor_z__x_prior_nor_z - E_infer_z__x_logprob_gen_x__z

# mean of loss
total_loss = total_loss.mean()


# check the loss
print_latex(total_loss)

<IPython.core.display.Math object>

### Model API: Set probability distributions and loss, and optimization algorithm

We use pixyz.models Model.  
Main arguments are `loss`, `distributions`, `optimizer`, `optimzer_params`. We set each arguments as follows.  
- loss: Set defined loss function defined by Loss API
- distributions: Set defined probability distributions which have parameters supposed to be learned defined by Distribution API  
- optimizer, optimizer_params: Set optimization algorithms and parameters of the algorithm  

For more details about Model: https://docs.pixyz.io/en/v0.0.4/_modules/pixyz/models/model.html#Model

In [5]:
from pixyz.models import Model
from torch import optim

optimizer = optim.Adam
optimizer_params = {'lr': 1e-3}

vae_model = Model(loss=total_loss, 
                     distributions=[gen_ber_x__z, infer_nor_z__x],
                     optimizer=optimizer,
                     optimizer_params=optimizer_params
                    )

We have defined Model.  
As shown above, we can set objective function and optimization algorithm independently.  
Next, we train the model using `train()` method.  
Model Class `train()` processes are following.  
source code: https://docs.pixyz.io/en/v0.0.4/_modules/pixyz/models/model.html#Model.train
1. Receive observed data x(.train({"x": x}))  
2. Calculate loss  
3. 1 step update of parameters  
4. Return the loss value  

```python
def train(self, train_x={}, **kwargs):
        self.distributions.train()

        self.optimizer.zero_grad()
        loss = self.loss_cls.estimate(train_x, **kwargs)

        # backprop
        loss.backward()

        # update params
        self.optimizer.step()

        return loss
```

### Training

In [6]:
epoch_loss = []
for epoch in range(3):
    train_loss = 0
    for x, _ in train_loader:
        x = x.to(device)
        loss = vae_model.train({"x": x})
        train_loss += loss
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch {}, Loss {} '.format(epoch, train_loss))
    epoch_loss.append(train_loss)

Epoch 0, Loss 199.86109924316406 
Epoch 1, Loss 147.0438690185547 
Epoch 2, Loss 126.67538452148438 


## Use more {abstract} Model API
We can define models more easily by using more {abstract} Model API.  
We need to set:  
- define probability distributions  
- (define additional loss functions)
- select the optimization algorithm

Here, we use VAE model as an example.  

In [7]:
from pixyz.distributions import Normal, Bernoulli
from pixyz.losses import KullbackLeibler
# more {abstract} Model API VAE
from pixyz.models import VAE

### Define probability distributions

In [8]:
x_dim = 784
z_dim = 64


# inference model q(z|x)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}

    
# generative model p(x|z)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
    
p = Generator().to(device)
q = Inference().to(device)

prior = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.),
               var=["z"], features_shape=[z_dim], name="p_{prior}").to(device)

### Add regularization terms to the loss function

In [9]:
kl = KullbackLeibler(q, prior)
print_latex(kl)

<IPython.core.display.Math object>

### VAE Model: Set additional loss function and select the optimazation algorithm  

In [10]:
model = VAE(encoder=q, decoder=p, regularizer=kl, 
            optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print_latex(model)

<IPython.core.display.Math object>

### Training

In [11]:
def train(epoch):
    train_loss = 0
    for x, _ in train_loader:
        x = x.to(device)
        loss = model.train({"x": x})
        train_loss += loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}'.format(epoch, train_loss))
    return train_loss

In [12]:
epochs = 3
train_losses = []
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    train_losses.append(train_loss)

Epoch: 1 Train loss: 200.3801
Epoch: 2 Train loss: 147.1353
Epoch: 3 Train loss: 127.9876


For more {abstract} Model
- Pre-implementation models: https://docs.pixyz.io/en/v0.0.4/models.html#pre-implementation-models

### Pixyz implementations
There are more complexed models written in pixyz in the following links.  
- Pixyz examples: https://github.com/masa-su/pixyz/tree/master/examples
- Pixyzoo: https://github.com/masa-su/pixyzoo

Pixyz implementation work flow is the same for all models  
1. Define probability distributions using `Distribution API`  
1. Define the loss function based on the defined probability distributions using `Loss API`
1. Set probability distributions and loss, and optimization algorithm using `Model API`, and train