<a href="https://colab.research.google.com/github/juliawiktoria/glow_implementation/blob/main/glow_impl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount Google Drive to save work

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pwd

/content/drive/My Drive/flow_implementation


In [None]:
%cd drive/MyDrive/flow_implementation/

/content/drive/MyDrive/flow_implementation


In [None]:
!rm -rf data/

In [None]:
!rm -rf samples/

# IMPLEMENTATION

Imports

In [None]:
## Standard libraries
import os
import math
import time
import numpy as np

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.optim.lr_scheduler as sched
import torch.backends.cudnn as cudnn

# Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms

# other
from tqdm import tqdm

Calculating negative log likelihood

In [None]:
class NLLLoss(nn.Module):
    """Negative log-likelihood loss assuming isotropic gaussian with unit norm.
    Args:
        k (int or float): Number of discrete values in each input dimension.
            E.g., `k` is 256 for natural images.
    See Also:
        Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803
    """
    def __init__(self, k=256):
        super(NLLLoss, self).__init__()
        self.k = k

    def forward(self, z, sldj):
        prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi))
        prior_ll = prior_ll.flatten(1).sum(-1) \
            - np.log(self.k) * np.prod(z.size()[1:])
        ll = prior_ll + sldj
        nll = -ll.mean()

        return nll

In [None]:
class InvertedConvolution(nn.Module):
  def __init__(self, num_channels):
    super(InvertedConvolution, self).__init__()
    self.num_channels = num_channels

    w_init = np.random.randn(num_channels, num_channels)
    w_init = np.linalg.qr(w_init)[0].astype(np.float32)
    self.weights = nn.Parameter(torch.from_numpy(w_init))

  def forward(self, x, sldj, reverse=False):
    lower_det_jacobian = torch.slogdet(self.weights)[1] * x.size(2) * x.size(3)

    if reverse:
      weights = torch.inverse(self.weights.double()).float()
      sldj = sldj - lower_det_jacobian
    else:
      weights = self.weights
      sldj = sldj + lower_det_jacobian
    
    weights = weights.view(self.num_channels, self.num_channels, 1, 1)
    z = F.conv2d(x, weights)

    return z, sldj

In [None]:
def mean_over_dimensions(tensor, dim=None, keepdims=False):
  if dim is None:
      return tensor.mean()
  else:
      if isinstance(dim, int):
          dim = [dim]
      dim = sorted(dim)
      for d in dim:
          tensor = tensor.mean(dim=d, keepdim=True)
      if not keepdims:
          for i, d in enumerate(dim):
              tensor.squeeze_(d-i)
      return tensor

In [None]:
class ActivationNormalisation(nn.Module):
  def __init__(self, num_features, scale=1., return_lower_det_jacobian=False):
    super(ActivationNormalisation, self).__init__()
    self.register_buffer('is_initialised', torch.zeros(1))
    self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
    self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1))

    self.num_features = num_features
    self.scale = float(scale)
    self.epsilon = 1e-6
    self.return_lower_det_jacobian = return_lower_det_jacobian

  def init_params(self, x):
    if not self.training:
      return
    with torch.no_grad():
      bias = -1*mean_over_dimensions(x.clone(), dim=[0, 2, 3], keepdims=True)
      v = mean_over_dimensions((x.clone() - bias) ** 2, dim=[0, 2, 3], keepdims=True)
      logs = (self.scale / (v.sqrt() + self.epsilon)).log()

      self.bias.data.copy_(bias.data)
      self.logs.data.copy_(logs.data)
      self.is_initialised += 1
  
  def _center(self, x, reverse=False):
    if reverse:
      return x - self.bias
    else:
      return x + self.bias

  def _scale(self, x, sldj, reverse=False):
    logs = self.logs

    if reverse:
      x = x * logs.mul(-1).exp()
    else:
      x = x * logs.exp()
    
    if sldj is not None:
      lower_det_jacobian = logs.sum() * x.size(2) * x.size(3)
      if reverse:
        sldj = sldj - lower_det_jacobian
      else:
        sldj = sldj + lower_det_jacobian
    
    return x, sldj
  
  def forward(self, x, lower_det_jacobian=None, reverse=False):
    if not self.is_initialised:
      self.init_params(x)
    
    if reverse:
      x, lower_det_jacobian = self._scale(x, lower_det_jacobian, reverse)
      x = self._center(x, reverse)
    else:
      x = self._center(x, reverse)
      x, lower_det_jacobian = self._scale(x, lower_det_jacobian, reverse)
    
    if self.return_lower_det_jacobian:
      return x, lower_det_jacobian
    
    return x

