In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pdb
from tqdm import tqdm
from collections import OrderedDict
import copy

import numpy as np
import random
import matplotlib.pyplot as plt
%matplotlib inline

import sys
sys.path.insert(0, './src')
from data import Dataset
from models import HMC_vanilla

In [2]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [3]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [4]:
args = dotdict({})

args['device'] = device
args['torchType'] = torch.float32
args['dataset_name'] = 'boston_housing'
if args['dataset_name'].find('mnist') > -1:
    args['num_epoches'] = 201
    args['print_info'] = 50
    args['n_IS'] = 10000
    
    args['train_batch_size'] = 100
    args['val_dataset'] = 10000
    args['val_batch_size'] = 100
    args['test_batch_size'] = 100
else:
    args['n_IS'] = 1000
    args['num_epoches'] = 10001
    args['print_info'] = 1000
    args['train_batch_size'] = 100
    args['val_dataset'] = 100
    args['val_batch_size'] = 20
    args['test_batch_size'] = 10

In [5]:
dataset = Dataset(args)
args.in_features = dataset.in_features[0]
dataset.in_features[0]

  self.df = pd.read_table(data_path, names=self.column_names, header=None, delim_whitespace=True)


Train data shape 406


13

In [6]:
X_train = torch.tensor(dataset.x_train, device=device, dtype=torch.float32)
y_train = torch.tensor(dataset.y_train, device=device, dtype=torch.float32)

X_val = torch.tensor(dataset.x_val, device=device, dtype=torch.float32)
y_val = torch.tensor(dataset.y_val, device=device, dtype=torch.float32)

In [7]:
def load_parameters(model, param_vector, params_shapes):
    left_boader = 0
    tensors = []
    for shape in params_shapes:
        cur_len = np.prod(shape)
        tensors.append(param_vector[left_boader: left_boader+cur_len].view(*shape))
        left_boader += cur_len
    d = OrderedDict()
    model_clone = copy.deepcopy(model)
    for i, name in enumerate(model_clone.named_parameters()):
        d[name[0]] = tensors[i]
    model_clone.load_state_dict(d)
    model_clone.to(model.device)
    return model_clone

In [8]:
class Simple_model(nn.Module):
    def __init__(self, args):
        super(Simple_model, self).__init__()
        in_features = args.in_features
        self.device = args.device
        self.l1 = nn.Linear(in_features, 3*in_features)
        self.l2 = nn.Linear(3*in_features, 1)
        self.activation = nn.Softplus()
        self.std_normal = torch.distributions.Normal(loc=torch.tensor(0., dtype=torch.float32, device=self.device),
                                                     scale=torch.tensor(1., dtype=torch.float32, device=self.device))
        
    def forward(self, x):
        h = self.activation(self.l1(x))
        h = self.l2(h)
        return h
    
    def get_logdensity(self, x, z):
        '''
        x - data vector
        z - vector of all parameters
        '''
#         pdb.set_trace()
        log_prior = self.std_normal.log_prob(z).sum()
        loaded_model = load_parameters(self, z, params_shapes)
        y_pred = loaded_model(x)
        log_likelihood = torch.distributions.Normal(loc=y_train,
                                                     scale=torch.tensor(1., dtype=torch.float32, device=self.device)).log_prob(y_pred).sum()
#         print('log_likelihood', log_likelihood)
        logdensity = log_likelihood + log_prior
        return logdensity

In [9]:
model = Simple_model(args)

In [10]:
params_shapes = [p.shape for p in list(model.parameters())]
print(params_shapes)
overall_params = sum([np.prod(el) for el in params_shapes])
print(f'Overall number of parameters is {overall_params}')

[torch.Size([39, 13]), torch.Size([39]), torch.Size([1, 39]), torch.Size([1])]
Overall number of parameters is 586


In [11]:
initial_parameters = torch.randn((overall_params, 1), dtype=torch.float32, device=device)

In [12]:
L = 100000 # num of hmc transitions
args.N = 3
args.alpha = 0.5
args.gamma = 0.001
args.use_partialref = False
args.use_barker = False


hmc_transition = HMC_vanilla(args)   

In [13]:
q_old = initial_parameters
q_new = q_old
a_list = []

model_new = copy.deepcopy(load_parameters(model, q_old, params_shapes))

for l in tqdm(range(L)):
    q_new, _, _, _, a = hmc_transition.make_transition(q_old=q_new, p_old=None, target_distr=model_new, x=X_train)
    a_list.append(a.cpu().numpy())
    model_new = copy.deepcopy(load_parameters(model_new, q_new, params_shapes))
    if l % 500 == 0:
        print(f'Mean acceptance rate {np.mean(a_list)}')
        a_list = []
        y_pred = model_new(X_train)
        print(f"At L={l} MSE on train is {((y_pred.squeeze() - y_train.squeeze())**2).mean()}")
        
        y_pred = model_new(X_val)
        print(f"At L={l} MSE on validation is {((y_pred.squeeze() - y_val.squeeze())**2).mean()}")

  0%|          | 1/100000 [00:00<4:34:02,  6.08it/s]

Mean acceptance rate 0.0
At L=0 MSE on train is 5095579.0
At L=0 MSE on validation is 5332677.5


  1%|          | 520/100000 [00:04<13:43, 120.86it/s]

Mean acceptance rate 0.47200000286102295
At L=500 MSE on train is 321977.21875
At L=500 MSE on validation is 259424.3125


  1%|          | 1014/100000 [00:08<13:38, 120.88it/s]

Mean acceptance rate 0.42399999499320984
At L=1000 MSE on train is 7659.78173828125
At L=1000 MSE on validation is 8970.2626953125


  2%|▏         | 1521/100000 [00:12<13:34, 120.96it/s]

Mean acceptance rate 0.2160000056028366
At L=1500 MSE on train is 2189.834228515625
At L=1500 MSE on validation is 2945.877685546875


  2%|▏         | 2015/100000 [00:16<13:22, 122.14it/s]

Mean acceptance rate 0.09000000357627869
At L=2000 MSE on train is 1475.8533935546875
At L=2000 MSE on validation is 1977.3248291015625


  3%|▎         | 2522/100000 [00:20<13:27, 120.77it/s]

