In [1]:
"""
This is a the shrinkage version of Transport map. In this version, 
the regression functions $f_i$ and the niggest parameters $d_i$ are 
assumed to have some specific structures. These values are given by
the parametric covariance matrix. Unlike the example in other notebook,
here we try to estimate the parametric covariance matrix parameters 
using the integrated log-likelihood function.


Author: Anirban Chakraborty,
Last modified: May 13, 2024
"""
%load_ext autoreload
%autoreload 2

### Load necessary libraries

In [2]:
import torch
import numpy as np
from veccs import orderings
from gpytorch.kernels import MaternKernel
from sklearn.gaussian_process import kernels
from matplotlib import pyplot as plt

from batram.helpers import make_grid, GaussianProcessGenerator
from batram.legmods import Data, SimpleTM
from batram.shrinkmods import ShrinkTM, EstimableShrinkTM


  from .autonotebook import tqdm as notebook_tqdm


### Comparing log-score with the base transport maps (exponential kernel)

In [3]:
torch.manual_seed(20240522)

<torch._C.Generator at 0x7fbed1c8c710>

In [4]:
## kernel and location parameters

num_locs = 30; dim_locs = 2
nu_original = 0.5
length_scale_original = 0.3
numSamples = 30
sd_noise=1e-6
largest_conditioning_set = 30
sigmasq_f = 1.0

In [5]:
import pickle
with open("../data/NR900ExpLST30SIGSQT10.pkl", "rb") as f:
    data = pickle.load(f)
locs = data["locs"]
locsorder = data["order"]
gp = data["gp"]
torchdata = data["data"][:, locsorder]
nn = orderings.find_nns_l2(locs, largest_conditioning_set)

In [6]:
## getting the data ready

numSamples = [1, 2, 5, 10, 20, 30, 50, 80, 160, 200]
reps = 10
logScore_tm = torch.zeros((reps, len(numSamples)))
logScore_shrink = torch.zeros((reps, len(numSamples)))
tm_models = []
shrink_models = []
yreps = 50 #to be used for estimating log-score
nsteps = 400

In [7]:
## fit models
for i, n in enumerate(numSamples):
    #if (n == 1):
    #    theta_init_fixed = torch.tensor([0.0, 0.0, -2.0, 2.0, 0.0, -0.7])
    #else:
    theta_init_fixed = None
    for _reps in range(reps):
        randperm = torch.randperm(torchdata.shape[0])
        obs = (torchdata[randperm, :])[0:n, :] #snip first n samples
        #if obs.dim() == 1:
        #    obs = obs.unsqueeze(0)
        obsTrain = obs
        #if (n > 1):
        #    obs = (obs - obs.mean(dim=0, keepdim=True)) / obs.std(dim=0, keepdim=True)

        # Create a `Data` object for use with the `SimpleTM`/ `ShrinkTM` model.
        data_tm = Data.new(torch.as_tensor(locs).float(), obs, torch.as_tensor(nn))
        data_shrink = Data.new(torch.as_tensor(locs).float(), obs, torch.as_tensor(nn))

        tm = SimpleTM(data_tm, theta_init=None, linear=False, smooth=1.5, nug_mult=4.0)
        opt = torch.optim.Adam(tm.parameters(), lr=0.01)
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, nsteps)
        res = tm.fit(
            nsteps, 0.1, test_data=tm.data, optimizer=opt, scheduler=sched, batch_size=300
        )
        tm_models.append(tm)
        
        shrink_tm = EstimableShrinkTM(data=data_shrink, linear=False, 
                        transportmap_smooth=1.5, 
                        parametric_kernel= "exponential",
                        param_nu=0.5,
                        param_ls=1.0,
                        nug_mult_bounded=False,
                        theta_init=theta_init_fixed,
                        )
        
        opt2 = torch.optim.Adam(shrink_tm.parameters(), lr=0.01)
        sched2 = torch.optim.lr_scheduler.CosineAnnealingLR(opt2, nsteps)
        res2 = shrink_tm.fit(
            nsteps, 0.1, test_data=shrink_tm.data, optimizer=opt2, scheduler=sched2, batch_size=300,

        )
        shrink_models.append(shrink_tm)

        testsampnum = 50
        for _j in range(0, testsampnum):
            with torch.no_grad():
                logScore_tm[_reps, i] += tm.score((torchdata[randperm, :])[(200 + _j), :])/testsampnum
            logScore_shrink[_reps, i] += shrink_tm.score((torchdata[randperm, :])[(200 + _j), :])/testsampnum
        print(f"n ={n}, rep {_reps} done")
        print(f"tmscore = {logScore_tm[_reps, i]}, shrinkscore = {logScore_shrink[_reps, i]}")
    

Train Loss: 1201.395, Test Loss: 1201.395: 100%|██████████| 400/400 [00:05<00:00, 67.65it/s]
Train Loss: 796.407, Test Loss: 796.407: 100%|██████████| 400/400 [00:33<00:00, 11.92it/s]


n =1, rep 0 done
tmscore = -1966.5849609375, shrinkscore = -5605.60302734375


Train Loss: 1217.122, Test Loss: 1217.121: 100%|██████████| 400/400 [00:06<00:00, 64.65it/s]
Train Loss: 788.626, Test Loss: 788.626: 100%|██████████| 400/400 [00:33<00:00, 11.84it/s]


n =1, rep 1 done
tmscore = -1948.330078125, shrinkscore = -5369.23193359375


Train Loss: 1311.998, Test Loss: 1311.998: 100%|██████████| 400/400 [00:06<00:00, 64.10it/s]
Train Loss: 788.434, Test Loss: 788.434: 100%|██████████| 400/400 [00:33<00:00, 11.82it/s]