In [None]:
class CNN(nn.Module):
  def __init__(self, in_channels, mid_channels, out_channels, use_act_norm=False):
    super(CNN, self).__init__()
    norm_function = ActivationNormalisation if use_act_norm else nn.BatchNorm2d

    self.in_norm = norm_function(in_channels)
    self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
    nn.init.normal_(self.in_conv.weight, 0., 0.05)

    self.mid_norm = norm_function(mid_channels)
    self.mid_conv = nn.Conv2d(mid_channels, mid_channels, kernel_size=1, padding=0, bias=False)
    nn.init.normal_(self.mid_conv.weight, 0., 0.05)

    self.out_norm = norm_function(mid_channels)
    self.out_conv = nn.Conv2d(mid_channels, out_channels,
                              kernel_size=3, padding=1, bias=True)
    nn.init.zeros_(self.out_conv.weight)
    nn.init.zeros_(self.out_conv.bias)

  def forward(self, x):
    x = self.in_norm(x)
    x = F.relu(x)
    x = self.in_conv(x)

    x = self.mid_norm(x)
    x = F.relu(x)
    x = self.mid_conv(x)

    x = self.out_norm(x)
    x = F.relu(x)
    x = self.out_conv(x)

    return x

In [None]:
class AffineCoupling(nn.Module):
  def __init__(self, in_channels, mid_channels):
    super(AffineCoupling, self).__init__()
    self.name = "affine coupling layer"
    self.cnn = CNN(in_channels, mid_channels, in_channels * 2)
    self.scale = nn.Parameter(torch.ones(in_channels, 1, 1))
  
  def forward(self, x, lower_det_jacobian, reverse=False):
    x_change, x_id = x.chunk(2, dim=1)

    st = self.cnn(x_id)
    s, t = st[:, 0::2, ...], st[:, 1::2, ...]
    s = self.scale * torch.tanh(s)

    # Scale and translate
    if reverse:
        x_change = x_change * s.mul(-1).exp() - t
        lower_det_jacobian = lower_det_jacobian - s.flatten(1).sum(-1)
    else:
        x_change = (x_change + t) * s.exp()
        lower_det_jacobian = lower_det_jacobian + s.flatten(1).sum(-1)

    x = torch.cat((x_change, x_id), dim=1)

    return x, lower_det_jacobian

