In [1]:
%load_ext autoreload
%autoreload 2
import pinot
import torch
import numpy as np
from itertools import chain


Using backend: pytorch


In [2]:
ds = pinot.data.moonshot_meta()

In [3]:
ds_tr, ds_te = pinot.data.utils.split(ds, [4, 1])

In [4]:
def get_separate_dataset(ds):
    n_tasks = ds[0][1].shape[-1]
    datasets = [[] for _ in range(n_tasks)]
    
    for g, y in ds:
        for idx in range(n_tasks):
            if np.isnan(y[idx].numpy()) == False:
                datasets[idx].append((g, y[idx][None]))
    
    return datasets


In [20]:
datasets = get_separate_dataset(ds_tr)
datasets = [pinot.data.utils.batch(data, len(data)) for data in datasets]

datasets_te = get_separate_dataset(ds_te)
datasets_te = [pinot.data.utils.batch(data, len(data)) for data in datasets_te]

In [23]:
net = pinot.representation.Sequential(
    pinot.representation.dgl_legacy.gn(),
    [32, 'tanh', 32, 'tanh', 32, 'tanh'])

In [24]:
gps = []
for idx in range(len(datasets)):
    model = pinot.inference.gp.gpr.exact_gpr.ExactGPR(
        kernel=pinot.inference.gp.kernels.deep_kernel.DeepKernel(
            base_kernel=pinot.inference.gp.kernels.rbf.RBF(scale=torch.ones(32)),
            representation=net))
    
    gps.append(model)

In [25]:
params = []

for gp in gps:
    params += list(gp.parameters())

opt = torch.optim.Adam(params, 1e-5)
for _ in range(500):
    opt.zero_grad()
    loss = 0.0
    for idx, gpr in enumerate(gps):
        g, y = datasets[idx][0]
        loss += gpr.loss(g, y)
        
    loss.backward()
    opt.step()
    print(loss)
        