n =1, rep 2 done
tmscore = -1925.643798828125, shrinkscore = -3022.961669921875


Train Loss: 1066.363, Test Loss: 1066.363: 100%|██████████| 400/400 [00:07<00:00, 50.57it/s]
Train Loss: 767.775, Test Loss: 767.775: 100%|██████████| 400/400 [00:37<00:00, 10.78it/s]


n =1, rep 3 done
tmscore = -2034.82470703125, shrinkscore = -3255.986083984375


Train Loss: 1252.170, Test Loss: 1252.170: 100%|██████████| 400/400 [00:06<00:00, 63.73it/s]
Train Loss: 787.388, Test Loss: 787.388: 100%|██████████| 400/400 [00:33<00:00, 12.11it/s]


n =1, rep 4 done
tmscore = -1941.7713623046875, shrinkscore = -3663.657958984375


Train Loss: 1186.404, Test Loss: 1186.404: 100%|██████████| 400/400 [00:06<00:00, 64.37it/s]
Train Loss: 784.505, Test Loss: 784.505: 100%|██████████| 400/400 [00:34<00:00, 11.69it/s]


n =1, rep 5 done
tmscore = -1971.3192138671875, shrinkscore = -4840.90087890625


Train Loss: 1218.959, Test Loss: 1218.959: 100%|██████████| 400/400 [00:06<00:00, 65.48it/s]
Train Loss: 797.346, Test Loss: 797.345: 100%|██████████| 400/400 [00:34<00:00, 11.53it/s]


n =1, rep 6 done
tmscore = -1952.3804931640625, shrinkscore = -4037.850341796875


Train Loss: 1148.936, Test Loss: 1148.936: 100%|██████████| 400/400 [00:05<00:00, 66.83it/s]
Train Loss: 764.152, Test Loss: 764.152: 100%|██████████| 400/400 [00:33<00:00, 11.95it/s]


n =1, rep 7 done
tmscore = -1973.1871337890625, shrinkscore = -10120.32421875


Train Loss: 1430.407, Test Loss: 1430.407: 100%|██████████| 400/400 [00:06<00:00, 66.19it/s]
Train Loss: 805.278, Test Loss: 805.278: 100%|██████████| 400/400 [00:33<00:00, 12.02it/s]


n =1, rep 8 done
tmscore = -1896.8994140625, shrinkscore = -8016.3486328125


Train Loss: 1198.145, Test Loss: 1198.145: 100%|██████████| 400/400 [00:06<00:00, 65.57it/s]
Train Loss: 799.638, Test Loss: 799.638: 100%|██████████| 400/400 [00:33<00:00, 11.94it/s]


n =1, rep 9 done
tmscore = -1953.1220703125, shrinkscore = -2324.190673828125


Train Loss: 2320.613, Test Loss: 2320.613: 100%|██████████| 400/400 [00:07<00:00, 53.28it/s]
Train Loss: 1603.710, Test Loss: 1603.710: 100%|██████████| 400/400 [00:34<00:00, 11.67it/s]


n =2, rep 0 done
tmscore = -1899.1015625, shrinkscore = -1603.8526611328125


Train Loss: 2392.853, Test Loss: 2392.853: 100%|██████████| 400/400 [00:07<00:00, 52.88it/s]
Train Loss: 1588.710, Test Loss: 1588.710: 100%|██████████| 400/400 [00:34<00:00, 11.70it/s]


n =2, rep 1 done
tmscore = -1865.994873046875, shrinkscore = -1608.8367919921875


Train Loss: 2344.110, Test Loss: 2344.111: 100%|██████████| 400/400 [00:07<00:00, 52.88it/s]
Train Loss: 1557.017, Test Loss: 1557.017: 100%|██████████| 400/400 [00:34<00:00, 11.49it/s]


n =2, rep 2 done
tmscore = -1867.0574951171875, shrinkscore = -1608.5799560546875


Train Loss: 2208.547, Test Loss: 2208.547: 100%|██████████| 400/400 [00:07<00:00, 53.48it/s]
Train Loss: 1548.098, Test Loss: 1548.098: 100%|██████████| 400/400 [00:34<00:00, 11.74it/s]


n =2, rep 3 done
tmscore = -1911.2392578125, shrinkscore = -1608.9715576171875


Train Loss: 2270.717, Test Loss: 2270.717: 100%|██████████| 400/400 [00:07<00:00, 54.30it/s]
Train Loss: 1570.854, Test Loss: 1570.854: 100%|██████████| 400/400 [00:34<00:00, 11.64it/s]


n =2, rep 4 done
tmscore = -1911.545654296875, shrinkscore = -1604.018310546875


Train Loss: 2387.691, Test Loss: 2387.691: 100%|██████████| 400/400 [00:07<00:00, 52.96it/s]
Train Loss: 1577.678, Test Loss: 1577.678: 100%|██████████| 400/400 [00:34<00:00, 11.75it/s]


n =2, rep 5 done
tmscore = -1862.05078125, shrinkscore = -1604.736328125


Train Loss: 2319.784, Test Loss: 2319.784: 100%|██████████| 400/400 [00:07<00:00, 52.91it/s]
Train Loss: 1573.851, Test Loss: 1573.852: 100%|██████████| 400/400 [00:33<00:00, 11.81it/s]


n =2, rep 6 done
tmscore = -1880.7691650390625, shrinkscore = -1608.9412841796875