Mean acceptance rate 0.057999998331069946
At L=2500 MSE on train is 1253.603271484375
At L=2500 MSE on validation is 1595.167724609375


  3%|▎         | 3016/100000 [00:25<13:23, 120.65it/s]

Mean acceptance rate 0.041999999433755875
At L=3000 MSE on train is 1108.5826416015625
At L=3000 MSE on validation is 1573.8118896484375


  4%|▎         | 3523/100000 [00:29<13:01, 123.49it/s]

Mean acceptance rate 0.04600000008940697
At L=3500 MSE on train is 1006.2835693359375
At L=3500 MSE on validation is 1368.068603515625


  4%|▍         | 4017/100000 [00:33<12:55, 123.79it/s]

Mean acceptance rate 0.01600000075995922
At L=4000 MSE on train is 971.5740356445312
At L=4000 MSE on validation is 1346.2244873046875


  5%|▍         | 4524/100000 [00:37<12:51, 123.70it/s]

Mean acceptance rate 0.03200000151991844
At L=4500 MSE on train is 909.796142578125
At L=4500 MSE on validation is 1308.575927734375


  5%|▌         | 5018/100000 [00:41<12:46, 123.96it/s]

Mean acceptance rate 0.02199999988079071
At L=5000 MSE on train is 857.9004516601562
At L=5000 MSE on validation is 1167.844482421875


  6%|▌         | 5525/100000 [00:45<12:40, 124.22it/s]

Mean acceptance rate 0.03200000151991844
At L=5500 MSE on train is 777.5593872070312
At L=5500 MSE on validation is 1101.4154052734375


  6%|▌         | 6019/100000 [00:49<12:37, 124.03it/s]

Mean acceptance rate 0.03400000184774399
At L=6000 MSE on train is 701.91259765625
At L=6000 MSE on validation is 944.8441772460938


  7%|▋         | 6513/100000 [00:53<12:34, 123.84it/s]

Mean acceptance rate 0.019999999552965164
At L=6500 MSE on train is 668.30322265625
At L=6500 MSE on validation is 920.4038696289062


  7%|▋         | 7020/100000 [00:57<12:34, 123.26it/s]

Mean acceptance rate 0.004000000189989805
At L=7000 MSE on train is 658.8895874023438
At L=7000 MSE on validation is 910.2744140625


  8%|▊         | 7514/100000 [01:01<12:26, 123.81it/s]

Mean acceptance rate 0.0020000000949949026
At L=7500 MSE on train is 658.2933959960938
At L=7500 MSE on validation is 920.3516235351562


  8%|▊         | 8021/100000 [01:05<12:31, 122.41it/s]

Mean acceptance rate 0.017999999225139618
At L=8000 MSE on train is 625.3618774414062
At L=8000 MSE on validation is 863.683349609375


  9%|▊         | 8515/100000 [01:09<12:28, 122.30it/s]

Mean acceptance rate 0.009999999776482582
At L=8500 MSE on train is 604.39501953125
At L=8500 MSE on validation is 857.9099731445312


  9%|▉         | 9022/100000 [01:13<12:24, 122.24it/s]

Mean acceptance rate 0.017999999225139618
At L=9000 MSE on train is 583.7477416992188
At L=9000 MSE on validation is 826.4172973632812


 10%|▉         | 9516/100000 [01:17<12:29, 120.79it/s]

Mean acceptance rate 0.00800000037997961
At L=9500 MSE on train is 575.03955078125
At L=9500 MSE on validation is 810.7437133789062


 10%|█         | 10023/100000 [01:21<12:07, 123.75it/s]

Mean acceptance rate 0.004000000189989805
At L=10000 MSE on train is 571.5501098632812
At L=10000 MSE on validation is 812.6036376953125


 11%|█         | 10517/100000 [01:25<12:02, 123.80it/s]

Mean acceptance rate 0.009999999776482582
At L=10500 MSE on train is 557.8035888671875
At L=10500 MSE on validation is 800.8729248046875


 11%|█         | 11024/100000 [01:30<12:04, 122.74it/s]

Mean acceptance rate 0.012000000104308128
At L=11000 MSE on train is 542.986083984375
At L=11000 MSE on validation is 788.6560668945312


 12%|█▏        | 11518/100000 [01:34<12:06, 121.75it/s]

Mean acceptance rate 0.006000000052154064
At L=11500 MSE on train is 539.0338134765625
At L=11500 MSE on validation is 792.1278076171875


 12%|█▏        | 12025/100000 [01:38<12:12, 120.14it/s]

Mean acceptance rate 0.009999999776482582
At L=12000 MSE on train is 533.0916748046875
At L=12000 MSE on validation is 770.9227905273438


 13%|█▎        | 12519/100000 [01:42<12:05, 120.51it/s]

Mean acceptance rate 0.02199999988079071
At L=12500 MSE on train is 502.1614685058594
At L=12500 MSE on validation is 728.0952758789062


 13%|█▎        | 13013/100000 [01:46<11:43, 123.63it/s]

Mean acceptance rate 0.014000000432133675
At L=13000 MSE on train is 489.8977355957031
At L=13000 MSE on validation is 728.1110229492188


 14%|█▎        | 13520/100000 [01:50<11:56, 120.71it/s]

Mean acceptance rate 0.01600000075995922
At L=13500 MSE on train is 474.5532531738281
At L=13500 MSE on validation is 704.713134765625


 14%|█▍        | 14014/100000 [01:54<11:34, 123.87it/s]

Mean acceptance rate 0.006000000052154064
At L=14000 MSE on train is 465.29974365234375
At L=14000 MSE on validation is 672.5263671875


 15%|█▍        | 14521/100000 [01:58<11:38, 122.46it/s]

Mean acceptance rate 0.006000000052154064
At L=14500 MSE on train is 455.8808898925781
At L=14500 MSE on validation is 679.107666015625


 15%|█▌        | 15015/100000 [02:02<11:35, 122.16it/s]