In [None]:
class _FlowStep(nn.Module):
  def __init__(self, in_channels, mid_channels):
    super(_FlowStep, self).__init__()

    self.normalisation = ActivationNormalisation(in_channels, return_lower_det_jacobian=True)
    self.convolution = InvertedConvolution(in_channels)
    self.coupling = AffineCoupling(in_channels // 2, mid_channels)

  def forward(self, x, sldj=None, reverse=False):
    if reverse:
      x, sldj = self.coupling(x, sldj, reverse)
      x, sldj = self.convolution(x, sldj, reverse)
      x, sldj = self.normalisation(x, sldj, reverse)
    else:
      x, sldj = self.normalisation(x, sldj, reverse)
      x, sldj = self.convolution(x, sldj, reverse)
      x, sldj = self.coupling(x, sldj, reverse)
    
    return x, sldj

In [None]:
class _GlowLevel(nn.Module):
  def __init__(self, in_channels, mid_channels, num_levels, num_steps):
    super(_GlowLevel, self).__init__()

    self.steps = nn.ModuleList([_FlowStep(in_channels=in_channels, mid_channels=mid_channels) for _ in range(num_steps)])

    # there are more than 1 level; create a link to the next level object
    if num_levels > 1:
      self.next = _GlowLevel(in_channels=2*in_channels, mid_channels=mid_channels, num_levels=num_levels-1, num_steps=num_steps)
    else:
      self.next = None

  def forward(self, x, sldj, reverse=False):
    if not reverse:
      for step in self.steps:
        x, sldj = step(x, sldj, reverse)
    
    if self.next is not None:
      x = squeeze(x)
      x, x_split = x.chunk(2, dim=1)
      x, sldj = self.next(x, sldj, reverse)
      x = torch.cat((x, x_split), dim=1)
      x = squeeze(x, reverse=True)
    
    if reverse:
      for step in reversed(self.steps):
        x, sldj = step(x, sldj, reverse)
    
    return x, sldj

In [None]:
class GlowModel(nn.Module):
  def __init__(self, num_channels, num_layers, num_steps):
    super(GlowModel, self).__init__()
    self.name = "glow"
    self.num_channels = num_channels
    self.num_layers = num_layers
    self.num_steps = num_steps
    self.register_buffer('bounds', torch.tensor([0.9], dtype=torch.float32))

    self.flows = _GlowLevel(in_channels=4*3, mid_channels=num_channels, num_levels=num_layers, num_steps=num_steps)
  
  def forward(self, x, reverse=False):
    if reverse:
      sldj = torch.zeros(x.size(0), device=x.device)
    else:
      if x.min() < 0 or x.max() > 1:
        raise ValueError('Expected x in [0, 1], got min/max [{}, {}]'.format(x.min(), x.max()))
      x, sldj = self._pre_process(x)
    
    x = squeeze(x)
    x, sldj = self.flows(x, sldj, reverse)
    x = squeeze(x, reverse=True)

    return x, sldj
  
  def _pre_process(self, x):
    y = (x * 255. + torch.rand_like(x)) / 256.
    y = (2 * y - 1) * self.bounds
    y = (y + 1) / 2
    y = y.log() - (1. - y).log()

    # Save log-determinant of Jacobian of initial transform
    ldj = F.softplus(y) + F.softplus(-y) \
        - F.softplus((1. - self.bounds).log() - self.bounds.log())
    sldj = ldj.flatten(1).sum(-1)

    return y, sldj

  def describe(self):
    print("Model {} with {} convolutional channels, {} model levels, and {} steps in each levels.".format(self.name, self.num_channels, self.num_layers, self.num_steps))

In [None]:
def squeeze(x, reverse=False):
  b, c, h, w = x.size()

  if reverse:
    # unsqueeze
    x = x.view(b, c // 4, 2, 2, h, w)
    x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
    x = x.view(b, c // 4, h * 2, w * 2)
  else:
    # squeeze
    x = x.view(b, c, h //2, 2, w //2, 2)
    x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
    x = x.view(b, c * 2 * 2, h // 2, w // 2)
  
  return x

In [None]:
def clip_grad_norm(optimizer, max_norm, norm_type=2):
  for group in optimizer.param_groups:
    utils.clip_grad_norm(group['params'], max_norm, norm_type)

In [None]:
def bits_per_dimension(x, nll):
  dim = np.prod(x.size()[1:])
  bpd = nll / (np.log(2) * dim)
  return bpd

In [None]:
class AvgMeter(object):
  def __init__(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

In [None]:
@torch.enable_grad()
def train(epoch, model, trainloader, device, optimizer, scheduler, loss_func, max_grad_norm):
  # TODO: implement checkpointing
  print("===> EPOCH {}".format(epoch))
  global global_step
  # training mode from torch nn module
  model.train()
  loss_meter = AvgMeter()
  with tqdm(total=len(trainloader.dataset)) as progress_bar:
    for x, _ in trainloader:
      x = x.to(device)
      optimizer.zero_grad()
      z, sldj = model(x, reverse=False)
      loss = loss_func(z, sldj)
      loss_meter.update(loss.item(), x.size(0))
      loss.backward()

      if max_grad_norm > 0:
        clip_grad_norm(optimizer, max_grad_norm)
      optimizer.step()
      scheduler.step()

      progress_bar.set_postfix(nll=loss_meter.avg, bpd=bits_per_dimension(x, loss_meter.avg), lr=optimizer.param_groups[0]['lr'])
      progress_bar.update(x.size(0))
      global_step += x.size(0)

In [None]:
def sample(model, batch_size, device):
  z = torch.randn((batch_size, 3, 32, 32), dtype=torch.float32, device=device)
  x, _ = model(z, reverse=True)
  x = torch.sigmoid(x)

  return x

In [None]:
@torch.no_grad()
def test(epoch, model, testloader, device, loss_func, num_samples):
  print("testing func")
  global best_loss

  model.eval()
  loss_meter = AvgMeter()

  with tqdm(total=len(testloader.dataset)) as progress_bar:
    for x, _ in testloader:
      x = x.to(device)
      z, sldj = model(x, reverse=False)
      loss = loss_func(z, sldj)
      loss_meter.update(loss.item(), x.size(0))
      progress_bar.set_postfix(nll=loss_meter.avg, bpd=bits_per_dimension(x, loss_meter.avg))
      progress_bar.update(x.size(0))

    # Save samples and data
    images = sample(model, num_samples, device)
    path_to_images = 'samples/epoch' + str(epoch)
    os.makedirs(path_to_images, exist_ok=True)
    for i in range(images.size(0)):
      torchvision.utils.save_image(images[i, :, :, :], '{}/img_{}.png'.format(path_to_images, i))
    # images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255)
    # torchvision.utils.save_image(images_concat, 'samples/epoch_{}.png'.format(epoch))

In [None]:
def main_wrapper():
  if_gpu = True
  # default args
  batch_size = 64
  benchmark = True
  gpu_ids = [0]
  learning_rate = 1e-3
  max_grad_norm = -1.
  num_channels = 512
  num_levels = 3
  num_steps = 32
  num_epochs = 20
  num_samples = 64
  num_workers = 8
  resume = False
  seed = 0
  warm_up = 500000
  
  device = 'cuda' if torch.cuda.is_available() and if_gpu else 'cpu'
  print(device)
  max_grad_norm_default = -1

  # getting data for training; just CIFAR10
  transform_train = transforms.Compose([
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor()
  ])

  transform_test = transforms.Compose([
      transforms.ToTensor()
  ])

  trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
  trainloader = data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

  testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
  testloader = data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)


  # define the model
  model = GlowModel(num_channels, num_levels, num_steps)
  model = model.to(device)
  model.describe()

  # if using GPU
  if device == 'cuda':
    model = torch.nn.DataParallel(model, gpu_ids)

  loss_function = NLLLoss().to(device)
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
  scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / warm_up))

  times_array = []

  # training loop
  print("Starting training of the Glow model")
  for epoch in range(1, num_epochs + 1):
    start_time = time.time()
    train(epoch, model, trainloader, device, optimizer, scheduler, loss_function, max_grad_norm_default)
    test(epoch, model, testloader, device, loss_function, num_samples)
    elapsed_time = time.time() - start_time

    times_array.append(["Epoch " + str(epoch) + ": ", time.strftime("%H:%M:%S", time.gmtime(elapsed_time))])

  with open("epoch_times.txt", "w") as txt_file:
    for line in times_array:
      txt_file.write(" ".join(line) + "\n")

In [None]:
times_array = []
for i in range(1, 10):
  start_time = time.time()
  time.sleep(2)
  elapsed_time = time.time() - start_time
  times_array.append(["Epoch " + str(i) + ": ", time.strftime("%H:%M:%S", time.gmtime(elapsed_time))])

with open("output.txt", "w") as txt_file:
    for line in times_array:
      txt_file.write(" ".join(line) + "\n")

In [None]:
best_loss = 0
global_step = 0
main_wrapper()

cuda
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting data/cifar-10-python.tar.gz to data


  cpuset_checked))


Files already downloaded and verified


  0%|          | 0/50000 [00:00<?, ?it/s]

Model glow with 512 convolutional channels, 3 model levels, and 32 steps in each levels.
Starting training of the Glow model
===> EPOCH 1


100%|██████████| 50000/50000 [15:45<00:00, 52.89it/s, bpd=4.84, lr=1.56e-6, nll=1.03e+4]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:11<00:00, 139.83it/s, bpd=4.31, nll=9.18e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 2


100%|██████████| 50000/50000 [15:49<00:00, 52.64it/s, bpd=4.19, lr=3.13e-6, nll=8.93e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:09<00:00, 143.06it/s, bpd=4.12, nll=8.78e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 3


100%|██████████| 50000/50000 [15:45<00:00, 52.90it/s, bpd=4.08, lr=4.69e-6, nll=8.68e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 142.03it/s, bpd=4.03, nll=8.59e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 4


100%|██████████| 50000/50000 [15:46<00:00, 52.82it/s, bpd=4.02, lr=6.26e-6, nll=8.56e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 141.59it/s, bpd=3.98, nll=8.47e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 5


100%|██████████| 50000/50000 [15:44<00:00, 52.92it/s, bpd=3.98, lr=7.82e-6, nll=8.47e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 141.29it/s, bpd=3.99, nll=8.49e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 6


100%|██████████| 50000/50000 [15:46<00:00, 52.81it/s, bpd=3.95, lr=9.38e-6, nll=8.42e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 141.50it/s, bpd=3.95, nll=8.4e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 7


100%|██████████| 50000/50000 [15:45<00:00, 52.86it/s, bpd=3.93, lr=1.09e-5, nll=8.37e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 141.04it/s, bpd=3.92, nll=8.35e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 8


100%|██████████| 50000/50000 [15:48<00:00, 52.73it/s, bpd=3.92, lr=1.25e-5, nll=8.34e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 141.14it/s, bpd=3.93, nll=8.36e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 9


100%|██████████| 50000/50000 [15:50<00:00, 52.62it/s, bpd=3.91, lr=1.41e-5, nll=8.32e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 141.33it/s, bpd=3.91, nll=8.33e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 10


100%|██████████| 50000/50000 [15:44<00:00, 52.92it/s, bpd=3.89, lr=1.56e-5, nll=8.29e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 142.16it/s, bpd=3.92, nll=8.34e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 11


100%|██████████| 50000/50000 [15:44<00:00, 52.94it/s, bpd=3.88, lr=1.72e-5, nll=8.27e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:09<00:00, 143.19it/s, bpd=3.9, nll=8.31e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 12


100%|██████████| 50000/50000 [15:44<00:00, 52.95it/s, bpd=3.88, lr=1.88e-5, nll=8.26e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:09<00:00, 143.11it/s, bpd=4, nll=8.51e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 13


100%|██████████| 50000/50000 [15:43<00:00, 53.01it/s, bpd=3.87, lr=2.03e-5, nll=8.24e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 142.74it/s, bpd=3.94, nll=8.39e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 14


100%|██████████| 50000/50000 [15:41<00:00, 53.09it/s, bpd=3.86, lr=2.19e-5, nll=8.22e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:09<00:00, 143.63it/s, bpd=4.01, nll=8.53e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 15


100%|██████████| 50000/50000 [15:41<00:00, 53.13it/s, bpd=3.86, lr=2.35e-5, nll=8.23e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 141.64it/s, bpd=3.87, nll=8.24e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 16


100%|██████████| 50000/50000 [15:42<00:00, 53.03it/s, bpd=3.85, lr=2.5e-5, nll=8.2e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 142.82it/s, bpd=3.87, nll=8.23e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 17


100%|██████████| 50000/50000 [15:42<00:00, 53.07it/s, bpd=3.86, lr=2.66e-5, nll=8.21e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:09<00:00, 143.77it/s, bpd=3.93, nll=8.38e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 18


100%|██████████| 50000/50000 [15:42<00:00, 53.06it/s, bpd=3.84, lr=2.82e-5, nll=8.18e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:09<00:00, 143.65it/s, bpd=3.87, nll=8.24e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 19


100%|██████████| 50000/50000 [15:41<00:00, 53.12it/s, bpd=3.85, lr=2.97e-5, nll=8.19e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:10<00:00, 142.80it/s, bpd=3.87, nll=8.24e+3]
  0%|          | 0/50000 [00:00<?, ?it/s]

===> EPOCH 20


100%|██████████| 50000/50000 [15:42<00:00, 53.05it/s, bpd=3.84, lr=3.13e-5, nll=8.17e+3]
  0%|          | 0/10000 [00:00<?, ?it/s]

testing func


100%|██████████| 10000/10000 [01:09<00:00, 143.34it/s, bpd=3.91, nll=8.33e+3]