Train Loss: 2279.977, Test Loss: 2279.977: 100%|██████████| 400/400 [00:07<00:00, 52.47it/s]
Train Loss: 1556.530, Test Loss: 1556.531: 100%|██████████| 400/400 [00:34<00:00, 11.58it/s]


n =2, rep 7 done
tmscore = -1883.5556640625, shrinkscore = -1609.3983154296875


Train Loss: 2356.992, Test Loss: 2356.992: 100%|██████████| 400/400 [00:07<00:00, 53.01it/s]
Train Loss: 1561.664, Test Loss: 1561.667: 100%|██████████| 400/400 [00:34<00:00, 11.72it/s]


n =2, rep 8 done
tmscore = -1874.176025390625, shrinkscore = -1603.3316650390625


Train Loss: 2370.829, Test Loss: 2370.829: 100%|██████████| 400/400 [00:07<00:00, 52.08it/s]
Train Loss: 1583.008, Test Loss: 1583.010: 100%|██████████| 400/400 [00:34<00:00, 11.58it/s]


n =2, rep 9 done
tmscore = -1883.96337890625, shrinkscore = -1611.78271484375


Train Loss: 5431.844, Test Loss: 5431.844: 100%|██████████| 400/400 [00:09<00:00, 44.40it/s]
Train Loss: 3878.462, Test Loss: 3878.458: 100%|██████████| 400/400 [00:35<00:00, 11.39it/s]


n =5, rep 0 done
tmscore = -1766.433349609375, shrinkscore = -1584.97216796875


Train Loss: 5393.869, Test Loss: 5393.870: 100%|██████████| 400/400 [00:09<00:00, 44.38it/s]
Train Loss: 3932.258, Test Loss: 3932.258: 100%|██████████| 400/400 [00:35<00:00, 11.27it/s]


n =5, rep 1 done
tmscore = -1774.6661376953125, shrinkscore = -1581.7413330078125


Train Loss: 5295.719, Test Loss: 5295.719: 100%|██████████| 400/400 [00:09<00:00, 43.71it/s]
Train Loss: 3879.283, Test Loss: 3879.280: 100%|██████████| 400/400 [00:35<00:00, 11.17it/s]


n =5, rep 2 done
tmscore = -1788.5936279296875, shrinkscore = -1583.3905029296875


Train Loss: 5363.590, Test Loss: 5363.591: 100%|██████████| 400/400 [00:13<00:00, 30.02it/s]
Train Loss: 3889.033, Test Loss: 3889.034: 100%|██████████| 400/400 [00:35<00:00, 11.24it/s]


n =5, rep 3 done
tmscore = -1777.4686279296875, shrinkscore = -1589.384765625


Train Loss: 5229.076, Test Loss: 5229.075: 100%|██████████| 400/400 [00:09<00:00, 44.38it/s]
Train Loss: 3831.374, Test Loss: 3831.374: 100%|██████████| 400/400 [00:35<00:00, 11.25it/s]


n =5, rep 4 done
tmscore = -1791.8006591796875, shrinkscore = -1579.2601318359375


Train Loss: 5352.429, Test Loss: 5352.429: 100%|██████████| 400/400 [00:09<00:00, 42.31it/s]
Train Loss: 3834.800, Test Loss: 3834.794: 100%|██████████| 400/400 [00:38<00:00, 10.43it/s]


n =5, rep 5 done
tmscore = -1767.6806640625, shrinkscore = -1584.99072265625


Train Loss: 5395.892, Test Loss: 5395.893: 100%|██████████| 400/400 [00:09<00:00, 40.97it/s]
Train Loss: 3889.267, Test Loss: 3889.267: 100%|██████████| 400/400 [00:37<00:00, 10.65it/s]


n =5, rep 6 done
tmscore = -1776.31201171875, shrinkscore = -1583.2235107421875


Train Loss: 5470.558, Test Loss: 5470.558: 100%|██████████| 400/400 [00:09<00:00, 43.45it/s]
Train Loss: 3895.787, Test Loss: 3895.793: 100%|██████████| 400/400 [00:35<00:00, 11.19it/s]


n =5, rep 7 done
tmscore = -1770.6029052734375, shrinkscore = -1591.340576171875


Train Loss: 5378.980, Test Loss: 5378.979: 100%|██████████| 400/400 [00:09<00:00, 44.29it/s]
Train Loss: 3912.529, Test Loss: 3912.529: 100%|██████████| 400/400 [00:35<00:00, 11.34it/s]


n =5, rep 8 done
tmscore = -1775.8067626953125, shrinkscore = -1583.2581787109375


Train Loss: 5352.320, Test Loss: 5352.320: 100%|██████████| 400/400 [00:08<00:00, 44.68it/s]
Train Loss: 3888.744, Test Loss: 3888.749: 100%|██████████| 400/400 [00:35<00:00, 11.39it/s]


n =5, rep 9 done
tmscore = -1769.9412841796875, shrinkscore = -1582.51123046875


Train Loss: 10022.440, Test Loss: 10022.438: 100%|██████████| 400/400 [00:11<00:00, 35.76it/s]
Train Loss: 7606.780, Test Loss: 7606.779: 100%|██████████| 400/400 [00:37<00:00, 10.68it/s]


n =10, rep 0 done
tmscore = -1694.2022705078125, shrinkscore = -1544.863525390625


Train Loss: 9937.213, Test Loss: 9937.213: 100%|██████████| 400/400 [00:11<00:00, 35.37it/s]
Train Loss: 7633.507, Test Loss: 7633.508: 100%|██████████| 400/400 [00:37<00:00, 10.69it/s]


n =10, rep 1 done
tmscore = -1705.702392578125, shrinkscore = -1544.53125