Mean acceptance rate 0.0
At L=15000 MSE on train is 455.8808898925781
At L=15000 MSE on validation is 679.107666015625


 16%|█▌        | 15522/100000 [02:06<11:37, 121.12it/s]

Mean acceptance rate 0.0
At L=15500 MSE on train is 455.8808898925781
At L=15500 MSE on validation is 679.107666015625


 16%|█▌        | 16016/100000 [02:11<11:21, 123.23it/s]

Mean acceptance rate 0.004000000189989805
At L=16000 MSE on train is 448.36749267578125
At L=16000 MSE on validation is 680.993896484375


 17%|█▋        | 16523/100000 [02:15<11:15, 123.58it/s]

Mean acceptance rate 0.009999999776482582
At L=16500 MSE on train is 439.0367736816406
At L=16500 MSE on validation is 677.0473022460938


 17%|█▋        | 17017/100000 [02:19<11:09, 123.86it/s]

Mean acceptance rate 0.014000000432133675
At L=17000 MSE on train is 420.5154724121094
At L=17000 MSE on validation is 635.3738403320312


 18%|█▊        | 17524/100000 [02:23<11:06, 123.72it/s]

Mean acceptance rate 0.017999999225139618
At L=17500 MSE on train is 392.2760009765625
At L=17500 MSE on validation is 606.3984375


 18%|█▊        | 18018/100000 [02:27<10:44, 127.26it/s]

Mean acceptance rate 0.006000000052154064
At L=18000 MSE on train is 388.2678527832031
At L=18000 MSE on validation is 599.117431640625


 19%|█▊        | 18525/100000 [02:31<10:45, 126.20it/s]

Mean acceptance rate 0.009999999776482582
At L=18500 MSE on train is 379.7176208496094
At L=18500 MSE on validation is 586.75732421875


 19%|█▉        | 19019/100000 [02:34<10:52, 124.08it/s]

Mean acceptance rate 0.014000000432133675
At L=19000 MSE on train is 368.1407775878906
At L=19000 MSE on validation is 535.157958984375


 20%|█▉        | 19526/100000 [02:39<10:34, 126.77it/s]

Mean acceptance rate 0.01600000075995922
At L=19500 MSE on train is 345.57623291015625
At L=19500 MSE on validation is 511.4585876464844


 20%|██        | 20020/100000 [02:43<10:53, 122.48it/s]

Mean acceptance rate 0.00800000037997961
At L=20000 MSE on train is 333.91363525390625
At L=20000 MSE on validation is 501.8994445800781


 21%|██        | 20514/100000 [02:47<10:48, 122.66it/s]

Mean acceptance rate 0.009999999776482582
At L=20500 MSE on train is 323.3324279785156
At L=20500 MSE on validation is 477.6851806640625


 21%|██        | 21021/100000 [02:51<10:29, 125.56it/s]

Mean acceptance rate 0.009999999776482582
At L=21000 MSE on train is 318.4874572753906
At L=21000 MSE on validation is 473.21630859375


 22%|██▏       | 21515/100000 [02:55<10:25, 125.49it/s]

Mean acceptance rate 0.012000000104308128
At L=21500 MSE on train is 304.5889587402344
At L=21500 MSE on validation is 456.097412109375


 22%|██▏       | 22022/100000 [02:59<10:22, 125.33it/s]

Mean acceptance rate 0.012000000104308128
At L=22000 MSE on train is 299.0369873046875
At L=22000 MSE on validation is 433.28521728515625


 23%|██▎       | 22516/100000 [03:03<10:16, 125.61it/s]

Mean acceptance rate 0.017999999225139618
At L=22500 MSE on train is 289.4076843261719
At L=22500 MSE on validation is 392.26129150390625


 23%|██▎       | 23023/100000 [03:07<10:14, 125.33it/s]

Mean acceptance rate 0.02199999988079071
At L=23000 MSE on train is 257.985595703125
At L=23000 MSE on validation is 402.5082092285156


 24%|██▎       | 23517/100000 [03:11<10:09, 125.48it/s]

Mean acceptance rate 0.006000000052154064
At L=23500 MSE on train is 255.81362915039062
At L=23500 MSE on validation is 400.0132751464844


 24%|██▍       | 24024/100000 [03:15<10:20, 122.39it/s]

Mean acceptance rate 0.0
At L=24000 MSE on train is 255.81362915039062
At L=24000 MSE on validation is 400.0132751464844


 25%|██▍       | 24518/100000 [03:19<10:17, 122.28it/s]

Mean acceptance rate 0.019999999552965164
At L=24500 MSE on train is 247.2977294921875
At L=24500 MSE on validation is 373.9684143066406


 25%|██▌       | 25025/100000 [03:23<10:12, 122.49it/s]

Mean acceptance rate 0.0020000000949949026
At L=25000 MSE on train is 247.1117706298828
At L=25000 MSE on validation is 368.3445129394531


 26%|██▌       | 25519/100000 [03:27<09:59, 124.20it/s]

Mean acceptance rate 0.004000000189989805
At L=25500 MSE on train is 246.76425170898438
At L=25500 MSE on validation is 366.72552490234375


 26%|██▌       | 26013/100000 [03:31<10:12, 120.82it/s]

Mean acceptance rate 0.009999999776482582
At L=26000 MSE on train is 235.22198486328125
At L=26000 MSE on validation is 355.5718688964844


 27%|██▋       | 26519/100000 [03:35<10:10, 120.29it/s]

Mean acceptance rate 0.0020000000949949026
At L=26500 MSE on train is 232.99191284179688
At L=26500 MSE on validation is 357.8990783691406


 27%|██▋       | 27013/100000 [03:39<09:58, 121.96it/s]

Mean acceptance rate 0.0020000000949949026
At L=27000 MSE on train is 232.87384033203125
At L=27000 MSE on validation is 362.6482238769531


 28%|██▊       | 27520/100000 [03:43<10:02, 120.33it/s]

Mean acceptance rate 0.004000000189989805
At L=27500 MSE on train is 231.9839324951172
At L=27500 MSE on validation is 375.00213623046875


 28%|██▊       | 28014/100000 [03:47<09:42, 123.52it/s]

