In [1]:
import torch
from torch import nn
from torch import optim
import numpy as np
from tqdm import tqdm
import os
from torch.utils.tensorboard import SummaryWriter
from nflows.flows.base import Flow
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.distributions.normal import ConditionalDiagonalNormal
from functions import load_train_data
from sklearn.preprocessing import StandardScaler
from pickle import dump

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
#3flavor_poisson, nsi_poisson
learn_target = '3flavor_poisson'

x_train, y_train, x_val, y_val = load_train_data(learn_target)
select = 10
x_train = x_train[:len(x_train)//select]
y_train = y_train[:len(y_train)//select]
x_val = x_val[:len(x_val)//select]
y_val = y_val[:len(y_val)//select]

#Standardization
scaler = StandardScaler()
scaler.fit(x_train)
x_train = scaler.transform(x_train)
x_val = scaler.transform(x_val)
dump(scaler, open('nf/' + learn_target + '/scaler.pkl', 'wb'))

In [3]:
len_x = len(x_train[0])
len_y =len(y_train[0])

def flow_generator(num_layers=4, hidden_features=8, num_blocks=10):
    base_dist = ConditionalDiagonalNormal(shape=[len_y],
                                        context_encoder=nn.Linear(len_x, len_y*2))

    transforms = []
    for _ in range(num_layers):
        transforms.append(ReversePermutation(features=len_y))
        transforms.append(MaskedAffineAutoregressiveTransform(features=len_y,
                                                            hidden_features=hidden_features,
                                                            context_features=len_x,
                                                            num_blocks=num_blocks))
    transform = CompositeTransform(transforms)

    flow = Flow(transform, base_dist)
    optimizer = optim.Adam(flow.parameters())
    return flow, optimizer

In [4]:
num_iter = 1000
hparam_writer = SummaryWriter('tb_log/nf/' + learn_target + '/hparam')
for num_layers in [5, 6, 7, 8]:
    for hidden_features in [8, 16, 32]:
        for num_blocks in [2]:
            index = 1
            while os.path.isfile('nf/' + learn_target + '/modelInfo_{}.txt'.format(index)): index += 1
            flow, optimizer = flow_generator(num_layers, hidden_features, num_blocks)
            flow = flow.to(device)
            writer = SummaryWriter( './tb_log/nf/' + learn_target + '/' + str(index))
            with open('nf/' + learn_target + '/modelInfo_{}.txt'.format(index), 'w') as f:
                f.writelines('num_layers = {}\n'.format(num_layers))
                f.writelines('hidden_features = {}\n'.format(hidden_features))
                f.writelines('num_blocks = {}\n'.format(num_blocks))

            for i in tqdm(range(num_iter)):
                x = torch.tensor(y_train, dtype=torch.float32).to(device)
                y = torch.tensor(x_train, dtype=torch.float32).to(device)
                optimizer.zero_grad()
                loss = -flow.log_prob(inputs=x, context=y).mean()
                loss.backward()
                optimizer.step()
                writer.add_scalar('training_loss', loss, i)
            torch.save(flow, './nf/' + learn_target + '/{}.pt'.format(index))
            hparam_writer.add_hparams({
                'num_layers': num_layers,
                'hidden_features': hidden_features,
                'num_blocks': num_blocks},
                {'hparam/loss': loss})

100%|█████████████████████████████████████████████████████████████████████████| 1000/1000 [00:37<00:00, 26.57it/s]
100%|█████████████████████████████████████████████████████████████████████████| 1000/1000 [00:38<00:00, 25.89it/s]
100%|█████████████████████████████████████████████████████████████████████████| 1000/1000 [00:38<00:00, 25.75it/s]
100%|█████████████████████████████████████████████████████████████████████████| 1000/1000 [00:43<00:00, 23.19it/s]
100%|█████████████████████████████████████████████████████████████████████████| 1000/1000 [00:42<00:00, 23.43it/s]
100%|█████████████████████████████████████████████████████████████████████████| 1000/1000 [00:41<00:00, 24.04it/s]
100%|█████████████████████████████████████████████████████████████████████████| 1000/1000 [00:46<00:00, 21.65it/s]
100%|█████████████████████████████████████████████████████████████████████████| 1000/1000 [00:45<00:00, 21.87it/s]
100%|█████████████████████████████████████████████████████████████████████████| 