# High Resolution VAEs

## 1. Setup

In [1]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
!pip -q install huggingface_hub datasets lightning einops bitsandbytes lpips

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
# !git clone https://github.com/oelin/generative-models

## 2. Implementation

### 2.1. Macros

In [3]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.checkpoint import checkpoint_sequential


def Convolution1x1(input_channels: int, output_channels: int) -> nn.Module:

    return nn.Conv2d(
        in_channels=input_channels,
        out_channels=output_channels,
        kernel_size=1,
        stride=1,
        padding=0,
    )


def Convolution3x3(input_channels: int, output_channels: int) -> nn.Module:

    return nn.Conv2d(
        in_channels=input_channels,
        out_channels=output_channels,
        kernel_size=3,
        stride=1,
        padding=1,
        bias=False,
    )


def Convolution4x4(input_channels: int, output_channels: int) -> nn.Module:

    return nn.Conv2d(
        in_channels=input_channels,
        out_channels=output_channels,
        kernel_size=4,
        stride=2,
        padding=1,
        bias=False,
    )


def Normalization(channels: int) -> nn.Module:

    return nn.GroupNorm(
        num_groups=min(channels, 32),
        num_channels=channels,
    )


def Repeat(module, channels_list: List[int]) -> nn.Module:

    return nn.Sequential(*(
        module(
            input_channels=input_channels,
            output_channels=output_channels,
        ) for input_channels, output_channels in zip(
            channels_list[: -1],
            channels_list[1 :],
        )
    ))

### 2.2. Modules

In [4]:
from einops import rearrange