Mean acceptance rate 0.012000000104308128
At L=28000 MSE on train is 223.15536499023438
At L=28000 MSE on validation is 343.9986572265625


 29%|██▊       | 28521/100000 [03:51<09:37, 123.74it/s]

Mean acceptance rate 0.0020000000949949026
At L=28500 MSE on train is 221.72666931152344
At L=28500 MSE on validation is 351.566162109375


 29%|██▉       | 29015/100000 [03:55<09:39, 122.51it/s]

Mean acceptance rate 0.00800000037997961
At L=29000 MSE on train is 215.40965270996094
At L=29000 MSE on validation is 351.3767395019531


 30%|██▉       | 29522/100000 [04:00<09:36, 122.24it/s]

Mean acceptance rate 0.004000000189989805
At L=29500 MSE on train is 213.82473754882812
At L=29500 MSE on validation is 352.60211181640625


 30%|███       | 30016/100000 [04:04<09:30, 122.67it/s]

Mean acceptance rate 0.0
At L=30000 MSE on train is 213.82473754882812
At L=30000 MSE on validation is 352.60211181640625


 31%|███       | 30523/100000 [04:08<09:12, 125.81it/s]

Mean acceptance rate 0.012000000104308128
At L=30500 MSE on train is 206.07919311523438
At L=30500 MSE on validation is 323.9244384765625


 31%|███       | 31017/100000 [04:12<09:09, 125.59it/s]

Mean acceptance rate 0.0020000000949949026
At L=31000 MSE on train is 205.713623046875
At L=31000 MSE on validation is 323.60394287109375


 32%|███▏      | 31524/100000 [04:16<09:05, 125.45it/s]

Mean acceptance rate 0.004000000189989805
At L=31500 MSE on train is 202.37820434570312
At L=31500 MSE on validation is 338.57318115234375


 32%|███▏      | 32018/100000 [04:20<09:02, 125.29it/s]

Mean acceptance rate 0.0020000000949949026
At L=32000 MSE on train is 201.62164306640625
At L=32000 MSE on validation is 335.01239013671875


 33%|███▎      | 32525/100000 [04:24<08:56, 125.71it/s]

Mean acceptance rate 0.004000000189989805
At L=32500 MSE on train is 200.96817016601562
At L=32500 MSE on validation is 337.8496398925781


 33%|███▎      | 33019/100000 [04:28<08:53, 125.57it/s]

Mean acceptance rate 0.0
At L=33000 MSE on train is 200.96817016601562
At L=33000 MSE on validation is 337.8496398925781


 34%|███▎      | 33526/100000 [04:32<08:49, 125.53it/s]

Mean acceptance rate 0.0
At L=33500 MSE on train is 200.96817016601562
At L=33500 MSE on validation is 337.8496398925781


 34%|███▍      | 34020/100000 [04:36<08:45, 125.45it/s]

Mean acceptance rate 0.004000000189989805
At L=34000 MSE on train is 198.84909057617188
At L=34000 MSE on validation is 323.76177978515625


 35%|███▍      | 34514/100000 [04:39<08:41, 125.50it/s]

Mean acceptance rate 0.0020000000949949026
At L=34500 MSE on train is 196.53713989257812
At L=34500 MSE on validation is 324.1536560058594


 35%|███▌      | 35021/100000 [04:43<08:36, 125.77it/s]

Mean acceptance rate 0.006000000052154064
At L=35000 MSE on train is 195.645751953125
At L=35000 MSE on validation is 311.44439697265625


 36%|███▌      | 35515/100000 [04:47<08:33, 125.64it/s]

Mean acceptance rate 0.00800000037997961
At L=35500 MSE on train is 194.81118774414062
At L=35500 MSE on validation is 317.7308044433594


 36%|███▌      | 36022/100000 [04:51<08:29, 125.56it/s]

Mean acceptance rate 0.004000000189989805
At L=36000 MSE on train is 192.23414611816406
At L=36000 MSE on validation is 317.162109375


 37%|███▋      | 36516/100000 [04:56<08:45, 120.70it/s]

Mean acceptance rate 0.009999999776482582
At L=36500 MSE on train is 187.61228942871094
At L=36500 MSE on validation is 308.0750427246094


 37%|███▋      | 37023/100000 [05:00<08:42, 120.59it/s]

Mean acceptance rate 0.0
At L=37000 MSE on train is 187.61228942871094
At L=37000 MSE on validation is 308.0750427246094


 38%|███▊      | 37517/100000 [05:04<08:25, 123.55it/s]

Mean acceptance rate 0.0020000000949949026
At L=37500 MSE on train is 186.39291381835938
At L=37500 MSE on validation is 307.40069580078125


 45%|████▌     | 45017/100000 [06:05<07:35, 120.77it/s]

Mean acceptance rate 0.0020000000949949026
At L=45000 MSE on train is 163.55438232421875
At L=45000 MSE on validation is 255.31564331054688


 46%|████▌     | 45524/100000 [06:09<07:31, 120.78it/s]

Mean acceptance rate 0.0
At L=45500 MSE on train is 163.55438232421875
At L=45500 MSE on validation is 255.31564331054688


 46%|████▌     | 46018/100000 [06:13<07:16, 123.58it/s]

Mean acceptance rate 0.004000000189989805
At L=46000 MSE on train is 162.6342315673828
At L=46000 MSE on validation is 257.6122741699219


 47%|████▋     | 46525/100000 [06:17<07:19, 121.76it/s]

Mean acceptance rate 0.0020000000949949026
At L=46500 MSE on train is 162.38916015625
At L=46500 MSE on validation is 259.32586669921875


 47%|████▋     | 47019/100000 [06:21<07:15, 121.75it/s]

Mean acceptance rate 0.006000000052154064
At L=47000 MSE on train is 159.91217041015625
At L=47000 MSE on validation is 251.8036651611328


 48%|████▊     | 47525/100000 [06:26<07:02, 124.06it/s]

Mean acceptance rate 0.0020000000949949026
At L=47500 MSE on train is 159.89883422851562
At L=47500 MSE on validation is 250.2607421875


 48%|████▊     | 48019/100000 [06:30<07:03, 122.60it/s]