Train Loss: 10018.269, Test Loss: 10018.268: 100%|██████████| 400/400 [00:11<00:00, 35.56it/s]
Train Loss: 7637.903, Test Loss: 7637.901: 100%|██████████| 400/400 [00:38<00:00, 10.34it/s]


n =10, rep 2 done
tmscore = -1700.9046630859375, shrinkscore = -1545.0428466796875


Train Loss: 9863.574, Test Loss: 9863.574: 100%|██████████| 400/400 [00:11<00:00, 35.84it/s]
Train Loss: 7575.976, Test Loss: 7575.977: 100%|██████████| 400/400 [00:37<00:00, 10.78it/s]


n =10, rep 3 done
tmscore = -1705.173583984375, shrinkscore = -1542.4219970703125


Train Loss: 10136.759, Test Loss: 10136.759: 100%|██████████| 400/400 [00:11<00:00, 35.56it/s]
Train Loss: 7709.855, Test Loss: 7709.854: 100%|██████████| 400/400 [00:38<00:00, 10.52it/s]


n =10, rep 4 done
tmscore = -1693.2342529296875, shrinkscore = -1548.7978515625


Train Loss: 9932.921, Test Loss: 9932.922: 100%|██████████| 400/400 [00:11<00:00, 35.49it/s] 
Train Loss: 7510.034, Test Loss: 7510.033: 100%|██████████| 400/400 [00:37<00:00, 10.59it/s]


n =10, rep 5 done
tmscore = -1701.89599609375, shrinkscore = -1549.278564453125


Train Loss: 10050.332, Test Loss: 10050.331: 100%|██████████| 400/400 [00:11<00:00, 35.77it/s]
Train Loss: 7634.528, Test Loss: 7634.526: 100%|██████████| 400/400 [00:37<00:00, 10.78it/s]


n =10, rep 6 done
tmscore = -1696.5167236328125, shrinkscore = -1545.9456787109375


Train Loss: 10146.360, Test Loss: 10146.359: 100%|██████████| 400/400 [00:10<00:00, 36.42it/s]
Train Loss: 7677.694, Test Loss: 7677.696: 100%|██████████| 400/400 [00:41<00:00,  9.61it/s]


n =10, rep 7 done
tmscore = -1691.8944091796875, shrinkscore = -1548.076904296875


Train Loss: 9933.821, Test Loss: 9933.821: 100%|██████████| 400/400 [00:12<00:00, 31.75it/s] 
Train Loss: 7575.573, Test Loss: 7575.574: 100%|██████████| 400/400 [00:39<00:00, 10.24it/s]


n =10, rep 8 done
tmscore = -1701.80810546875, shrinkscore = -1545.738525390625


Train Loss: 9904.328, Test Loss: 9904.328: 100%|██████████| 400/400 [00:11<00:00, 35.41it/s]
Train Loss: 7582.192, Test Loss: 7582.192: 100%|██████████| 400/400 [00:42<00:00,  9.44it/s]


n =10, rep 9 done
tmscore = -1709.2147216796875, shrinkscore = -1549.6619873046875


Train Loss: 18031.753, Test Loss: 18031.754: 100%|██████████| 400/400 [00:15<00:00, 25.91it/s]
Train Loss: 14178.088, Test Loss: 14178.088: 100%|██████████| 400/400 [00:42<00:00,  9.36it/s]


n =20, rep 0 done
tmscore = -1599.8319091796875, shrinkscore = -1460.966064453125


Train Loss: 18519.215, Test Loss: 18519.215: 100%|██████████| 400/400 [00:16<00:00, 24.27it/s]
Train Loss: 14670.466, Test Loss: 14670.459: 100%|██████████| 400/400 [00:54<00:00,  7.39it/s]


n =20, rep 1 done
tmscore = -1615.2841796875, shrinkscore = -1472.8048095703125


Train Loss: 18455.984, Test Loss: 18455.984: 100%|██████████| 400/400 [00:16<00:00, 23.58it/s]
Train Loss: 14530.826, Test Loss: 14530.828: 100%|██████████| 400/400 [00:45<00:00,  8.83it/s]


n =20, rep 2 done
tmscore = -1608.81494140625, shrinkscore = -1471.1944580078125


Train Loss: 18324.880, Test Loss: 18324.879: 100%|██████████| 400/400 [00:16<00:00, 23.84it/s]
Train Loss: 14502.739, Test Loss: 14502.738: 100%|██████████| 400/400 [00:46<00:00,  8.61it/s]


n =20, rep 3 done
tmscore = -1607.293701171875, shrinkscore = -1466.4951171875


Train Loss: 18057.696, Test Loss: 18057.695: 100%|██████████| 400/400 [00:16<00:00, 24.49it/s]
Train Loss: 14224.105, Test Loss: 14224.101: 100%|██████████| 400/400 [00:49<00:00,  8.09it/s]


n =20, rep 4 done
tmscore = -1603.345947265625, shrinkscore = -1464.1221923828125


Train Loss: 18018.732, Test Loss: 18018.730: 100%|██████████| 400/400 [00:16<00:00, 24.74it/s]
Train Loss: 14234.865, Test Loss: 14234.866: 100%|██████████| 400/400 [00:45<00:00,  8.77it/s]


n =20, rep 5 done
tmscore = -1597.7862548828125, shrinkscore = -1453.5028076171875


Train Loss: 18054.746, Test Loss: 18054.746: 100%|██████████| 400/400 [00:16<00:00, 24.11it/s]
Train Loss: 14255.874, Test Loss: 14255.873: 100%|██████████| 400/400 [00:44<00:00,  8.94it/s]


