In [1]:
import torch
import torch.distributions as D
import sys
import time
import numpy as np

from torch.utils.data import DataLoader, RandomSampler
from itertools import cycle
from tqdm.auto import trange
from copy import deepcopy

sys.path.append("../lib")
from sw import *
from nf.realnvp import *
from nf.utils_nf import *

from evaluate import posterior_sample_evaluation
from datasets import *
from data_posterior import LogRegDPTarget, posterior_sample_evaluation

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def one_step_flow(n_epochs, rho_prev, J, create_NF, d, h, 
                  num_projections, nh, nl, lr, n_samples, device, 
                  k, plot_loss, sw_approx, max_sliced, reset_NF, use_scheduler):
    """
        Perform gradient descent at time step t
        
        Inputs:
        - n_epochs
        - rho_prev: previous NF
        - J: functional (taking (x,z,log(det(J(z)))) as inputs)
        - h: time step
        - num_projections
        - nh: number of hidden units
        - nl: number of layers
        - lr
        - device
        - k: step
        - plot_loss
        - sw_approx: use the concentration approximation
        - max_sliced: if True, use max SW
        - reset_NF: If True, start from a random initialized NF
        - use_scheduler: If True, use ReduceLROnPlateau Scheduler
        
        Outputs:
        - rho_{k+1}^h
    """    

    if k>0 and not reset_NF: ## check if it is a NF
#        rho_k = deepcopy(rho_prev)
        rho_k = create_NF(nh, nl, d=d).to(device)
        rho_k.load_state_dict(deepcopy(rho_prev.state_dict()))
    else:
        rho_k = create_NF(nh, nl, d=d).to(device)

    optimizer = torch.optim.Adam(rho_k.parameters(), lr=lr)
    optimizer.zero_grad()
    
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
        
    
    train_loss = []
    sw_loss = []
    J_loss = []

    for j in range(n_epochs):
        z_k = torch.randn(n_samples, d, device=device)
        x_k, log_det_k = rho_k(z_k)
        x_k = x_k[-1]

        if k>0:
            z0 = torch.randn(n_samples, d, device=device)
            x_prev, log_det_prev = rho_prev(z0)
            x_prev = x_prev[-1]
        else:
            x_prev = rho_prev.sample((n_samples,))

        if sw_approx:
            sw = sw2_approx(x_k, x_prev, device, u_weights=None, v_weights=None)
        elif max_sliced:
            sw = max_SW(x_k, x_prev, device, p=2, u_weights=None, v_weights=None)
        else:
            sw = sliced_wasserstein(x_k, x_prev, num_projections, device, 
                                    u_weights=None, v_weights=None, p=2)
            
        if num_projections == 0:
            sw *= 0
            h = 1/2
            
        f = J(x_k, z_k, log_det_k)
        loss = sw+2*h*f
                
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        for flow in rho_k.flows:
            if flow.__class__.__name__ == "ConvexPotentialFlow":
                flow.icnn.convexify() # clamp weights to be >=0
        
        train_loss.append(loss.item())
        sw_loss.append(sw.item())
        J_loss.append(2*h*f.item())
        
        if use_scheduler:
            scheduler.step(f)

    if plot_loss:
        fig, ax = plt.subplots(1,3,figsize=(15,5))
        ax[0].plot(range(len(train_loss)),train_loss, label="Loss")
        L = range(10, len(train_loss))
        moving_average = []
        for i in range(len(L)):
            moving_average.append(np.mean(train_loss[i:i+10]))
        ax[0].plot(L, moving_average, label="Moving Average")
        ax[0].set_title("Full loss")
        ax[0].legend()
        
        ax[1].plot(sw_loss)
        moving_average = []
        for i in range(len(L)):
            moving_average.append(np.mean(sw_loss[i:i+10]))
        ax[1].plot(L, moving_average)
        ax[1].set_title("SW")
        
        ax[2].plot(J_loss)
        moving_average = []
        for i in range(len(L)):
            moving_average.append(np.mean(J_loss[i:i+10]))
        ax[2].plot(L, moving_average)
        ax[2].set_title("2hJ")
        
        plt.suptitle("k="+str(k))
        plt.show()
        
    return rho_k


