In [9]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
import pytorch_lightning as pl

class VAE(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(784, 512)
        self.fc_mu = nn.Linear(512, 32)
        self.fc_log_var = nn.Linear(512, 32)
        self.fc_decode1 = nn.Linear(32, 512)
        self.fc_decode2 = nn.Linear(512, 784)

    def encode(self, x):
        x = F.relu(self.fc(x.view(x.size(0), -1)))
        z_mu = self.fc_mu(x)
        z_log_var = self.fc_log_var(x)
        z_std = torch.exp(0.5 * z_log_var)
        eps = torch.zeros_like(z_std).normal_()
        z = z_mu + z_std * eps
        # z = torch.normal(z_mu, std) <- NEVER DO THIS!
        return z_mu, z_log_var, z

    def decode(self, z):
        z = F.relu(self.fc_decode1(z))
        z = torch.sigmoid(self.fc_decode2(z))
        return z.view(z.size(0), 1, 28, 28)
        
    def forward(self, x):
        z_mu, z_log_var, z = self.encode(x)
        return self.decode(z)
    
    def __calc_loss(self, batch):
        criterion = nn.BCELoss(reduction='sum')
        inputs, labels = batch
        z_mu, z_log_var, z = self.encode(inputs)
        outputs = self.decode(z)
        recon_loss = criterion(outputs, labels)
        kl_loss = 1 + z_log_var - z_mu ** 2 - torch.exp(z_log_var)
        kl_loss = torch.sum(kl_loss) * -0.5
        return recon_loss, kl_loss
    
    def training_step(self, batch, batch_idx):
        recon_loss, kl_loss = self.__calc_loss(batch)
        return {'loss': recon_loss + kl_loss, 'recon': recon_loss, 'kl': kl_loss}
    
    def validation_step(self, batch, batch_idx):
        # optional
        return {'val_loss': sum(self.__calc_loss(batch))}

    def training_epoch_end(self, outputs):
        #OPTIONAL
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_recon_loss = torch.stack([x['recon'] for x in outputs]).mean()
        avg_kl_loss = torch.stack([x['kl'] for x in outputs]).mean()
        logs = {'loss': avg_loss, 'recon': avg_recon_loss, 'kl': avg_kl_loss}
        results = {'log': logs}
        return results
    
    def validation_epoch_end(self, outputs):
        #OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val_loss': avg_loss}
        results = {'log': logs}
        return results
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
    
    def train_dataloader(self):
        # prepare하는 method가 따로 있다고도 함. TODO
        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])
        train_set = torchvision.datasets.MNIST(
            root='./data', train=True, download=True, transform=transform
        )
        train_loader = torch.utils.data.DataLoader(
            [(x, x) for x, y in train_set], batch_size=512, num_workers=16
        )
        return train_loader
    
    def val_dataloader(self):
        transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])
        train_set = torchvision.datasets.MNIST(
            root='./data', train=False, download=True, transform=transform
        )
        train_loader = torch.utils.data.DataLoader(
            [(x, x) for x, y in train_set], batch_size=512, num_workers=16
        )
        return train_loader

from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="my_model")

model = VAE()
trainer = pl.Trainer(
    max_nb_epochs=30,
    gpus='0',
    logger=logger
)
trainer.fit(model)

INFO:lightning:GPU available: True, used: True
INFO:lightning:VISIBLE GPUS: 0
INFO:lightning:
  | Name       | Type   | Params
----------------------------------
0 | fc         | Linear | 401 K 
1 | fc_mu      | Linear | 16 K  
2 | fc_log_var | Linear | 16 K  
3 | fc_decode1 | Linear | 16 K  
4 | fc_decode2 | Linear | 402 K 


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=20.0, style=Pro…




1