n =20, rep 6 done
tmscore = -1606.1392822265625, shrinkscore = -1467.046142578125


Train Loss: 18260.354, Test Loss: 18260.355: 100%|██████████| 400/400 [00:16<00:00, 23.66it/s]
Train Loss: 14356.783, Test Loss: 14356.783: 100%|██████████| 400/400 [00:45<00:00,  8.71it/s]


n =20, rep 7 done
tmscore = -1608.364501953125, shrinkscore = -1472.2908935546875


Train Loss: 18076.564, Test Loss: 18076.566: 100%|██████████| 400/400 [00:16<00:00, 23.92it/s]
Train Loss: 14242.535, Test Loss: 14242.531: 100%|██████████| 400/400 [00:45<00:00,  8.78it/s]


n =20, rep 8 done
tmscore = -1602.3448486328125, shrinkscore = -1462.5814208984375


Train Loss: 18271.083, Test Loss: 18271.086: 100%|██████████| 400/400 [00:16<00:00, 23.82it/s]
Train Loss: 14422.457, Test Loss: 14422.457: 100%|██████████| 400/400 [00:45<00:00,  8.75it/s]


n =20, rep 9 done
tmscore = -1604.5, shrinkscore = -1468.582763671875


Train Loss: 25598.357, Test Loss: 25598.357: 100%|██████████| 400/400 [00:24<00:00, 16.24it/s]
Train Loss: 20551.760, Test Loss: 20551.766: 100%|██████████| 400/400 [00:56<00:00,  7.11it/s]


n =30, rep 0 done
tmscore = -1489.82275390625, shrinkscore = -1382.38623046875


Train Loss: 25579.109, Test Loss: 25579.107: 100%|██████████| 400/400 [00:23<00:00, 16.96it/s]
Train Loss: 20501.487, Test Loss: 20501.486: 100%|██████████| 400/400 [00:51<00:00,  7.75it/s]


n =30, rep 1 done
tmscore = -1504.5675048828125, shrinkscore = -1392.9783935546875


Train Loss: 25462.237, Test Loss: 25462.238: 100%|██████████| 400/400 [00:22<00:00, 17.63it/s]
Train Loss: 20344.243, Test Loss: 20344.246: 100%|██████████| 400/400 [00:51<00:00,  7.82it/s]


n =30, rep 2 done
tmscore = -1512.6856689453125, shrinkscore = -1395.0498046875


Train Loss: 25561.378, Test Loss: 25561.379: 100%|██████████| 400/400 [00:22<00:00, 17.50it/s]
Train Loss: 20415.739, Test Loss: 20415.744: 100%|██████████| 400/400 [00:51<00:00,  7.80it/s]


n =30, rep 3 done
tmscore = -1499.3192138671875, shrinkscore = -1393.3897705078125


Train Loss: 25417.518, Test Loss: 25417.516: 100%|██████████| 400/400 [00:23<00:00, 17.39it/s]
Train Loss: 20353.555, Test Loss: 20353.543: 100%|██████████| 400/400 [00:51<00:00,  7.78it/s]


n =30, rep 4 done
tmscore = -1492.8037109375, shrinkscore = -1381.4080810546875


Train Loss: 25451.586, Test Loss: 25451.586: 100%|██████████| 400/400 [00:22<00:00, 17.80it/s]
Train Loss: 20329.571, Test Loss: 20329.570: 100%|██████████| 400/400 [00:52<00:00,  7.68it/s]


n =30, rep 5 done
tmscore = -1498.0682373046875, shrinkscore = -1385.6807861328125


Train Loss: 25757.464, Test Loss: 25757.463: 100%|██████████| 400/400 [00:22<00:00, 17.60it/s]
Train Loss: 20575.268, Test Loss: 20575.270: 100%|██████████| 400/400 [00:51<00:00,  7.72it/s]


n =30, rep 6 done
tmscore = -1497.7872314453125, shrinkscore = -1381.810791015625


Train Loss: 25634.220, Test Loss: 25634.219: 100%|██████████| 400/400 [00:22<00:00, 17.52it/s]
Train Loss: 20470.124, Test Loss: 20470.125: 100%|██████████| 400/400 [00:51<00:00,  7.81it/s]


n =30, rep 7 done
tmscore = -1495.78662109375, shrinkscore = -1383.118408203125


Train Loss: 25572.911, Test Loss: 25572.908: 100%|██████████| 400/400 [00:23<00:00, 17.04it/s]
Train Loss: 20534.539, Test Loss: 20534.535: 100%|██████████| 400/400 [00:51<00:00,  7.75it/s]


n =30, rep 8 done
tmscore = -1507.7801513671875, shrinkscore = -1390.436279296875


Train Loss: 25650.904, Test Loss: 25650.904: 100%|██████████| 400/400 [00:22<00:00, 17.54it/s]
Train Loss: 20607.061, Test Loss: 20607.059: 100%|██████████| 400/400 [00:51<00:00,  7.75it/s]


n =30, rep 9 done
tmscore = -1512.4241943359375, shrinkscore = -1399.6514892578125


Train Loss: 37772.676, Test Loss: 37772.676: 100%|██████████| 400/400 [00:52<00:00,  7.69it/s]
Train Loss: 30666.699, Test Loss: 30666.693: 100%|██████████| 400/400 [01:20<00:00,  4.95it/s]


n =50, rep 0 done
tmscore = -1325.4990234375, shrinkscore = -1249.157958984375