tensor([[2976.3604]], grad_fn=<AddBackward0>)
tensor([[2972.9854]], grad_fn=<AddBackward0>)
tensor([[2969.5415]], grad_fn=<AddBackward0>)
tensor([[2966.0208]], grad_fn=<AddBackward0>)
tensor([[2962.4165]], grad_fn=<AddBackward0>)
tensor([[2958.7227]], grad_fn=<AddBackward0>)
tensor([[2954.9360]], grad_fn=<AddBackward0>)
tensor([[2951.0532]], grad_fn=<AddBackward0>)
tensor([[2947.0740]], grad_fn=<AddBackward0>)
tensor([[2942.9976]], grad_fn=<AddBackward0>)
tensor([[2938.8247]], grad_fn=<AddBackward0>)
tensor([[2934.5588]], grad_fn=<AddBackward0>)
tensor([[2930.2029]], grad_fn=<AddBackward0>)
tensor([[2925.7634]], grad_fn=<AddBackward0>)
tensor([[2921.2454]], grad_fn=<AddBackward0>)
tensor([[2916.6580]], grad_fn=<AddBackward0>)
tensor([[2912.0105]], grad_fn=<AddBackward0>)
tensor([[2907.3130]], grad_fn=<AddBackward0>)
tensor([[2902.5754]], grad_fn=<AddBackward0>)
tensor([[2897.8101]], grad_fn=<AddBackward0>)
tensor([[2893.0276]], grad_fn=<AddBackward0>)
tensor([[2888.2400]], grad_fn=<Add

tensor([[2598.2009]], grad_fn=<AddBackward0>)
tensor([[2595.0344]], grad_fn=<AddBackward0>)
tensor([[2595.6086]], grad_fn=<AddBackward0>)
tensor([[2593.1929]], grad_fn=<AddBackward0>)
tensor([[2593.2610]], grad_fn=<AddBackward0>)
tensor([[2591.8215]], grad_fn=<AddBackward0>)
tensor([[2590.4539]], grad_fn=<AddBackward0>)
tensor([[2589.9028]], grad_fn=<AddBackward0>)
tensor([[2588.0903]], grad_fn=<AddBackward0>)
tensor([[2587.9312]], grad_fn=<AddBackward0>)
tensor([[2585.8101]], grad_fn=<AddBackward0>)
tensor([[2585.6533]], grad_fn=<AddBackward0>)
tensor([[2583.2119]], grad_fn=<AddBackward0>)
tensor([[2583.5969]], grad_fn=<AddBackward0>)
tensor([[2581.1304]], grad_fn=<AddBackward0>)
tensor([[2580.9956]], grad_fn=<AddBackward0>)
tensor([[2579.1472]], grad_fn=<AddBackward0>)
tensor([[2578.3127]], grad_fn=<AddBackward0>)
tensor([[2577.6226]], grad_fn=<AddBackward0>)
tensor([[2575.2390]], grad_fn=<AddBackward0>)
tensor([[2575.6951]], grad_fn=<AddBackward0>)
tensor([[2572.2019]], grad_fn=<Add

tensor([[2292.7639]], grad_fn=<AddBackward0>)
tensor([[2314.0178]], grad_fn=<AddBackward0>)
tensor([[2294.4377]], grad_fn=<AddBackward0>)
tensor([[2319.2593]], grad_fn=<AddBackward0>)
tensor([[2300.2241]], grad_fn=<AddBackward0>)
tensor([[2294.0005]], grad_fn=<AddBackward0>)
tensor([[2296.3372]], grad_fn=<AddBackward0>)
tensor([[2294.9707]], grad_fn=<AddBackward0>)
tensor([[2303.2588]], grad_fn=<AddBackward0>)
tensor([[2290.7654]], grad_fn=<AddBackward0>)
tensor([[2306.4810]], grad_fn=<AddBackward0>)
tensor([[2298.4932]], grad_fn=<AddBackward0>)
tensor([[2308.9827]], grad_fn=<AddBackward0>)
tensor([[2293.5779]], grad_fn=<AddBackward0>)
tensor([[2296.7771]], grad_fn=<AddBackward0>)
tensor([[2292.2292]], grad_fn=<AddBackward0>)
tensor([[2300.7087]], grad_fn=<AddBackward0>)
tensor([[2298.6313]], grad_fn=<AddBackward0>)
tensor([[2293.6211]], grad_fn=<AddBackward0>)
tensor([[2303.9717]], grad_fn=<AddBackward0>)
tensor([[2297.8267]], grad_fn=<AddBackward0>)
tensor([[2310.0967]], grad_fn=<Add

In [26]:
metrics = ['avg_nll', 'r2', 'rmse']
for idx, gp in enumerate(gps):
    g, y = datasets_te[idx][0]
    
    print('-----')
    
    print(idx)
    for metric in metrics:
        metric = getattr(pinot.metrics, metric)
        print(metric, metric(gp, g, y).detach().numpy())

-----
0
<function avg_nll at 0x13d468200> -0.27919927
<function r2 at 0x13d4680e0> 0.6519361
<function rmse at 0x13d462f80> 0.14225633
-----
1
<function avg_nll at 0x13d468200> 0.09541814
<function r2 at 0x13d4680e0> 0.44149286
<function rmse at 0x13d462f80> 0.2636996
-----
2
<function avg_nll at 0x13d468200> -0.050352775
<function r2 at 0x13d4680e0> -1.0841897
<function rmse at 0x13d462f80> 0.23096302
-----
3
<function avg_nll at 0x13d468200> -0.17185949
<function r2 at 0x13d4680e0> 0.64149666
<function rmse at 0x13d462f80> 0.17031498
-----
4
<function avg_nll at 0x13d468200> -0.004102695
<function r2 at 0x13d4680e0> 0.49314988
<function rmse at 0x13d462f80> 0.22750771
-----
5
<function avg_nll at 0x13d468200> -0.06378331
<function r2 at 0x13d4680e0> 0.15665519
<function rmse at 0x13d462f80> 0.22991867
-----
6
<function avg_nll at 0x13d468200> -0.36182767
<function r2 at 0x13d4680e0> 0.3867584
<function rmse at 0x13d462f80> 0.0035778624
-----
7
<function avg_nll at 0x13d468200> -0.272

In [36]:

gps = []
for idx in range(len(datasets)):
    net = pinot.representation.Sequential(
        pinot.representation.dgl_legacy.gn(),
        [32, 'tanh', 32, 'tanh', 32, 'tanh'])


    model = pinot.inference.gp.gpr.exact_gpr.ExactGPR(
        kernel=pinot.inference.gp.kernels.deep_kernel.DeepKernel(
            base_kernel=pinot.inference.gp.kernels.rbf.RBF(scale=torch.ones(32)),
            representation=net))
    
    gps.append(model)
    
    
for idx, gp in enumerate(gps):
    opt = torch.optim.Adam(gp.parameters(), 1e-5)
    for _ in range(500):
        opt.zero_grad()

        g, y = datasets[idx][0]
        loss = gp.loss(g, y)

        loss.backward()
        opt.step()
        
    print('-----')

    print(idx)
    for metric in metrics:
        g, y = datasets_te[idx][0]
        metric = getattr(pinot.metrics, metric)
        print(metric, metric(gp, g, y).detach().numpy())

        

-----
0
<function avg_nll at 0x13d468200> 0.4367059
<function r2 at 0x13d4680e0> 0.34786165
<function rmse at 0x13d462f80> 0.19472066
-----
1
<function avg_nll at 0x13d468200> 0.5991193
<function r2 at 0x13d4680e0> -0.019057631
<function rmse at 0x13d462f80> 0.35620043
-----
2
<function avg_nll at 0x13d468200> 0.7436651
<function r2 at 0x13d4680e0> -0.64520586
<function rmse at 0x13d462f80> 0.20520312
-----
3
<function avg_nll at 0x13d468200> 0.491587
<function r2 at 0x13d4680e0> 0.14984548
<function rmse at 0x13d462f80> 0.26227397
-----
4
<function avg_nll at 0x13d468200> 0.5207345
<function r2 at 0x13d4680e0> 0.056881487
<function rmse at 0x13d462f80> 0.3103412
-----
5
<function avg_nll at 0x13d468200> 0.5489281
<function r2 at 0x13d4680e0> 0.14821124
<function rmse at 0x13d462f80> 0.23106684
-----
6
<function avg_nll at 0x13d468200> 0.4421416
<function r2 at 0x13d4680e0> -5.554524
<function rmse at 0x13d462f80> 0.011697105
-----
7
<function avg_nll at 0x13d468200> 0.40021127
<functi

In [37]:
for g, y in datasets_te[0]:
    print(y)

tensor([[ 0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 9.8455e-01],
        [ 6.4150e-02],
        [ 1.8171e-01],
        [ 0.0000e+00],
        [ 5.7357e-01],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 2.5705e-02],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 9.5114e-03],
        [ 0.0000e+00],
        [ 1.0993e-01],
        [-8.1577e-03],
        [ 2.6035e-01],
        [ 5.4870e-01],
        [ 8.7545e-01],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 5.0294e-02],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [ 0.0000e+00],
        [-9.7614e-04],
        [-1.1300e-02],
        [ 2.5748e-01],
        [ 0.0000e+00],
        [ 2.2330e-01],
        [ 0.0000e+00],
        [ 1.2320e-01],
        [-2.4650e-02],
        [ 1.1033e-01],
        [ 2.5583e-01],
        [ 0