In [3]:
%load_ext autoreload
%autoreload 2

import numpy as np
np.random.seed(42)

from src.data_utils import *
from src.yarin_gal_net import net

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
_DATA_FILE = "Data/Concrete_Strength.txt"
X_train, y_train = get_concrete(_DATA_FILE)
X_train, X_train_db, y_train, y_train_db = split_to_create_db(X_train, y_train)

In [5]:
print(f"X_train.shape = {X_train.shape}")
print(f"X_train_db.shape = {X_train_db.shape}")
print(f"y_train.shape = {y_train.shape}")
print(f"y_train_db.shape = {y_train_db.shape}")

X_train.shape = (824, 8)
X_train_db.shape = (206, 8)
y_train.shape = (824,)
y_train_db.shape = (206,)


# Train Concrete Net

In [7]:
n_hidden = 40
num_hidden_layers = 1
load_model=True
concrete_net = net(X_train, y_train,
                   n_hidden=([n_hidden] * num_hidden_layers),
                   n_epochs=4000,
                   normalize=True,
                   tau=0.05,
                   dropout=0.005,
                   load_model=load_model
                  )

if load_model:
    concrete_net.load_model("Assets/concrete_net_4000_epochs.h5")
else:
    concrete_net.save_model("Assets/concrete_net_4000_epochs")

In [9]:
rmse_standard_pred, mc_rmse, test_ll, y_mc = concrete_net.predict(X_train_db, y_train_db, T=1000)
print("rmse_standard_pred, mc_rmse, test_ll:", rmse_standard_pred, mc_rmse, test_ll)


rmse_standard_pred, mc_rmse, test_ll: 10.378569616762983 9.99389982007692 -4.170848079407234


# Create DB

In [13]:
import joblib
from tqdm import tqdm
def create_db(net, X_train_db, y_train_db, max_mc_iters):
    db = {}
    for t in tqdm(range(2, max_mc_iters+1)):
        rmse_standard_pred, mc_rmse, test_ll, y_mc = net.predict(X_train_db, y_train_db, T=t)
        db[t] = (rmse_standard_pred, mc_rmse, test_ll, y_mc)    
    joblib.dump(db, f"Assets/db_concrete_{max_mc_iters}_iters.jblib", compress=True)
    print(f"db_{max_mc_iters}_iters created")

In [14]:
ans = create_db(concrete_net, X_train_db, y_train_db, 10000)

100%|████████████████| 9999/9999 [8:06:13<00:00,  2.92s/it]


db_10000_iters created


In [12]:
%%time
db = joblib.load("Assets/db_concrete_10000_iters.jblib")

Wall time: 970 µs