Mean acceptance rate 0.006000000052154064
At L=48000 MSE on train is 158.16136169433594
At L=48000 MSE on validation is 241.04823303222656


 49%|████▊     | 48513/100000 [06:34<07:06, 120.67it/s]

Mean acceptance rate 0.0
At L=48500 MSE on train is 158.16136169433594
At L=48500 MSE on validation is 241.04823303222656


 49%|████▉     | 49020/100000 [06:38<07:01, 121.04it/s]

Mean acceptance rate 0.0020000000949949026
At L=49000 MSE on train is 157.68838500976562
At L=49000 MSE on validation is 248.37307739257812


 50%|████▉     | 49514/100000 [06:42<06:49, 123.32it/s]

Mean acceptance rate 0.006000000052154064
At L=49500 MSE on train is 156.47769165039062
At L=49500 MSE on validation is 241.14767456054688


 50%|█████     | 50021/100000 [06:46<06:54, 120.64it/s]

Mean acceptance rate 0.006000000052154064
At L=50000 MSE on train is 153.57440185546875
At L=50000 MSE on validation is 249.76942443847656


 51%|█████     | 50515/100000 [06:50<06:51, 120.29it/s]

Mean acceptance rate 0.004000000189989805
At L=50500 MSE on train is 151.92269897460938
At L=50500 MSE on validation is 244.98545837402344


 51%|█████     | 51022/100000 [06:54<06:41, 121.93it/s]

Mean acceptance rate 0.0020000000949949026
At L=51000 MSE on train is 151.92649841308594
At L=51000 MSE on validation is 238.3008575439453


 52%|█████▏    | 51516/100000 [06:58<06:32, 123.56it/s]

Mean acceptance rate 0.0020000000949949026
At L=51500 MSE on train is 151.51156616210938
At L=51500 MSE on validation is 237.326171875


 52%|█████▏    | 52023/100000 [07:02<06:37, 120.76it/s]

Mean acceptance rate 0.004000000189989805
At L=52000 MSE on train is 150.67799377441406
At L=52000 MSE on validation is 240.80169677734375


 53%|█████▎    | 52517/100000 [07:07<06:32, 120.91it/s]

Mean acceptance rate 0.0020000000949949026
At L=52500 MSE on train is 149.33978271484375
At L=52500 MSE on validation is 238.2065887451172


 53%|█████▎    | 53024/100000 [07:11<06:27, 121.10it/s]

Mean acceptance rate 0.0020000000949949026
At L=53000 MSE on train is 148.9334716796875
At L=53000 MSE on validation is 231.20968627929688


 54%|█████▎    | 53518/100000 [07:15<06:25, 120.48it/s]

Mean acceptance rate 0.0
At L=53500 MSE on train is 148.9334716796875
At L=53500 MSE on validation is 231.20968627929688


 54%|█████▍    | 54025/100000 [07:19<06:20, 120.76it/s]

Mean acceptance rate 0.004000000189989805
At L=54000 MSE on train is 147.482666015625
At L=54000 MSE on validation is 237.68386840820312


 55%|█████▍    | 54519/100000 [07:23<06:17, 120.51it/s]

Mean acceptance rate 0.0020000000949949026
At L=54500 MSE on train is 146.6747589111328
At L=54500 MSE on validation is 230.3787078857422


 55%|█████▌    | 55013/100000 [07:27<06:08, 122.06it/s]

Mean acceptance rate 0.004000000189989805
At L=55000 MSE on train is 146.2729034423828
At L=55000 MSE on validation is 227.4002227783203


 56%|█████▌    | 55520/100000 [07:31<05:54, 125.49it/s]

Mean acceptance rate 0.0020000000949949026
At L=55500 MSE on train is 145.69692993164062
At L=55500 MSE on validation is 223.95298767089844


 56%|█████▌    | 56014/100000 [07:35<05:57, 122.97it/s]

Mean acceptance rate 0.00800000037997961
At L=56000 MSE on train is 141.89222717285156
At L=56000 MSE on validation is 211.516357421875


 57%|█████▋    | 56521/100000 [07:39<05:55, 122.33it/s]

Mean acceptance rate 0.0020000000949949026
At L=56500 MSE on train is 141.8092498779297
At L=56500 MSE on validation is 219.58409118652344


 57%|█████▋    | 57015/100000 [07:43<05:46, 123.89it/s]

Mean acceptance rate 0.0
At L=57000 MSE on train is 141.8092498779297
At L=57000 MSE on validation is 219.58409118652344


 58%|█████▊    | 57522/100000 [07:47<05:43, 123.77it/s]

Mean acceptance rate 0.0
At L=57500 MSE on train is 141.8092498779297
At L=57500 MSE on validation is 219.58409118652344


 58%|█████▊    | 58016/100000 [07:51<05:39, 123.84it/s]

Mean acceptance rate 0.0020000000949949026
At L=58000 MSE on train is 140.93136596679688
At L=58000 MSE on validation is 215.26437377929688


 59%|█████▊    | 58523/100000 [07:55<05:33, 124.36it/s]

Mean acceptance rate 0.0
At L=58500 MSE on train is 140.93136596679688
At L=58500 MSE on validation is 215.26437377929688


 59%|█████▉    | 59017/100000 [07:59<05:29, 124.33it/s]

Mean acceptance rate 0.0020000000949949026
At L=59000 MSE on train is 139.96351623535156
At L=59000 MSE on validation is 215.320068359375


 60%|█████▉    | 59524/100000 [08:04<05:32, 121.62it/s]

Mean acceptance rate 0.0
At L=59500 MSE on train is 139.96351623535156
At L=59500 MSE on validation is 215.320068359375


 60%|██████    | 60018/100000 [08:08<05:32, 120.25it/s]

Mean acceptance rate 0.0020000000949949026
At L=60000 MSE on train is 138.76605224609375
At L=60000 MSE on validation is 209.40415954589844


 61%|██████    | 60525/100000 [08:12<05:28, 120.32it/s]