def SWGF_BLR(rho_0, tau, n_step, n_epochs, create_NF, target, d=2, nh=64, nl=5, lrs=1e-5, 
         num_projections=100, n_samples=500, sw_approx=False, max_sliced=False, reset_NF=False, 
         device=device, use_scheduler=False, plot_loss=False, tqdm_bar=False):
    """
        Inputs:
        - rho_0
        - tau: step size
        - n_step: number of t steps
        - n_epochs: number of epochs for the optimization (can be a list of size
        n_step or an int)
        - create_NF: function which return a BaseNormalizingFlow class taking
        (nh, nl, d) as inputs
        - target: Target posterior for Bayesian Logistic Regression
        - nh: number of hidden units
        - nl: number of layers
        - lrs: learning rate for optimization (can be a list of size n_step or an int)
        - num_projections
        - n_samples: batch size
        - sw_approx: If true, use the SW_2^2 approximation of SW (without projections)
        - max_sliced: If True, use max-SW 
        - reset_nn: If True, start from an unitialized flow
        - device
        - use_scheduler: If True, use a ReduceLROnPlateau Scheduler
        - plot_loss (default False)
        - tqdm_bar (default False)
    """
    
    Lrho = [rho_0] ## For rho_0, distribution class
    
    if tqdm_bar:    
        pbar = trange(n_step)
    else:
        pbar = range(n_step)
    
    for k in pbar:
        def V(X):
            S = target.sample_data()

            init_targ_loss = target.est_log_init_prob(X, reduction='sum')

            data_targ_loss = target.len_dataset * target.est_log_data_prob(X, S, reduction='sum')
            loss = - init_targ_loss - data_targ_loss
            return loss


        def J(x, z, log_det):
            h = torch.mean(log_likelihood(z, log_det, device), axis=0) ## entropy
            v = torch.mean(V(x), axis=0)
            return v+h
    
        if isinstance(n_epochs, np.ndarray):
            n_epoch = n_epochs[k].astype(int)
        else:
            n_epoch = n_epochs

        if isinstance(lrs, np.ndarray):
            lr = lrs[k]
        else:
            lr = lrs

        rho_k = one_step_flow(n_epoch, Lrho[-1], J, create_NF, d, tau, 
                                num_projections, nh, nl, lr, n_samples, 
                                device, k, plot_loss, sw_approx, max_sliced, 
                                reset_NF, use_scheduler)
        
        Lrho.append(rho_k)

    return Lrho

In [4]:
def evaluate_blr(dataset_name, dataset_batch_size, lr, ntraining=5,
                 h=0.1, t_end=0.5, t_init=0, nh=512, nl=2, epochs=500,
                n_projs=1000, tqdm_bar=True, plot_loss=False):
    dataset, train_ds, test_ds = get_train_test_datasets(dataset_name)

    train_dl = DataLoader(train_ds, batch_size=dataset_batch_size, shuffle=True)
    target = LogRegDPTarget(train_dl, dataset.n_features, device=device, clip_alpha=8)

    X_test, y_test = dataset2numpy(test_ds)
    
    accuracies = []
    ts = []
    
    if tqdm_bar:
        pbar = trange(ntraining)
    else:
        pbar = range(ntraining)
        
    for k in pbar:
        start = time.time()

        d = dataset.n_features

        n_steps = int(np.ceil((t_end-t_init)/h))

        mu0 = torch.zeros(d+1, device=device, dtype=torch.float)
        sigma0 = torch.eye(d+1, device=device, dtype=torch.float)
        rho_0 = D.MultivariateNormal(mu0, sigma0)

        lrs = lr * np.ones(n_steps)

        Lrho = SWGF_BLR(rho_0, h, n_step=n_steps, n_epochs=epochs, d=d+1,
                    create_NF=create_RealNVP, nh=nh, nl=nl, lrs=lrs, 
                    num_projections=n_projs, n_samples=1024, plot_loss=plot_loss,
                    tqdm_bar=False, use_scheduler=False, target=target)

        ts.append(time.time()-start)
        
        rho = Lrho[-1]
        z = torch.randn((4096,d+1), device=device)
        ws, _ = rho(z)
        w = ws[-1]
        acc, _ = posterior_sample_evaluation(w.detach().cpu().numpy(), X_test, y_test)
        accuracies.append(acc)
    
    return accuracies, ts