Train Loss: 37475.875, Test Loss: 37475.875: 100%|██████████| 400/400 [00:58<00:00,  6.82it/s]
Train Loss: 30464.654, Test Loss: 30464.650: 100%|██████████| 400/400 [01:28<00:00,  4.54it/s]


n =50, rep 1 done
tmscore = -1330.1441650390625, shrinkscore = -1256.395263671875


Train Loss: 37067.039, Test Loss: 37067.039: 100%|██████████| 400/400 [00:55<00:00,  7.17it/s]
Train Loss: 30139.273, Test Loss: 30139.264: 100%|██████████| 400/400 [01:25<00:00,  4.68it/s]


n =50, rep 2 done
tmscore = -1333.3624267578125, shrinkscore = -1254.6654052734375


Train Loss: 37395.928, Test Loss: 37395.930: 100%|██████████| 400/400 [00:56<00:00,  7.14it/s]
Train Loss: 30395.842, Test Loss: 30395.842: 100%|██████████| 400/400 [01:23<00:00,  4.80it/s]


n =50, rep 3 done
tmscore = -1316.7474365234375, shrinkscore = -1240.0615234375


Train Loss: 37349.493, Test Loss: 37349.500: 100%|██████████| 400/400 [00:52<00:00,  7.58it/s]
Train Loss: 30421.442, Test Loss: 30421.443: 100%|██████████| 400/400 [01:23<00:00,  4.79it/s]


n =50, rep 4 done
tmscore = -1328.3739013671875, shrinkscore = -1250.9864501953125


Train Loss: 37024.092, Test Loss: 37024.094: 100%|██████████| 400/400 [00:52<00:00,  7.56it/s]
Train Loss: 30032.129, Test Loss: 30032.129: 100%|██████████| 400/400 [01:22<00:00,  4.85it/s]


n =50, rep 5 done
tmscore = -1326.6346435546875, shrinkscore = -1247.4381103515625


Train Loss: 37139.301, Test Loss: 37139.301: 100%|██████████| 400/400 [00:52<00:00,  7.67it/s]
Train Loss: 30096.986, Test Loss: 30096.990: 100%|██████████| 400/400 [01:21<00:00,  4.89it/s]


n =50, rep 6 done
tmscore = -1317.6939697265625, shrinkscore = -1242.09130859375


Train Loss: 36718.840, Test Loss: 36718.836: 100%|██████████| 400/400 [00:52<00:00,  7.68it/s]
Train Loss: 29763.833, Test Loss: 29763.834: 100%|██████████| 400/400 [01:22<00:00,  4.86it/s]


n =50, rep 7 done
tmscore = -1315.208984375, shrinkscore = -1237.8804931640625


Train Loss: 37770.628, Test Loss: 37770.625: 100%|██████████| 400/400 [00:53<00:00,  7.44it/s]
Train Loss: 30741.844, Test Loss: 30741.846: 100%|██████████| 400/400 [01:23<00:00,  4.79it/s]


n =50, rep 8 done
tmscore = -1330.155517578125, shrinkscore = -1251.0748291015625


Train Loss: 36779.208, Test Loss: 36779.211: 100%|██████████| 400/400 [00:54<00:00,  7.37it/s]
Train Loss: 29939.000, Test Loss: 29938.996: 100%|██████████| 400/400 [01:27<00:00,  4.56it/s]


n =50, rep 9 done
tmscore = -1335.5052490234375, shrinkscore = -1254.8992919921875


Train Loss: 48994.021, Test Loss: 48994.027: 100%|██████████| 400/400 [02:00<00:00,  3.32it/s]
Train Loss: 39978.466, Test Loss: 39978.457: 100%|██████████| 400/400 [02:32<00:00,  2.63it/s]


n =80, rep 0 done
tmscore = -1147.5953369140625, shrinkscore = -1077.3245849609375


Train Loss: 49131.638, Test Loss: 49131.641: 100%|██████████| 400/400 [01:55<00:00,  3.46it/s]
Train Loss: 40040.146, Test Loss: 40040.152: 100%|██████████| 400/400 [02:30<00:00,  2.65it/s]


n =80, rep 1 done
tmscore = -1140.6610107421875, shrinkscore = -1075.501708984375


Train Loss: 50054.272, Test Loss: 50054.277: 100%|██████████| 400/400 [01:57<00:00,  3.40it/s]
Train Loss: 40838.156, Test Loss: 40838.156: 100%|██████████| 400/400 [02:28<00:00,  2.69it/s]


n =80, rep 2 done
tmscore = -1136.61572265625, shrinkscore = -1073.5975341796875


Train Loss: 49024.402, Test Loss: 49024.402: 100%|██████████| 400/400 [01:57<00:00,  3.42it/s]
Train Loss: 39910.669, Test Loss: 39910.672: 100%|██████████| 400/400 [02:28<00:00,  2.69it/s]


n =80, rep 3 done
tmscore = -1143.0423583984375, shrinkscore = -1074.656005859375


Train Loss: 49318.961, Test Loss: 49318.965: 100%|██████████| 400/400 [02:00<00:00,  3.33it/s]
Train Loss: 40262.741, Test Loss: 40262.746: 100%|██████████| 400/400 [02:31<00:00,  2.64it/s]


n =80, rep 4 done
tmscore = -1137.1361083984375, shrinkscore = -1071.521728515625


Train Loss: 49161.534, Test Loss: 49161.527: 100%|██████████| 400/400 [02:00<00:00,  3.33it/s]
Train Loss: 40111.980, Test Loss: 40111.984: 100%|██████████| 400/400 [02:32<00:00,  2.62it/s]


n =80, rep 5 done
tmscore = -1156.203857421875, shrinkscore = -1091.0806884765625