Mean acceptance rate 0.004000000189989805
At L=60500 MSE on train is 136.86695861816406
At L=60500 MSE on validation is 211.1036834716797


 61%|██████    | 61017/100000 [08:16<05:23, 120.33it/s]

Mean acceptance rate 0.0
At L=61000 MSE on train is 136.86695861816406
At L=61000 MSE on validation is 211.1036834716797


 62%|██████▏   | 61524/100000 [08:20<05:20, 120.22it/s]

Mean acceptance rate 0.0020000000949949026
At L=61500 MSE on train is 136.24368286132812
At L=61500 MSE on validation is 216.27671813964844


 62%|██████▏   | 62017/100000 [08:24<05:04, 124.64it/s]

Mean acceptance rate 0.004000000189989805
At L=62000 MSE on train is 135.75503540039062
At L=62000 MSE on validation is 211.21046447753906


 63%|██████▎   | 62524/100000 [08:28<04:58, 125.56it/s]

Mean acceptance rate 0.0020000000949949026
At L=62500 MSE on train is 135.6682586669922
At L=62500 MSE on validation is 212.39686584472656


 63%|██████▎   | 63018/100000 [08:32<04:54, 125.43it/s]

Mean acceptance rate 0.0
At L=63000 MSE on train is 135.6682586669922
At L=63000 MSE on validation is 212.39686584472656


 64%|██████▎   | 63525/100000 [08:36<04:50, 125.54it/s]

Mean acceptance rate 0.004000000189989805
At L=63500 MSE on train is 133.71458435058594
At L=63500 MSE on validation is 203.8143768310547


 64%|██████▍   | 64019/100000 [08:40<04:46, 125.44it/s]

Mean acceptance rate 0.0020000000949949026
At L=64000 MSE on train is 132.94114685058594
At L=64000 MSE on validation is 200.50831604003906


 65%|██████▍   | 64526/100000 [08:44<04:42, 125.74it/s]

Mean acceptance rate 0.0
At L=64500 MSE on train is 132.94114685058594
At L=64500 MSE on validation is 200.50831604003906


 65%|██████▌   | 65020/100000 [08:48<04:49, 120.89it/s]

Mean acceptance rate 0.0
At L=65000 MSE on train is 132.94114685058594
At L=65000 MSE on validation is 200.50831604003906


 66%|██████▌   | 65514/100000 [08:52<04:42, 122.00it/s]

Mean acceptance rate 0.0
At L=65500 MSE on train is 132.94114685058594
At L=65500 MSE on validation is 200.50831604003906


 66%|██████▌   | 66021/100000 [08:56<04:37, 122.56it/s]

Mean acceptance rate 0.004000000189989805
At L=66000 MSE on train is 132.41305541992188
At L=66000 MSE on validation is 209.31788635253906


 67%|██████▋   | 66515/100000 [09:00<04:33, 122.26it/s]

Mean acceptance rate 0.0
At L=66500 MSE on train is 132.41305541992188
At L=66500 MSE on validation is 209.31788635253906


 67%|██████▋   | 67022/100000 [09:05<04:30, 122.01it/s]

Mean acceptance rate 0.0
At L=67000 MSE on train is 132.41305541992188
At L=67000 MSE on validation is 209.31788635253906


 68%|██████▊   | 67516/100000 [09:09<04:25, 122.55it/s]

Mean acceptance rate 0.0020000000949949026
At L=67500 MSE on train is 131.63192749023438
At L=67500 MSE on validation is 202.6354217529297


 68%|██████▊   | 68023/100000 [09:13<04:21, 122.06it/s]

Mean acceptance rate 0.0
At L=68000 MSE on train is 131.63192749023438
At L=68000 MSE on validation is 202.6354217529297


 69%|██████▊   | 68516/100000 [09:17<04:21, 120.23it/s]

Mean acceptance rate 0.0
At L=68500 MSE on train is 131.63192749023438
At L=68500 MSE on validation is 202.6354217529297


 69%|██████▉   | 69023/100000 [09:21<04:17, 120.35it/s]

Mean acceptance rate 0.0
At L=69000 MSE on train is 131.63192749023438
At L=69000 MSE on validation is 202.6354217529297


 70%|██████▉   | 69516/100000 [09:25<04:10, 121.55it/s]

Mean acceptance rate 0.0
At L=69500 MSE on train is 131.63192749023438
At L=69500 MSE on validation is 202.6354217529297


 70%|███████   | 70023/100000 [09:29<04:07, 121.29it/s]

Mean acceptance rate 0.0
At L=70000 MSE on train is 131.63192749023438
At L=70000 MSE on validation is 202.6354217529297


 71%|███████   | 70517/100000 [09:33<03:54, 125.49it/s]

Mean acceptance rate 0.0
At L=70500 MSE on train is 131.63192749023438
At L=70500 MSE on validation is 202.6354217529297


 71%|███████   | 71024/100000 [09:37<03:51, 125.41it/s]

Mean acceptance rate 0.004000000189989805
At L=71000 MSE on train is 131.417236328125
At L=71000 MSE on validation is 203.238525390625


 72%|███████▏  | 71518/100000 [09:41<03:46, 125.48it/s]

Mean acceptance rate 0.004000000189989805
At L=71500 MSE on train is 130.2141876220703
At L=71500 MSE on validation is 199.6804656982422


 72%|███████▏  | 72025/100000 [09:45<03:49, 121.79it/s]

Mean acceptance rate 0.0020000000949949026
At L=72000 MSE on train is 128.7960968017578
At L=72000 MSE on validation is 196.40292358398438


 73%|███████▎  | 72519/100000 [09:49<03:45, 121.84it/s]

Mean acceptance rate 0.0
At L=72500 MSE on train is 128.7960968017578
At L=72500 MSE on validation is 196.40292358398438


 73%|███████▎  | 73013/100000 [09:53<03:38, 123.62it/s]

Mean acceptance rate 0.0
At L=73000 MSE on train is 128.7960968017578
At L=73000 MSE on validation is 196.40292358398438


 74%|███████▎  | 73520/100000 [09:57<03:34, 123.68it/s]