## Accuracies

### Covtype

In [5]:
acc, ts = evaluate_blr("covtype", 512, nh=512, nl=2, lr=2e-5, epochs=1000)

  0%|          | 0/5 [00:00<?, ?it/s]

In [6]:
print("Results", acc)
print("Mean", np.mean(acc), "Std", np.std(acc))

Results [0.7557894374499797, 0.7549977195081022, 0.7561250570122975, 0.7563401977573728, 0.7522267067115307]
Mean 0.7550958236878567 Std 0.0015053467883486124


In [7]:
print("Time", ts)
print("Mean", np.mean(ts), "Std", np.std(ts))

Time [108.42422366142273, 112.59758138656616, 97.97873544692993, 99.54449796676636, 99.1012225151062]
Mean 103.52925219535828 Std 5.87348770765303


### German

In [41]:
acc, ts = evaluate_blr("german", 800, lr=1e-4, h=1e-6, t_end=5e-6, plot_loss=False)

  0%|          | 0/5 [00:00<?, ?it/s]

In [42]:
print("Results", acc)
print("Mean", np.mean(acc), "Std", np.std(acc))

Results [0.675, 0.68, 0.69, 0.68, 0.675]
Mean 0.68 Std 0.005477225575051626


In [43]:
print("Time", ts)
print("Mean", np.mean(ts), "Std", np.std(ts))

Time [82.31915354728699, 82.52840495109558, 82.5493311882019, 82.40837907791138, 82.97730135917664]
Mean 82.5565140247345 Std 0.22635708635375681


### Diabetis

In [5]:
acc, ts = evaluate_blr("diabetis", 614, lr=5e-4, h=5e-6, t_end=5e-5, plot_loss=False)

  0%|          | 0/5 [00:00<?, ?it/s]

In [6]:
print("Results", acc)
print("Mean", np.mean(acc), "Std", np.std(acc))

Results [0.7727272727272727, 0.7792207792207793, 0.7792207792207793, 0.7792207792207793, 0.7792207792207793]
Mean 0.7779220779220779 Std 0.0025974025974026204


In [7]:
print("Time", ts)
print("Mean", np.mean(ts), "Std", np.std(ts))

Time [124.2893009185791, 113.6352903842926, 124.00571465492249, 124.34977626800537, 124.2599229812622]
Mean 122.10800104141235 Std 4.237983864860112


### Twonorm

In [101]:
acc, ts = evaluate_blr("twonorm", 1024, lr=1e-4,  h=1e-8, t_end=20e-8)

  0%|          | 0/5 [00:00<?, ?it/s]

In [102]:
print("Results", acc)
print("Mean", np.mean(acc), "Std", np.std(acc))

Results [0.9817567567567568, 0.981081081081081, 0.9804054054054054, 0.981081081081081, 0.9797297297297297]
Mean 0.9808108108108108 Std 0.0006890566910260536


In [103]:
print("Time", ts)
print("Mean", np.mean(ts), "Std", np.std(ts))

Time [291.88450264930725, 291.63723278045654, 322.92185854911804, 298.90231490135193, 301.83920645713806]
Mean 301.43702306747434 Std 11.44963919648696