Train Loss: 48888.611, Test Loss: 48888.609: 100%|██████████| 400/400 [01:55<00:00,  3.47it/s]
Train Loss: 39797.958, Test Loss: 39797.953: 100%|██████████| 400/400 [02:31<00:00,  2.63it/s]


n =80, rep 6 done
tmscore = -1133.6881103515625, shrinkscore = -1067.6756591796875


Train Loss: 49482.600, Test Loss: 49482.602: 100%|██████████| 400/400 [01:55<00:00,  3.48it/s]
Train Loss: 40277.809, Test Loss: 40277.805: 100%|██████████| 400/400 [02:30<00:00,  2.67it/s]


n =80, rep 7 done
tmscore = -1140.6004638671875, shrinkscore = -1075.2900390625


Train Loss: 48886.845, Test Loss: 48886.840: 100%|██████████| 400/400 [01:56<00:00,  3.43it/s]
Train Loss: 39834.694, Test Loss: 39834.695: 100%|██████████| 400/400 [02:29<00:00,  2.68it/s]


n =80, rep 8 done
tmscore = -1138.98828125, shrinkscore = -1074.05712890625


Train Loss: 48892.861, Test Loss: 48892.855: 100%|██████████| 400/400 [01:59<00:00,  3.35it/s]
Train Loss: 39686.714, Test Loss: 39686.715: 100%|██████████| 400/400 [02:31<00:00,  2.64it/s]


n =80, rep 9 done
tmscore = -1141.988037109375, shrinkscore = -1075.772216796875


Train Loss: 64196.827, Test Loss: 64196.824: 100%|██████████| 400/400 [06:41<00:00,  1.00s/it]
Train Loss: 49738.547, Test Loss: 49738.543: 100%|██████████| 400/400 [07:24<00:00,  1.11s/it]


n =160, rep 0 done
tmscore = -899.3291625976562, shrinkscore = -840.2755737304688


Train Loss: 63819.691, Test Loss: 63819.680: 100%|██████████| 400/400 [06:37<00:00,  1.01it/s]
Train Loss: 49472.042, Test Loss: 49472.047: 100%|██████████| 400/400 [07:24<00:00,  1.11s/it]


n =160, rep 1 done
tmscore = -911.7537231445312, shrinkscore = -851.6375122070312


Train Loss: 63381.926, Test Loss: 63381.922: 100%|██████████| 400/400 [06:32<00:00,  1.02it/s]
Train Loss: 49021.733, Test Loss: 49021.734: 100%|██████████| 400/400 [07:49<00:00,  1.17s/it]


n =160, rep 2 done
tmscore = -895.9644165039062, shrinkscore = -836.5612182617188


Train Loss: 63380.457, Test Loss: 63380.453: 100%|██████████| 400/400 [06:31<00:00,  1.02it/s]
Train Loss: 48943.169, Test Loss: 48943.172: 100%|██████████| 400/400 [08:13<00:00,  1.23s/it]


n =160, rep 3 done
tmscore = -918.4982299804688, shrinkscore = -857.136474609375


Train Loss: 63782.191, Test Loss: 63782.188: 100%|██████████| 400/400 [07:47<00:00,  1.17s/it]
Train Loss: 49425.646, Test Loss: 49425.648: 100%|██████████| 400/400 [08:23<00:00,  1.26s/it]


n =160, rep 4 done
tmscore = -912.4564819335938, shrinkscore = -851.0782470703125


Train Loss: 63693.185, Test Loss: 63693.184: 100%|██████████| 400/400 [07:11<00:00,  1.08s/it]
Train Loss: 49305.405, Test Loss: 49305.402: 100%|██████████| 400/400 [09:42<00:00,  1.46s/it]


n =160, rep 5 done
tmscore = -900.166259765625, shrinkscore = -839.1046142578125


Train Loss: 64222.130, Test Loss: 64222.129: 100%|██████████| 400/400 [08:06<00:00,  1.22s/it]
Train Loss: 49893.178, Test Loss: 49893.180: 100%|██████████| 400/400 [08:22<00:00,  1.26s/it]


n =160, rep 6 done
tmscore = -902.2489624023438, shrinkscore = -841.3084106445312


Train Loss: 62565.639, Test Loss: 62565.641: 100%|██████████| 400/400 [06:56<00:00,  1.04s/it]
Train Loss: 48296.724, Test Loss: 48296.727: 100%|██████████| 400/400 [07:30<00:00,  1.13s/it]


n =160, rep 7 done
tmscore = -917.39892578125, shrinkscore = -855.1844482421875


Train Loss: 63467.828, Test Loss: 63467.820: 100%|██████████| 400/400 [07:55<00:00,  1.19s/it]
Train Loss: 49120.059, Test Loss: 49120.074: 100%|██████████| 400/400 [07:34<00:00,  1.14s/it]


n =160, rep 8 done
tmscore = -901.1205444335938, shrinkscore = -839.5804443359375


Train Loss: 63971.374, Test Loss: 63971.371: 100%|██████████| 400/400 [07:04<00:00,  1.06s/it]
Train Loss: 49531.233, Test Loss: 49531.234: 100%|██████████| 400/400 [07:58<00:00,  1.20s/it]


n =160, rep 9 done
tmscore = -889.491943359375, shrinkscore = -828.227294921875


Train Loss: 65598.717, Test Loss: 65598.719: 100%|██████████| 400/400 [12:45<00:00,  1.91s/it]
Train Loss: 48615.229, Test Loss: 48615.227: 100%|██████████| 400/400 [13:21<00:00,  2.00s/it]


