In [4]:
%load_ext autoreload
%autoreload 2
from scalable_gps.data import get_dataset
datasets = ['pol',
            'elevators',
            'bike',
            'kin40k',
            'protein',
            'keggdirected',
            '3droad',
            'song',
            'buzz',
            'houseelectric']
for dataset in datasets:
    data_train, data_test = get_dataset(dataset, split=0, normalise=True)
    print(f"train: {data_train.N}, test: {data_test.N}")
    print()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
pol dataset, N=15000, d=26
train: 13500, test: 1500

elevators dataset, N=16599, d=18
train: 14940, test: 1659

bike dataset, N=17379, d=17
train: 15642, test: 1737

kin40k dataset, N=40000, d=8
train: 36000, test: 4000

protein dataset, N=45730, d=9
train: 41157, test: 4573

keggdirected dataset, N=48827, d=20
train: 43945, test: 4882

3droad dataset, N=434874, d=3
train: 391387, test: 43487

song dataset, N=515345, d=90
train: 463811, test: 51534

buzz dataset, N=583250, d=77
train: 524925, test: 58325

houseelectric dataset, N=2049280, d=11
train: 1844352, test: 204928



In [2]:
import jax
import jax.numpy as jnp
from kernels import Matern32Kernel

kernel_config = {
    'signal_scale': 1.0,
    'length_scale': jnp.array([1])
}
noise_scale = 1.0

kernel = Matern32Kernel(kernel_config=kernel_config)

In [6]:

from models import SGDGPModel
import ml_collections

config = ml_collections.ConfigDict()
config.train_config = ml_collections.ConfigDict()

config.train_config.learning_rate = 1e-1
config.train_config.lr_schedule_name = None
config.train_config.lr_schedule_config = None
config.train_config.momentum = 0.9
config.train_config.polyak = 1e-3
config.train_config.iterations = 10
config.train_config.eval_every = 5
config.train_config.n_features_optim = 100
config.train_config.recompute_features = True
config.train_config.iterative_idx = True
config.train_config.time_budget_in_seconds = None
config.train_config.eval_every_in_seconds = None

optim_key = jax.random.PRNGKey(123)

sgd_gp = SGDGPModel(noise_scale=noise_scale, kernel=kernel)


config.train_config.batch_size = data_train.N
success = False
while not success:
    try:
        print(f"Trying batch size = {config.train_config.batch_size}")
        alpha, info = sgd_gp.compute_representer_weights(
            key=optim_key,
            train_ds=data_train,
            test_ds=data_test,
            config=config.train_config,
            metrics_list=['loss'],
            exact_metrics=None
        )
        success = True
    except Exception as e:
        config.train_config.batch_size = int(config.train_config.batch_size / 2)

Trying batch size = 1844352


  0%|          | 0/10 [03:08<?, ?it/s]


Trying batch size = 922176


  0%|          | 0/10 [03:02<?, ?it/s]


Trying batch size = 461088


  0%|          | 0/10 [02:57<?, ?it/s]


Trying batch size = 230544


  0%|          | 0/10 [02:54<?, ?it/s]


Trying batch size = 115272


  0%|          | 0/10 [02:51<?, ?it/s]


Trying batch size = 57636


  0%|          | 0/10 [02:50<?, ?it/s]


Trying batch size = 28818


  0%|          | 0/10 [02:49<?, ?it/s]


Trying batch size = 14409


  0%|          | 0/10 [02:47<?, ?it/s]


Trying batch size = 7204


  0%|          | 0/10 [02:46<?, ?it/s]


Trying batch size = 3602


  0%|          | 0/10 [02:46<?, ?it/s]


Trying batch size = 1801


100%|██████████| 10/10 [03:04<00:00, 18.40s/it]