### Ringnorm

In [53]:
acc, ts = evaluate_blr("ringnorm", 1024, lr=5e-5, h=1e-6, t_end=5e-6)

  0%|          | 0/5 [00:00<?, ?it/s]

In [54]:
print("Results", acc)
print("Mean", np.mean(acc), "Std", np.std(acc))

Results [0.7412162162162163, 0.7412162162162163, 0.7405405405405405, 0.7418918918918919, 0.7398648648648649]
Mean 0.7409459459459459 Std 0.0006890566910260363


In [55]:
print("Time", ts)
print("Mean", np.mean(ts), "Std", np.std(ts))

Time [82.74868774414062, 82.9918487071991, 82.74334359169006, 82.88334560394287, 82.34805178642273]
Mean 82.74305548667908 Std 0.2180087668528168


### Banana

In [35]:
acc, ts = evaluate_blr("banana", 1024, lr=1e-4)

  0%|          | 0/5 [00:00<?, ?it/s]

In [36]:
print("Results", acc)
print("Mean", np.mean(acc), "Std", np.std(acc))

Results [0.5490566037735849, 0.5433962264150943, 0.569811320754717, 0.5716981132075472, 0.560377358490566]
Mean 0.5588679245283019 Std 0.011156034338939231


In [37]:
print("Time", ts)
print("Mean", np.mean(ts), "Std", np.std(ts))

Time [66.71263360977173, 66.6420567035675, 66.59221506118774, 66.75500988960266, 66.54901146888733]
Mean 66.6501853466034 Std 0.07556553810915242


### Splice

In [61]:
acc, ts = evaluate_blr("splice", 512, lr=5e-4, nl=5, nh=128,  h=1e-6, t_end=5e-6)

  0%|          | 0/5 [00:00<?, ?it/s]

In [62]:
print("Results", acc)
print("Mean", np.mean(acc), "Std", np.std(acc))

Results [0.8447412353923205, 0.8514190317195326, 0.8497495826377296, 0.8530884808013356, 0.8497495826377296]
Mean 0.8497495826377296 Std 0.002793522626157209


In [63]:
print("Time", ts)
print("Mean", np.mean(ts), "Std", np.std(ts))

Time [102.52043581008911, 119.41600966453552, 110.12472009658813, 121.95077538490295, 115.92983675003052]
Mean 113.98835554122925 Std 6.972370760971641


### Waveform

In [14]:
acc, ts = evaluate_blr("waveform", 512, nh=128, nl=5, lr=1e-4)

  0%|          | 0/5 [00:00<?, ?it/s]

In [15]:
print("Results", acc)
print("Mean", np.mean(acc), "Std", np.std(acc))

Results [0.777, 0.776, 0.777, 0.775, 0.775]
Mean 0.776 Std 0.0008944271909999167


In [16]:
print("Time", ts)
print("Mean", np.mean(ts), "Std", np.std(ts))

Time [120.79881429672241, 121.34904217720032, 120.70495057106018, 120.94201898574829, 119.3961431980133]
Mean 120.6381938457489 Std 0.6588718034957246


### Image

In [53]:
acc, ts = evaluate_blr("image", 1024, lr=5e-5)

  0%|          | 0/5 [00:00<?, ?it/s]

In [54]:
print("Results", acc)
print("Mean", np.mean(acc), "Std", np.std(acc))

Results [0.8205741626794258, 0.8205741626794258, 0.8277511961722488, 0.8157894736842105, 0.8205741626794258]
Mean 0.8210526315789475 Std 0.0038277511961722493


In [55]:
print("Time", ts)
print("Mean", np.mean(ts), "Std", np.std(ts))

Time [74.73810482025146, 81.63690543174744, 68.53701853752136, 68.74294590950012, 68.55454754829407]
Mean 72.4419044494629 Std 5.174215904844543