Mean acceptance rate 0.0
At L=73500 MSE on train is 128.7960968017578
At L=73500 MSE on validation is 196.40292358398438


 74%|███████▍  | 74014/100000 [10:01<03:32, 122.30it/s]

Mean acceptance rate 0.0
At L=74000 MSE on train is 128.7960968017578
At L=74000 MSE on validation is 196.40292358398438


 75%|███████▍  | 74521/100000 [10:06<03:27, 122.55it/s]

Mean acceptance rate 0.0
At L=74500 MSE on train is 128.7960968017578
At L=74500 MSE on validation is 196.40292358398438


 75%|███████▌  | 75015/100000 [10:10<03:24, 122.25it/s]

Mean acceptance rate 0.0
At L=75000 MSE on train is 128.7960968017578
At L=75000 MSE on validation is 196.40292358398438


 76%|███████▌  | 75522/100000 [10:14<03:22, 120.85it/s]

Mean acceptance rate 0.004000000189989805
At L=75500 MSE on train is 128.52272033691406
At L=75500 MSE on validation is 191.60116577148438


 76%|███████▌  | 76016/100000 [10:18<03:17, 121.34it/s]

Mean acceptance rate 0.0
At L=76000 MSE on train is 128.52272033691406
At L=76000 MSE on validation is 191.60116577148438


 77%|███████▋  | 76523/100000 [10:22<03:10, 123.33it/s]

Mean acceptance rate 0.0
At L=76500 MSE on train is 128.52272033691406
At L=76500 MSE on validation is 191.60116577148438


 77%|███████▋  | 77017/100000 [10:26<03:10, 120.83it/s]

Mean acceptance rate 0.0
At L=77000 MSE on train is 128.52272033691406
At L=77000 MSE on validation is 191.60116577148438


 78%|███████▊  | 77524/100000 [10:30<03:06, 120.68it/s]

Mean acceptance rate 0.0
At L=77500 MSE on train is 128.52272033691406
At L=77500 MSE on validation is 191.60116577148438


 78%|███████▊  | 78018/100000 [10:34<03:02, 120.66it/s]

Mean acceptance rate 0.0
At L=78000 MSE on train is 128.52272033691406
At L=78000 MSE on validation is 191.60116577148438


 79%|███████▊  | 78525/100000 [10:38<02:53, 123.73it/s]

Mean acceptance rate 0.0
At L=78500 MSE on train is 128.52272033691406
At L=78500 MSE on validation is 191.60116577148438


 79%|███████▉  | 79019/100000 [10:42<02:49, 123.80it/s]

Mean acceptance rate 0.0
At L=79000 MSE on train is 128.52272033691406
At L=79000 MSE on validation is 191.60116577148438


 80%|███████▉  | 79513/100000 [10:46<02:48, 121.78it/s]

Mean acceptance rate 0.0
At L=79500 MSE on train is 128.52272033691406
At L=79500 MSE on validation is 191.60116577148438


 80%|████████  | 80020/100000 [10:51<02:41, 123.71it/s]

Mean acceptance rate 0.0
At L=80000 MSE on train is 128.52272033691406
At L=80000 MSE on validation is 191.60116577148438


 81%|████████  | 80514/100000 [10:55<02:37, 123.81it/s]

Mean acceptance rate 0.0
At L=80500 MSE on train is 128.52272033691406
At L=80500 MSE on validation is 191.60116577148438


 81%|████████  | 81021/100000 [10:59<02:33, 123.70it/s]

Mean acceptance rate 0.0
At L=81000 MSE on train is 128.52272033691406
At L=81000 MSE on validation is 191.60116577148438


 82%|████████▏ | 81515/100000 [11:03<02:31, 122.20it/s]

Mean acceptance rate 0.0
At L=81500 MSE on train is 128.52272033691406
At L=81500 MSE on validation is 191.60116577148438


 82%|████████▏ | 82022/100000 [11:07<02:26, 122.36it/s]

Mean acceptance rate 0.0
At L=82000 MSE on train is 128.52272033691406
At L=82000 MSE on validation is 191.60116577148438


 83%|████████▎ | 82516/100000 [11:11<02:22, 122.56it/s]

Mean acceptance rate 0.0020000000949949026
At L=82500 MSE on train is 128.52084350585938
At L=82500 MSE on validation is 179.34811401367188


 83%|████████▎ | 83022/100000 [11:15<02:21, 120.19it/s]

Mean acceptance rate 0.0020000000949949026
At L=83000 MSE on train is 128.18824768066406
At L=83000 MSE on validation is 190.2625274658203


 84%|████████▎ | 83516/100000 [11:19<02:12, 123.99it/s]

Mean acceptance rate 0.0020000000949949026
At L=83500 MSE on train is 127.86314392089844
At L=83500 MSE on validation is 188.88916015625


 84%|████████▍ | 84023/100000 [11:23<02:09, 123.64it/s]

Mean acceptance rate 0.0020000000949949026
At L=84000 MSE on train is 126.73289489746094
At L=84000 MSE on validation is 196.5257110595703


 85%|████████▍ | 84517/100000 [11:27<02:05, 123.69it/s]

Mean acceptance rate 0.0020000000949949026
At L=84500 MSE on train is 126.69186401367188
At L=84500 MSE on validation is 194.022216796875


 85%|████████▌ | 85024/100000 [11:31<02:01, 123.57it/s]

Mean acceptance rate 0.0
At L=85000 MSE on train is 126.69186401367188
At L=85000 MSE on validation is 194.022216796875


 86%|████████▌ | 85518/100000 [11:35<01:54, 126.50it/s]

Mean acceptance rate 0.0
At L=85500 MSE on train is 126.69186401367188
At L=85500 MSE on validation is 194.022216796875


 86%|████████▌ | 86025/100000 [11:39<01:53, 123.51it/s]

Mean acceptance rate 0.0020000000949949026
At L=86000 MSE on train is 125.81099700927734
At L=86000 MSE on validation is 195.156982421875


 87%|████████▋ | 86519/100000 [11:43<01:49, 123.21it/s]

Mean acceptance rate 0.0
At L=86500 MSE on train is 125.81099700927734
At L=86500 MSE on validation is 195.156982421875


 87%|████████▋ | 87013/100000 [11:47<01:45, 123.06it/s]