n =200, rep 0 done
tmscore = -844.0306396484375, shrinkscore = -785.5910034179688


Train Loss: 65453.957, Test Loss: 65453.953: 100%|██████████| 400/400 [11:29<00:00,  1.72s/it]
Train Loss: 48505.518, Test Loss: 48505.516: 100%|██████████| 400/400 [12:34<00:00,  1.89s/it]


n =200, rep 1 done
tmscore = -855.439453125, shrinkscore = -798.2098388671875


Train Loss: 65663.208, Test Loss: 65663.211: 100%|██████████| 400/400 [11:22<00:00,  1.71s/it]
Train Loss: 48752.936, Test Loss: 48752.949: 100%|██████████| 400/400 [12:40<00:00,  1.90s/it]


n =200, rep 2 done
tmscore = -850.2902221679688, shrinkscore = -792.2962036132812


Train Loss: 66089.812, Test Loss: 66089.812: 100%|██████████| 400/400 [11:40<00:00,  1.75s/it]
Train Loss: 49102.777, Test Loss: 49102.781: 100%|██████████| 400/400 [12:30<00:00,  1.88s/it]


n =200, rep 3 done
tmscore = -837.4471435546875, shrinkscore = -780.4719848632812


Train Loss: 65118.223, Test Loss: 65118.227: 100%|██████████| 400/400 [11:06<00:00,  1.67s/it]
Train Loss: 48145.826, Test Loss: 48145.832: 100%|██████████| 400/400 [12:26<00:00,  1.87s/it]


n =200, rep 4 done
tmscore = -856.7238159179688, shrinkscore = -799.0363159179688


Train Loss: 65133.762, Test Loss: 65133.777: 100%|██████████| 400/400 [11:40<00:00,  1.75s/it]
Train Loss: 48282.595, Test Loss: 48282.605: 100%|██████████| 400/400 [12:39<00:00,  1.90s/it]


n =200, rep 5 done
tmscore = -843.465576171875, shrinkscore = -783.4490356445312


Train Loss: 65531.885, Test Loss: 65531.891: 100%|██████████| 400/400 [11:26<00:00,  1.72s/it]
Train Loss: 48536.232, Test Loss: 48536.230: 100%|██████████| 400/400 [12:47<00:00,  1.92s/it]


n =200, rep 6 done
tmscore = -830.169677734375, shrinkscore = -774.0859985351562


Train Loss: 65766.617, Test Loss: 65766.609: 100%|██████████| 400/400 [11:39<00:00,  1.75s/it]
Train Loss: 48947.441, Test Loss: 48947.449: 100%|██████████| 400/400 [12:51<00:00,  1.93s/it]


n =200, rep 7 done
tmscore = -852.1978149414062, shrinkscore = -793.5626831054688


Train Loss: 65612.867, Test Loss: 65612.859: 100%|██████████| 400/400 [11:29<00:00,  1.72s/it]
Train Loss: 48647.128, Test Loss: 48647.129: 100%|██████████| 400/400 [12:48<00:00,  1.92s/it]


n =200, rep 8 done
tmscore = -847.1380615234375, shrinkscore = -790.2738037109375


Train Loss: 65556.092, Test Loss: 65556.102: 100%|██████████| 400/400 [11:31<00:00,  1.73s/it]
Train Loss: 48551.624, Test Loss: 48551.629: 100%|██████████| 400/400 [12:54<00:00,  1.94s/it]


n =200, rep 9 done
tmscore = -848.4058227539062, shrinkscore = -790.2371215820312


In [8]:
torch.save({
    "gp_generator": gp,
    "tm_models": tm_models,
    "shrink_models": shrink_models,
    "tm_logscore" : logScore_tm,
    "shrink_logscore": logScore_shrink,
    "numSamples": numSamples
}, f"../results/modelsNR_LST{int(100*length_scale_original)}_SQT{int(100*sigmasq_f)}.pt")

In [9]:
import torch
models_history = torch.load("../results/modelsNR_LST30_SQT100.pt")
shrink_models = models_history["shrink_models"]


AttributeError: Can't get attribute 'EstimableShrinkTMRefactor' on <module 'batram.shrinkmods' from '/mnt/c/Projects/batram-shrink2param/src/batram/shrinkmods.py'>

In [None]:
Ns = models_history["numSamples"]
reps = int(len(shrink_models)/ len(models_history["numSamples"]))

In [None]:
_tmp = shrink_models[29]

In [None]:
with torch.no_grad():
    print(_tmp.nugget_shrinkage_factor.exp())

In [None]:
nugget_shrink_factors = torch.zeros((len(Ns), reps))
for i, _ in enumerate(Ns):
    for j in range(reps):
        with torch.no_grad():
            nugget_shrink_factors[i, j] = shrink_models[i*reps + j].nugget_shrinkage_factor.exp()

mean_nug_shrink = nugget_shrink_factors.mean(dim = 1)
median_nug_shrink = nugget_shrink_factors.median(dim = 1).values
max_nug_shrink = nugget_shrink_factors.max(dim = 1).values
min_nug_shrink = nugget_shrink_factors.min(dim = 1).values

In [None]:
import matplotlib.pyplot as plt

#plt.plot(torch.arange(len(Ns)), median_nug_shrink, alpha = 0.7)
plt.plot(torch.arange(len(Ns)), mean_nug_shrink)
plt.fill_between(torch.arange(len(Ns)), min_nug_shrink, max_nug_shrink, alpha = 0.3)
plt.xticks(torch.arange(len(Ns)), Ns);
plt.xlabel("Size of training data")