In [None]:
import torch
from torch import nn
from torch import optim
import numpy as np
from tqdm import tqdm
import os
from torch.utils.tensorboard import SummaryWriter
writer_path = "./tb_log/nf/"

from nflows.flows.base import Flow
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.distributions.normal import ConditionalDiagonalNormal
from functions import load_data

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
x_train, y_train, x_val, y_val = load_data()

In [None]:
def flow_generator(num_layers=4, hidden_features=8, num_blocks=10):
    base_dist = ConditionalDiagonalNormal(shape=[4],
                                        context_encoder=nn.Linear(144, 8))

    transforms = []
    for _ in range(num_layers):
        transforms.append(ReversePermutation(features=4))
        transforms.append(MaskedAffineAutoregressiveTransform(features=4,
                                                            hidden_features=hidden_features,
                                                            context_features=144,
                                                            num_blocks=num_blocks))
    transform = CompositeTransform(transforms)

    flow = Flow(transform, base_dist)
    optimizer = optim.Adam(flow.parameters())
    return flow, optimizer

In [None]:
num_iter = 1000
hparam_writer = SummaryWriter("tb_log/nf/hparam")
for num_layers in [5, 6]:
    for hidden_features in [32, 64]:
        for num_blocks in [2, 4, 6, 8]:
            index = 1
            while os.path.isfile("nf/modelInfo_{}.txt".format(index)): index += 1
            flow, optimizer = flow_generator(num_layers, hidden_features, num_blocks)
            flow = flow.to(device)
            writer = SummaryWriter(writer_path + str(index))
            with open("nf/modelInfo_{}.txt".format(index), 'w') as f:
                f.writelines('num_layers = {}\n'.format(num_layers))
                f.writelines('hidden_features = {}\n'.format(hidden_features))
                f.writelines('num_blocks = {}\n'.format(num_blocks))

            for i in tqdm(range(num_iter)):
                x = torch.tensor(y_train, dtype=torch.float32).to(device)
                y = torch.tensor(x_train/1000, dtype=torch.float32).to(device)
                optimizer.zero_grad()
                loss = -flow.log_prob(inputs=x, context=y).mean()
                loss.backward()
                optimizer.step()
                writer.add_scalar("training_loss", loss, i)
            torch.save(flow, "./nf/{}.pt".format(index))
            hparam_writer.add_hparams({
                'num_layers': num_layers,
                'hidden_features': hidden_features,
                'num_blocks': num_blocks},
                {'hparam/loss': loss})