# Subspace Inference with Advanced MH Sampler

### use packages

In [1]:
using NPZ
using Flux
using Flux: Data.DataLoader
using Flux: @epochs
using BSON: @save
using BSON: @load
using PyPlot
using SubspaceInference;
using Zygote;

┌ Info: Precompiling SubspaceInference [706446a6-0e85-4c47-b731-c658bbb72625]
└ @ Base loading.jl:1278
│ - If you have SubspaceInference checked out for development and have
│   added AdvancedHMC as a dependency but haven't updated your primary
│   environment's manifest file, try `Pkg.resolve()`.
│ - Otherwise you may need to report an issue with SubspaceInference


### Set working directory

In [2]:
root = pwd();
cd(root);

### Load Data and Format

In [3]:
data_ld = npzread("data.npy");
x, y = (data_ld[:, 1]', data_ld[:, 2]');
function features(x)
    return vcat(x./2, (x./2).^2)
end

f = features(x);
data =  DataLoader(f,y, batchsize=50, shuffle=true);

### Setup NN model

In [4]:
m = Chain(
	Dense(2,200,Flux.relu), 
	Dense(200,50,Flux.relu),
	Dense(50,50,Flux.relu),
	Dense(50,50,Flux.relu),
	Dense(50,1),
)


θ, re = Flux.destructure(m);

L(m, x, y) = Flux.Losses.mse(m(x), y)/2;
ps = Flux.params(m);
opt = Momentum(0.01, 0.95);

### Pretrain model

In [5]:
## update cost function to suite for flux training
#L(x, y) = Flux.Losses.mse(m(x), y)/2;
#@epochs 3000 Flux.train!(L, ps, data, opt)

This notebook, we are using a pretrained model

In [6]:
i = 1;
@load "model_weights_$(i).bson" ps;
Flux.loadparams!(m, ps);

### Run subspace inference

In [None]:
M = 20
T = 25
c = 1
itr = 10_00

chn, lp, W_swa, re = SubspaceInference.subspace_inference(m, L, data, opt,
	σ_z = 1.0,	σ_m = 1.0, σ_p = 1.0,
	itr =itr, T=T, c=c, M=M, print_freq=1, alg =:hmc
);

Traing loss: 0.0010804259589564413 Epoch: 5
Traing loss: 0.0008566155858883028 Epoch: 10
Traing loss: 0.0007709339848581731 Epoch: 15
Traing loss: 0.0013931469772947836 Epoch: 20


### Plot Uncertainty

In [None]:
ns = length(chn)
z = collect(range(-10.0, 10.0,length = 100))
inp = features(z')
trajectories = Array{Float64}(undef,100,ns)
for i in 1:ns
	mn = re(chn[i])
	out = mn(inp)
	trajectories[:, i] = out'
end

mx = maximum(trajectories, dims=2)
mn = minimum(trajectories, dims=2)
val, loc = findmax(lp)
max_log = trajectories[:,loc]
(fig, f_axes) = PyPlot.subplots(ncols=1, nrows=1)
f_axes.scatter(data_ld[:,1],data_ld[:,2], c="red", marker=".")
f_axes.plot(z,vec(mx), c="blue") #maximum values
f_axes.plot(z,vec(mn), c="green") #minimum values
f_axes.plot(z,vec(max_log), c="darkorange") #maximum log probability
f_axes.fill_between(z, vec(mx), vec(mn), alpha=0.5)
f_axes.set_title("Subspace: 20")
fig.show()