Mean acceptance rate 0.0
At L=87000 MSE on train is 125.81099700927734
At L=87000 MSE on validation is 195.156982421875


 88%|████████▊ | 87520/100000 [11:51<01:41, 123.05it/s]

Mean acceptance rate 0.0
At L=87500 MSE on train is 125.81099700927734
At L=87500 MSE on validation is 195.156982421875


 88%|████████▊ | 88014/100000 [11:55<01:34, 126.88it/s]

Mean acceptance rate 0.0
At L=88000 MSE on train is 125.81099700927734
At L=88000 MSE on validation is 195.156982421875


 89%|████████▊ | 88521/100000 [11:59<01:33, 123.20it/s]

Mean acceptance rate 0.0020000000949949026
At L=88500 MSE on train is 125.29684448242188
At L=88500 MSE on validation is 194.91163635253906


 89%|████████▉ | 89015/100000 [12:03<01:29, 123.26it/s]

Mean acceptance rate 0.0020000000949949026
At L=89000 MSE on train is 125.12390899658203
At L=89000 MSE on validation is 191.71788024902344


 90%|████████▉ | 89522/100000 [12:08<01:24, 124.03it/s]

Mean acceptance rate 0.0
At L=89500 MSE on train is 125.12390899658203
At L=89500 MSE on validation is 191.71788024902344


 90%|█████████ | 90016/100000 [12:12<01:22, 121.02it/s]

Mean acceptance rate 0.0020000000949949026
At L=90000 MSE on train is 124.6076431274414
At L=90000 MSE on validation is 198.19293212890625


 91%|█████████ | 90522/100000 [12:16<01:19, 119.85it/s]

Mean acceptance rate 0.004000000189989805
At L=90500 MSE on train is 123.51329803466797
At L=90500 MSE on validation is 191.47727966308594


 91%|█████████ | 91015/100000 [12:20<01:12, 123.47it/s]

Mean acceptance rate 0.0
At L=91000 MSE on train is 123.51329803466797
At L=91000 MSE on validation is 191.47727966308594


 92%|█████████▏| 91522/100000 [12:24<01:09, 121.89it/s]

Mean acceptance rate 0.0
At L=91500 MSE on train is 123.51329803466797
At L=91500 MSE on validation is 191.47727966308594


 92%|█████████▏| 92016/100000 [12:28<01:05, 121.93it/s]

Mean acceptance rate 0.0
At L=92000 MSE on train is 123.51329803466797
At L=92000 MSE on validation is 191.47727966308594


 93%|█████████▎| 92523/100000 [12:32<01:00, 123.93it/s]

Mean acceptance rate 0.0
At L=92500 MSE on train is 123.51329803466797
At L=92500 MSE on validation is 191.47727966308594


 93%|█████████▎| 93017/100000 [12:36<00:56, 124.41it/s]

Mean acceptance rate 0.0
At L=93000 MSE on train is 123.51329803466797
At L=93000 MSE on validation is 191.47727966308594


 94%|█████████▎| 93524/100000 [12:40<00:52, 122.44it/s]

Mean acceptance rate 0.0
At L=93500 MSE on train is 123.51329803466797
At L=93500 MSE on validation is 191.47727966308594


 94%|█████████▍| 94018/100000 [12:44<00:48, 123.77it/s]

Mean acceptance rate 0.0
At L=94000 MSE on train is 123.51329803466797
At L=94000 MSE on validation is 191.47727966308594


 95%|█████████▍| 94525/100000 [12:48<00:44, 123.67it/s]

Mean acceptance rate 0.0
At L=94500 MSE on train is 123.51329803466797
At L=94500 MSE on validation is 191.47727966308594


 95%|█████████▌| 95019/100000 [12:52<00:40, 123.74it/s]

Mean acceptance rate 0.0020000000949949026
At L=95000 MSE on train is 123.0765609741211
At L=95000 MSE on validation is 186.82516479492188


 96%|█████████▌| 95513/100000 [12:56<00:35, 126.54it/s]

Mean acceptance rate 0.0
At L=95500 MSE on train is 123.0765609741211
At L=95500 MSE on validation is 186.82516479492188


 96%|█████████▌| 96020/100000 [13:00<00:32, 123.62it/s]

Mean acceptance rate 0.0
At L=96000 MSE on train is 123.0765609741211
At L=96000 MSE on validation is 186.82516479492188


 97%|█████████▋| 96514/100000 [13:04<00:28, 123.50it/s]

Mean acceptance rate 0.0
At L=96500 MSE on train is 123.0765609741211
At L=96500 MSE on validation is 186.82516479492188


 97%|█████████▋| 97021/100000 [13:08<00:24, 123.75it/s]

Mean acceptance rate 0.0
At L=97000 MSE on train is 123.0765609741211
At L=97000 MSE on validation is 186.82516479492188


 98%|█████████▊| 97515/100000 [13:12<00:20, 122.14it/s]

Mean acceptance rate 0.0
At L=97500 MSE on train is 123.0765609741211
At L=97500 MSE on validation is 186.82516479492188


 98%|█████████▊| 98021/100000 [13:17<00:16, 121.01it/s]

Mean acceptance rate 0.0
At L=98000 MSE on train is 123.0765609741211
At L=98000 MSE on validation is 186.82516479492188


 99%|█████████▊| 98515/100000 [13:21<00:12, 120.42it/s]

Mean acceptance rate 0.0
At L=98500 MSE on train is 123.0765609741211
At L=98500 MSE on validation is 186.82516479492188


 99%|█████████▉| 99022/100000 [13:25<00:08, 120.40it/s]

Mean acceptance rate 0.0
At L=99000 MSE on train is 123.0765609741211
At L=99000 MSE on validation is 186.82516479492188


100%|█████████▉| 99516/100000 [13:29<00:04, 119.30it/s]

Mean acceptance rate 0.0
At L=99500 MSE on train is 123.0765609741211
At L=99500 MSE on validation is 186.82516479492188


100%|██████████| 100000/100000 [13:33<00:00, 122.91it/s]