class ResidualBlock(nn.Module):

    def __init__(self, channels: int) -> None:
        super().__init__()

        self.normalization = Normalization(channels=channels)

        self.convolution = Convolution3x3(
            input_channels=channels,
            output_channels=channels,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        z = self.normalization(x)
        z = F.leaky_relu(z)
        z = self.convolution(z)

        return x + z


class ResNetBlock(nn.Module):

    def __init__(self, channels: int) -> None:
        super().__init__()

        self.residual_block_1 = ResidualBlock(channels=channels)
        self.residual_block_2 = ResidualBlock(channels=channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.residual_block_1(x)
        x = self.residual_block_2(x)

        return x


class UpsampleBlock(nn.Module):

    def __init__(self, input_channels: int, output_channels: int) -> None:
        super().__init__()

        # self.upsample = nn.Upsample(scale_factor=2)
        self.normalization = Normalization(channels=input_channels)

        self.convolution = Convolution3x3(
            input_channels=input_channels,
            output_channels=output_channels,
        )

        # self.convolution = nn.ConvTranspose2d(
        #     in_channels=input_channels,
        #     out_channels=output_channels,
        #     kernel_size=4,
        #     stride=2,
        #     padding=1,
        # )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.normalization(x)
        x = F.leaky_relu(x)
        # x = self.upsample(x)
        x = F.interpolate(x, scale_factor=2.0, mode='nearest')
        x = self.convolution(x)

        return x


class DownsampleBlock(nn.Module):

    def __init__(self, input_channels: int, output_channels: int) -> None:
        super().__init__()

        self.normalization = Normalization(channels=input_channels)

        self.convolution = Convolution4x4(
            input_channels=input_channels,
            output_channels=output_channels,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.normalization(x)
        x = F.leaky_relu(x)
        x = self.convolution(x)

        return x


class AttentionBlock(nn.Module):

    def __init__(self, channels: int) -> None:
        super().__init__()

        self.normalization = Normalization(channels=channels)

        self.convolution_1 = Convolution1x1(
            input_channels=channels,
            output_channels=channels,
        )

        self.convolution_2 = Convolution1x1(
            input_channels=channels,
            output_channels=channels,
        )

        self.convolution_3 = Convolution1x1(
            input_channels=channels,
            output_channels=channels,
        )

        self.convolution_4 = Convolution1x1(
            input_channels=channels,
            output_channels=channels,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        B, C, H, W = x.shape

        z = self.normalization(x)
        q = self.convolution_1(z)
        k = self.convolution_2(z)
        v = self.convolution_3(z)

        q = rearrange(q, 'b c h w -> b (h w) c')
        k = rearrange(k, 'b c h w -> b c (h w)')  # Transposed.
        v = rearrange(v, 'b c h w -> b (h w) c')

        z = F.softmax(q @ k, dim=-1) @ v
        z = rearrange(z, 'b (h w) c -> b c h w', h=H, w=W)
        z = self.convolution_4(z)

        return x + z


class UpBlock(nn.Module):

    def __init__(self, input_channels: int, output_channels: int) -> None:
        super().__init__()

        self.resnet_block = ResNetBlock(channels=input_channels)

        self.upsample_block = UpsampleBlock(
            input_channels=input_channels,
            output_channels=output_channels,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.resnet_block(x)
        x = self.upsample_block(x)

        return x


class DownBlock(nn.Module):

    def __init__(self, input_channels: int, output_channels: int) ->  None:
        super().__init__()

        self.resnet_block = ResNetBlock(channels=input_channels)

        self.downsample_block = DownsampleBlock(
            input_channels=input_channels,
            output_channels=output_channels,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.resnet_block(x)
        x = self.downsample_block(x)

        return x


class MiddleBlock(nn.Module):

    def __init__(self, channels: int) -> None:
        super().__init__()

        self.resnet_block_1 = ResNetBlock(channels=channels)
        self.resnet_block_2 = ResNetBlock(channels=channels)
        self.attention_block = AttentionBlock(channels=channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.resnet_block_1(x)
        x = self.attention_block(x)
        x = self.resnet_block_2(x)

        return x


class Encoder(nn.Module):

    def __init__(self, channels_list: List[int]) -> None:
        super().__init__()

        self.down_blocks = Repeat(module=DownBlock, channels_list=channels_list)
        self.middle_block = MiddleBlock(channels=channels_list[-1])

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.down_blocks(x)
        x = self.middle_block(x)

        return x


class Decoder(nn.Module):

    def __init__(self, channels_list: List[int]) -> None:
        super().__init__()

        self.up_blocks = Repeat(module=UpBlock, channels_list=channels_list)
        self.middle_block = MiddleBlock(channels=channels_list[0])

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.middle_block(x)
        x = self.up_blocks(x)

        return x


class GaussianDistribution(nn.Module):

    def __init__(self, parameters: torch.Tensor) -> None:
        super().__init__()

        self.mean, self.log_variance = parameters.chunk(chunks=2, dim=1)

    def sample(self) -> torch.Tensor:

        epsilon = torch.randn_like(self.mean, device=self.mean.device)
        standard_deviation = torch.exp(0.5 * self.log_variance)
        x = epsilon * standard_deviation + self.mean

        return x

### 2.3. Models

In [10]:
from typing import Tuple

from dataclasses import dataclass


@dataclass(frozen=True)
class VAEOptions:

    input_channels: int
    output_channels: int
    latent_channels: int
    encoder_channels_list: List[int]
    decoder_channels_list: List[int]


class VAE(nn.Module):

    def __init__(self, options: VAEOptions) -> None:
        super().__init__()

        self.encoder = Encoder(channels_list=options.encoder_channels_list)
        self.decoder = Decoder(channels_list=options.decoder_channels_list)

        # Input to encoder.

        self.convolution_1 = Convolution3x3(
            input_channels=options.input_channels,
            output_channels=options.encoder_channels_list[0],
        )

        # Encoder to latent.

        self.convolution_2 = Convolution3x3(
            input_channels=options.encoder_channels_list[-1],
            output_channels=options.latent_channels * 2,
        )

        # Latent to decoder.

        self.convolution_3 = Convolution3x3(
            input_channels=options.latent_channels,
            output_channels=options.decoder_channels_list[0],
        )

        # Decoder to output.

        self.convolution_4 = Convolution3x3(
            input_channels=options.decoder_channels_list[-1],
            output_channels=options.output_channels,
        )

    def encode(self, x: torch.Tensor) -> GaussianDistribution:

        x = self.convolution_1(x)
        x = self.encoder(x)
        x = self.convolution_2(x)

        distribution = GaussianDistribution(x)

        return distribution

    def decode(self, z: torch.Tensor) -> torch.Tensor:

        x = self.convolution_3(z)
        x = self.decoder(x)
        x = self.convolution_4(x)
        # x = torch.sigmoid(x)

        return x

    def forward(self,
        x: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, GaussianDistribution]:

        distribution = self.encode(x)
        z = distribution.sample()
        x = self.decode(z)

        return x, z, distribution

In [11]:
@dataclass(frozen=True)
class PatchDiscriminatorOptions:
    input_channels: int
    channels_list: int


        # (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        # (1): LeakyReLU(negative_slope=0.2, inplace=True)
        # (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        # (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # (4): LeakyReLU(negative_slope=0.2, inplace=True)
        # (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        # (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # (7): LeakyReLU(negative_slope=0.2, inplace=True)
        # (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
        # (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # (10): LeakyReLU(negative_slope=0.2, inplace=True)
        # (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))

class PatchDiscriminator(nn.Module):

    def __init__(self, options: PatchDiscriminatorOptions) -> None:
        super().__init__()

        self.convolution_1 = Convolution3x3(
            input_channels=options.input_channels,
            output_channels=options.channels_list[0],
        )

        self.convolution_2 = Convolution3x3(
            input_channels=options.channels_list[-1],
            output_channels=1,
        )

        self.down_blocks = Repeat(
            module=DownBlock,
            channels_list=options.channels_list,
        )


    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.convolution_1(x)
        x = self.down_blocks(x)
        x = self.convolution_2(x)  # (b, 1, h/f, w/f).
        x = torch.sigmoid(x)

        return x

### 2.4. Losses

In [12]:
import lpips


@dataclass(frozen=True)
class VAELossOptions:
    kl_divergence_weight: float


class VAELoss(nn.Module):

    def __init__(self, options: VAELossOptions) -> None:
        super().__init__()

        self.kl_divergence_weight = options.kl_divergence_weight
        self.lpips = lpips.LPIPS(net='vgg').eval()

    def forward(
        self,
        prediction: torch.Tensor,
        distribution: GaussianDistribution,
        target: torch.Tensor,
    ) -> None:

        # KL divergence loss.

        loss = distribution.log_variance - distribution.log_variance.exp()
        loss = loss - distribution.mean.pow(2) + 1
        loss = loss.mean() * self.kl_divergence_weight * -0.5

        # Perceptual loss.

        loss = loss + self.lpips(prediction, target).mean() #F.binary_cross_entropy(prediction, target)

        return loss

In [31]:
from typing import Optional


@dataclass(frozen=True)
class VAEDiscriminatorLossOptions:

    kl_divergence_weight: float
    reconstruction_weight: float
    perceptual_weight: float
    generator_weight: float


class VAEDiscriminatorLoss(nn.Module):

    def __init__(self, options: VAEDiscriminatorLossOptions) -> None:
        super().__init__()

        self.kl_divergence_weight = options.kl_divergence_weight
        self.reconstruction_weight = options.reconstruction_weight
        self.perceptual_weight = options.perceptual_weight
        self.generator_weight = options.generator_weight

        self.lpips = lpips.LPIPS(net='vgg').eval()

    def forward(
        self,
        reconstruction: torch.Tensor,
        distribution: GaussianDistribution,
        target: torch.Tensor,
        last_layer: torch.Tensor,
        discriminator: PatchDiscriminator,
        use_discriminator: bool = False,
        scale: float = 1.,
    ) -> dict:

        # KL divergence loss.

        kl_divergence_loss = distribution.log_variance - distribution.log_variance.exp()
        kl_divergence_loss = kl_divergence_loss - distribution.mean.pow(2) + 1
        kl_divergence_loss = kl_divergence_loss.mean() * -0.5

        # Reconstruction loss.

        reconstruction_loss = torch.abs(reconstruction.contiguous() - target.contiguous()).mean()

        # Perceptual loss.

        perceptual_loss = self.lpips(reconstruction.contiguous(), target.contiguous()).mean()

        # Generator loss.

        if use_discriminator:#
            generator_loss = -torch.log(discriminator(reconstruction).mean()) # Maximize p(real) on fake inputs.

            # Compute adaptive weight: the ratio between the reconstruction and generator gradients.

            reconstruction_gradient = torch.autograd.grad(reconstruction_loss, last_layer, retain_graph=True)[0]
            generator_gradient = torch.autograd.grad(generator_loss, last_layer, retain_graph=True)[0]

            ratio = torch.norm(reconstruction_gradient) / (torch.norm(generator_gradient) + 1e-4)
            ratio = torch.clamp(ratio, 0., 1e4).detach()

            generator_loss = generator_loss #* ratio  # Scale generator loss accordingly.

        else:
            generator_loss = 0.

        # Overall loss.

        loss = self.kl_divergence_weight * kl_divergence_loss \
             + self.reconstruction_weight * reconstruction_loss \
             + self.perceptual_weight * perceptual_loss \
             + self.generator_weight * generator_loss

        return {
            'loss': loss * scale,
            'kl_divergence_loss': kl_divergence_loss * self.kl_divergence_weight,
            'reconstruction_loss': reconstruction_loss * self.reconstruction_weight,
            'perceptual_loss': perceptual_loss * self.perceptual_weight,
            'generator_loss': generator_loss * self.generator_weight,
        }

## 4. Training

### 4.1. Training Script

In [36]:
import time

from typing import Callable, Optional

from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

from torchvision.utils import save_image

import bitsandbytes as bnb


@dataclass(frozen=True)
class TrainOptions:
    device: str

    epochs: int
    batch_size: int
    generator_learning_rate: float
    discriminator_learning_rate: float
    discriminator_warmup_steps: int

    batches_before_step: int
    batches_before_log: int
    batches_before_sample: int
    batches_before_checkpoint: int

    checkpoint_path: str
    sample_path: str

    generator_options: VAEOptions
    discriminator_options: PatchDiscriminatorOptions
    loss_options: VAEDiscriminatorLossOptions

    generator: Optional[VAE] = None
    discriminator: Optional[PatchDiscriminator] = None
    loss: Optional[VAEDiscriminatorLoss] = None


def train_summary(
    options: TrainOptions,
    generator: VAE,
    discriminator: PatchDiscriminator,
    log: Callable=print,
) -> None:

    log(f'==========')

    log(f'Training for {options.epochs} epochs on {options.device}...')

    generator_size = sum(p.numel() for p in generator.parameters())
    generator_f_value = 2 ** (len(options.generator_options.encoder_channels_list) - 1)
    generator_c_value = options.generator_options.latent_channels

    log(f'Generator:')
    log(f'- Parameters: {generator_size/1e6:0.2f}M')
    log(f'- Downsampling factor: {generator_f_value}')
    log(f'- Bottleneck channels: {generator_c_value}')
    log(f'- Optimizer: Adam (learning_rate={options.generator_learning_rate})')

    # discriminator_size = sum(p.numel() for p in discriminator.parameters())
    # discriminator_f_value = 2 ** (len(options.discriminator_options.channels_list) - 1)

    # log(f'Discriminator:')
    # log(f'- Parameters: {discriminator_size/1e6:0.2f}M')
    # log(f'- Downsampling factor: {generator_f_value}')
    # log(f'- Optimizer: Adam (learning_rate={options.discriminator_learning_rate})')

    log(f'Loss:')
    log(f'- KL divergence weight: {options.loss_options.kl_divergence_weight}')
    log(f'- Reconstruction weight: {options.loss_options.reconstruction_weight}')
    log(f'- Perceptual weight: {options.loss_options.perceptual_weight}')
    log(f'- Generator weight: {options.loss_options.generator_weight}')

    log(f'Training:')
    log(f'- Epochs: {options.epochs} (batch_size={options.batch_size})')
    log(f'- Batches per step: {options.batches_before_step}')
    log(f'- Batches per checkpoint: {options.batches_before_checkpoint}')
    log(f'- Discriminator warmup steps: {options.discriminator_warmup_steps}')

    log(f'==========')

def train(
    options: TrainOptions,
    dataset: Dataset,
    log: Callable=print,
) -> TrainOptions:

    # Initialize models.

    device = options.device

    generator = options.generator or VAE(options=options.generator_options)

    # discriminator = options.discriminator \
    #     or PatchDiscriminator(options=options.discriminator_options)

    loss = options.loss or VAEDiscriminatorLoss(options=options.loss_options)

    generator = generator.to(device)
    #discriminator = discriminator.to(device)
    loss = loss.to(device)

    # Initialize dataloaders.

    dataloader = DataLoader(
        dataset=dataset,
        shuffle=True,
        batch_size=options.batch_size,
        num_workers=2,
    )

    # Initialize optimizers.

    generator_optimizer = bnb.optim.Adam8bit(generator.parameters(), lr=options.generator_learning_rate)#Adam(generator.parameters(), lr=options.generator_learning_rate)
    #discriminator_optimizer = Adam(discriminator.parameters(), lr=options.discriminator_learning_rate)

    # Display summary before training.

    train_summary(options, generator, None, log)#discriminator, log)

    # Start training loop.

    step = 0

    for epoch in range(options.epochs):
        for batch, examples in enumerate(dataloader):

            use_discriminator = step >= options.discriminator_warmup_steps

            target = examples['image'].to(device)

            with torch.autocast(device_type='cuda'): #device):
                reconstruction, _, distribution = generator(target)

                loss_on_batch = loss(
                    reconstruction=reconstruction,
                    distribution=distribution,
                    target=target,
                    last_layer=generator.convolution_4.weight,  # Last layer of decoder.
                    discriminator=None,
                    use_discriminator=use_discriminator,
                    scale=1/options.batches_before_step,
                )

            loss_on_batch['loss'].backward()
            discriminator_loss = None

            if (batch + 1) % options.batches_before_step == 0:
                generator_optimizer.step()
                generator_optimizer.zero_grad()

                if use_discriminator:

                    discriminator_optimizer.zero_grad()

                    p_real_real = discriminator(target.detach()).mean()
                    p_real_fake = discriminator(reconstruction.detach()).mean()

                    loss_real = -torch.log(p_real_real)
                    loss_fake = -torch.log(1 - p_real_fake)

                    discriminator_loss = loss_real + loss_fake
                    discriminator_loss.backward()

                    discriminator_optimizer.step()

                step += 1

            if (batch + 1) % options.batches_before_log == 0:

                loss1 = loss_on_batch['loss'].detach().item()
                loss2 = loss_on_batch['kl_divergence_loss'].detach().item()
                loss3 = loss_on_batch['reconstruction_loss'].detach().item()
                loss4 = loss_on_batch['perceptual_loss'].detach().item()
                loss5 = f'{loss_on_batch["generator_loss"].detach().item():0.3f}' if use_discriminator else 'n/a'
                loss6 = f'{discriminator_loss.detach().item():0.3f}' if use_discriminator else 'n/a'
                loss7 = f'{p_real_real.detach().item():0.5f}' if use_discriminator else 'n/a'
                loss8 = f'{p_real_fake.detach().item():0.5f}' if use_discriminator else 'n/a'

                log(f'{time.ctime()} | epoch: {epoch:6d}, batch: {batch:6d}, step: {step:6d} - loss: {loss1:0.3f} (kld={loss2:0.3f}, rec={loss3:0.3f}, per={loss4:0.3f}, gen={loss5}, dis={loss6}), p(R|R)={loss7}, p(R|F)={loss8}')

            if (batch + 1) % options.batches_before_sample == 0:
                save_image(rescale(reconstruction)[: 64], f'{options.sample_path}/sample.png')#/sample-{step}.png')

            if (batch + 1) % options.batches_before_checkpoint == 0:
                log('Creating checkpoint...')

                torch.save(generator, f'{options.checkpoint_path}/generator.ckpt')#generator-{step}.ckpt')

### 4.1. CelebA-HQ-256x256

In [33]:
from datasets import load_dataset
from torchvision import transforms

resolution = 256 #512

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: 2 * (x - 0.5)),
    #transforms.Resize((resolution, resolution))#, 0)
])


def preprocess(examples):

  return {
      'image': [transform(image) for image in examples['image']]
  }


def rescale(x):
    return (x - x.min()) / (x.max() - x.min())


dataset = load_dataset('korexyz/CelebA-HQ-256x256', split='train')
#dataset = load_dataset('mattymchen/celeba-hq', split='train')
dataset.set_transform(transform=preprocess)

In [None]:
import gc
import os

torch.cuda.empty_cache()
gc.collect()

#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:native,roundup_power2_divisions:1024"


generator_options = VAEOptions(
    input_channels=3,
    output_channels=3,
    latent_channels=4,
    encoder_channels_list=[128, 256, 512],
    decoder_channels_list=[512, 256, 128],
)

discriminator_options = PatchDiscriminatorOptions(
    input_channels=generator_options.output_channels,
    channels_list=[32, 32, 32],
)

loss_options = VAEDiscriminatorLossOptions(
    kl_divergence_weight=1e-4,#1e-6,
    reconstruction_weight=1.,#1.,#1.,
    perceptual_weight=0.25,#0.,#1.,#1.,
    generator_weight=0.5,
)

loss = VAEDiscriminatorLoss(options=loss_options)

#generator = torch.load('/content/drive/MyDrive/generator-day2-320.ckpt')

#generator = VAE(options=generator_options)
#discriminator = PatchDiscriminator(options=discriminator_options)


train_options = TrainOptions(

    device='cuda',
    epochs=2,
    batch_size=1,#4,
    generator_learning_rate=1e-5,
    discriminator_learning_rate=1e-5,
    discriminator_warmup_steps=1000,

    batches_before_step=8,#4,
    batches_before_log=16,#8,
    batches_before_sample=64,#32,
    batches_before_checkpoint=128,#128,

    checkpoint_path='./checkpoints',
    sample_path='./samples',

    generator_options=generator_options,
    discriminator_options=discriminator_options,
    loss_options=loss_options,

    generator=generator,
    discriminator=None,
    loss=loss,
)


!rm -rf checkpoints samples
!mkdir checkpoints samples

train(options=train_options, dataset=dataset)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.10/dist-packages/lpips/weights/v0.1/vgg.pth
Training for 2 epochs on cuda...
Generator:
- Parameters: 32.52M
- Downsampling factor: 4
- Bottleneck channels: 4
- Optimizer: Adam (learning_rate=1e-05)
Loss:
- KL divergence weight: 0.0001
- Reconstruction weight: 1.0
- Perceptual weight: 0.25
- Generator weight: 0.5
Training:
- Epochs: 2 (batch_size=1)
- Batches per step: 8
- Batches per checkpoint: 128
- Discriminator warmup steps: 1000
Wed Dec 20 13:25:55 2023 | epoch:      0, batch:     15, step:      2 - loss: 0.047 (kld=0.000, rec=0.230, per=0.144, gen=n/a, dis=n/a), p(R|R)=n/a, p(R|F)=n/a
Wed Dec 20 13:25:57 2023 | epoch:      0, batch:     31, step:      4 - loss: 0.034 (kld=0.000, rec=0.136, per=0.135, gen=n/a, dis=n/a), p(R|R)=n/a, p(R|F)=n/a
Wed Dec 20 13:25:59 2023 | epoch:      0, batch:     47, step:      6 - loss: 0.028 (kld=0.000, rec=0.121, per=0.102, gen=n/a, dis=n/a), p(R|R)=n/a, p(R|F)=n/a
Wed Dec 20 13:26:02 2023 | epoch:     

In [None]:
cuda