In [1]:
import os

import numpy as np
import torch

import matplotlib.pyplot as plt


from LFIS.config.config import get_configuration, get_baseconfig
from LFIS.LF.LF import LF_base

from LFIS.util.util import (
                    run_stat_batch,
                    save_file,
                    load_file)

In [2]:
def check_parameter(config):
    print(f'Case: {config.case}')
    print(f'Problem: {config.problemtype}')
    print(f'Number of steps: {config.nstep}')
    print(f'Device: {config.device}')
    print(f'Number of dimension: {config.ndim}')
    print(f'Epoch: {config.train.epoch}')
    print(f'Sample size: {config.train.nsample}')
    print(f'Batch size: {config.train.nbatch}')

In [3]:
cfg = get_baseconfig()
configurer = get_configuration()

In [4]:
case = 'LGCP'

In [5]:
cfg_LGCP = configurer.setup_config(cfg, case)

In [6]:
check_parameter(cfg_LGCP)

Case: Log Gaussian Cox problem with Gaussian prior and Poisson process likelihood function
Problem: bayes
Number of steps: 256
Device: cuda:0
Number of dimension: 1600
Epoch: 2000
Sample size: 2000
Batch size: 20000


In [7]:
trainer = cfg_LGCP.train.method

In [8]:
LFmodel = LF_base(cfg_LGCP)
LFmodel.to(cfg_LGCP.device)

LF_base(
  (flow): IndependentLinear()
)

# Train LFIS

In [9]:
output = trainer(cfg_LGCP)
save_file(case, output)

time = 0.0000, loss = 0.0000, percetage = nan%
Complete Training Flow at time 0.0000
time = 0.0039, loss = 0.1785, percetage = 99.9500%
time = 0.0039, loss = 0.0345, percetage = 20.0972%
time = 0.0039, loss = 0.0337, percetage = 19.6366%
===Reducing LR====
time = 0.0039, loss = 0.0274, percetage = 15.8890%
time = 0.0039, loss = 0.0273, percetage = 16.4387%
===Reducing LR====
time = 0.0039, loss = 0.0209, percetage = 11.7072%
time = 0.0039, loss = 0.0247, percetage = 14.5559%
===Reducing LR====
time = 0.0039, loss = 0.0222, percetage = 12.6824%
time = 0.0039, loss = 0.0220, percetage = 12.2539%
===Reducing LR====
time = 0.0039, loss = 0.0239, percetage = 13.3448%
===Reducing LR====
====LR too small, stop traning ====
Complete Training Flow at time 0.0039
time = 0.0078, loss = 0.2409, percetage = 34.7064%
time = 0.0078, loss = 0.1326, percetage = 19.9341%
time = 0.0078, loss = 0.1272, percetage = 17.6094%
time = 0.0078, loss = 0.1216, percetage = 17.5820%
===Reducing LR====
time = 0.0078

# Load pre-trained LFIS

In [10]:
nnmodel = cfg_LGCP.nnmodel.to(cfg_LGCP.device)
output = load_file(case, nnmodel)

In [None]:
logstat = run_stat_batch(LFmodel, output, nruns = 30)

tensor(495.3729, device='cuda:0')
tensor(494.7118, device='cuda:0')
tensor(494.8094, device='cuda:0')
tensor(494.7348, device='cuda:0')
tensor(494.8850, device='cuda:0')
tensor(494.6763, device='cuda:0')
tensor(494.9839, device='cuda:0')
tensor(494.8212, device='cuda:0')
tensor(494.6595, device='cuda:0')
tensor(494.7829, device='cuda:0')
tensor(495.1160, device='cuda:0')
tensor(494.8847, device='cuda:0')
tensor(494.7513, device='cuda:0')
tensor(494.8965, device='cuda:0')
tensor(494.7083, device='cuda:0')
tensor(495.1452, device='cuda:0')
tensor(494.7811, device='cuda:0')
tensor(495.0384, device='cuda:0')
tensor(494.7439, device='cuda:0')
tensor(494.9714, device='cuda:0')
tensor(494.5389, device='cuda:0')


In [16]:
logstat

{'logzmean': 470.7424049910776,
 'logzstd': 2.398313591606652,
 'samples': array([[4.1701791 , 5.76021785, 6.54797896, ..., 4.07596655, 5.07062809,
         4.42885283],
        [4.01266707, 1.84104619, 4.12325986, ..., 7.04798194, 4.6426013 ,
         4.5268069 ],
        [5.21051622, 3.53571935, 5.56174123, ..., 4.99369064, 5.76155589,
         6.05981801],
        ...,
        [4.75286317, 5.51608139, 5.72416513, ..., 4.53716292, 4.04539905,
         3.41839997],
        [5.17043881, 4.25018843, 5.2160464 , ..., 5.11558876, 6.94621893,
         4.66551744],
        [4.41593129, 2.41764428, 4.31958874, ..., 5.44831145, 4.41350289,
         5.68862725]]),
 'weight': array([4.74542477e-15, 1.10587041e-11, 1.10892352e-10, ...,
        8.14638838e-11, 7.99332484e-07, 8.47907467e-16]),
 'logzlist': array([471.38232904, 470.73429832, 468.24611543, 469.53849486,
        469.17183012, 470.08785777, 470.13775727, 469.44244056,
        471.04251887, 476.86107308, 469.06270336, 467.32254   ,
  