In [2]:
import math, numpy as np, pandas as pd
import torch
import gpytorch
from matplotlib import pyplot as plt

from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
import os
import time

from scipy.stats import spearmanr

home = os.path.join(os.path.dirname(os.path.realpath("__file__")), "..")
print(home)

  from .autonotebook import tqdm as notebook_tqdm


/home/jlparkinson1/Documents/gp_proteins/benchmarking_xGPR/jupyter_notebook/..


In [3]:
def get_xy(home, target_dir, get_files = False):
    os.chdir(os.path.join(home, "benchmark_evals"))
    os.chdir(target_dir)
    xfiles = [os.path.abspath(f) for f in os.listdir() if f.endswith("xvalues.npy")]
    yfiles = [os.path.abspath(f) for f in os.listdir() if f.endswith("yvalues.npy")]
    xfiles.sort()
    yfiles.sort()
    os.chdir(home)
    if get_files:
        return xfiles, yfiles
    
    x, y = [], []
    for i, xfile in enumerate(xfiles):
        x.append(np.load(xfile).astype(np.float32))
        y.append(np.load(yfiles[i]))
    
    x = np.vstack(x)
    y = np.concatenate(y)
    return x, y

In [4]:
class GPModel(ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)
        super(GPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [5]:
def build_test_gpytorch_model(home, target_dataset, num_epochs, num_inducing,
                                     minibatch_size = 1000):
    target_dir = os.path.join(target_dataset, "train")
    trainx, trainy = get_xy(home, target_dir)
    trainx = torch.from_numpy(trainx)
    trainy = torch.from_numpy(trainy)
    
    inducing_points = trainx[:num_inducing, :]
    
    model = GPModel(inducing_points=inducing_points)
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    model = model.cuda()
    likelihood = likelihood.cuda()
    model.train()
    likelihood.train()
    
    optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
    ], lr=0.01)

    mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=trainy.size(0))

    loss_tally = []
    wallclock = time.time()
    for i in range(num_epochs):
        niter = 0
        for j in range(0, trainx.size(0), minibatch_size):
            xbatch, ybatch = trainx[j:j+minibatch_size,:].cuda(), trainy[j:j+minibatch_size].cuda()
            optimizer.zero_grad()
            output = model(xbatch)
            loss = -mll(output, ybatch)
            loss.backward()
            optimizer.step()
            niter += 1
            if niter % 10 == 0:
                loss_tally.append(loss.item())
                print("Epoch %s, iteration %s"%(i,niter))
            
    wallclock = time.time() - wallclock
    
    
    target_dir = os.path.join(target_dataset, "test")
    testx, testy = get_xy(home, target_dir)
    testx = torch.from_numpy(testx).cuda()
    
    preds = []
    with torch.no_grad():
        for i in range(0, testx.shape[0], 1000):
            preds.append(model(testx[i:i+1000]).mean.cpu().numpy())

    preds = np.concatenate(preds)
    sscore = spearmanr(preds, testy)
    mae = np.mean(np.abs(preds - testy))
    return wallclock, sscore, mae

In [64]:
results_dict = {"dataset":[], "num_epochs":[], "minibatch_size":[],
                "num_inducing_pts":[],
               "wallclock_time":[], "spearmanr":[], "mae":[]}

In [65]:
for num_epochs in [20, 40]:
    for num_inducing in [500, 3000]:
        for dataset in ["fluorescence_eval/onehot/standard",
                       "aav_eval/onehot/des_mut_split",
                       "aav_eval/onehot/mut_des_split",
                       "aav_eval/onehot/seven_vs_many_split",
                       "kin40k_dataset",
                       "song_dataset/y_norm",
                       "uci_protein_dataset"]:
            wclock, spearman, mae = build_test_gpytorch_model(home, dataset, num_epochs = num_epochs, 
                              num_inducing = num_inducing, minibatch_size = 1000)
            results_dict["dataset"].append(dataset)
            results_dict["num_epochs"].append(num_epochs)
            results_dict["num_inducing_pts"].append(num_inducing)
            results_dict["spearmanr"].append(spearman)
            results_dict["mae"].append(mae)
            results_dict["wallclock_time"].append(wclock)
            results_dict["minibatch_size"].append(1000)
    
    time.sleep(60)

Epoch 0, iteration 10
Epoch 0, iteration 20
Epoch 1, iteration 10
Epoch 1, iteration 20
Epoch 2, iteration 10
Epoch 2, iteration 20
Epoch 3, iteration 10
Epoch 3, iteration 20
Epoch 4, iteration 10
Epoch 4, iteration 20
Epoch 5, iteration 10
Epoch 5, iteration 20
Epoch 6, iteration 10
Epoch 6, iteration 20
Epoch 7, iteration 10
Epoch 7, iteration 20
Epoch 8, iteration 10
Epoch 8, iteration 20
Epoch 9, iteration 10
Epoch 9, iteration 20
Epoch 10, iteration 10
Epoch 10, iteration 20
Epoch 11, iteration 10
Epoch 11, iteration 20
Epoch 12, iteration 10
Epoch 12, iteration 20
Epoch 13, iteration 10
Epoch 13, iteration 20
Epoch 14, iteration 10
Epoch 14, iteration 20
Epoch 15, iteration 10
Epoch 15, iteration 20
Epoch 16, iteration 10
Epoch 16, iteration 20
Epoch 17, iteration 10
Epoch 17, iteration 20
Epoch 18, iteration 10
Epoch 18, iteration 20
Epoch 19, iteration 10
Epoch 19, iteration 20
Epoch 0, iteration 10
Epoch 0, iteration 20
Epoch 0, iteration 30
Epoch 0, iteration 40
Epoch 0, ite

Epoch 17, iteration 140
Epoch 17, iteration 150
Epoch 17, iteration 160
Epoch 17, iteration 170
Epoch 17, iteration 180
Epoch 18, iteration 10
Epoch 18, iteration 20
Epoch 18, iteration 30
Epoch 18, iteration 40
Epoch 18, iteration 50
Epoch 18, iteration 60
Epoch 18, iteration 70
Epoch 18, iteration 80
Epoch 18, iteration 90
Epoch 18, iteration 100
Epoch 18, iteration 110
Epoch 18, iteration 120
Epoch 18, iteration 130
Epoch 18, iteration 140
Epoch 18, iteration 150
Epoch 18, iteration 160
Epoch 18, iteration 170
Epoch 18, iteration 180
Epoch 19, iteration 10
Epoch 19, iteration 20
Epoch 19, iteration 30
Epoch 19, iteration 40
Epoch 19, iteration 50
Epoch 19, iteration 60
Epoch 19, iteration 70
Epoch 19, iteration 80
Epoch 19, iteration 90
Epoch 19, iteration 100
Epoch 19, iteration 110
Epoch 19, iteration 120
Epoch 19, iteration 130
Epoch 19, iteration 140
Epoch 19, iteration 150
Epoch 19, iteration 160
Epoch 19, iteration 170
Epoch 19, iteration 180
Epoch 0, iteration 10
Epoch 0, ite

Epoch 0, iteration 30
Epoch 0, iteration 40
Epoch 0, iteration 50
Epoch 0, iteration 60
Epoch 0, iteration 70
Epoch 0, iteration 80
Epoch 0, iteration 90
Epoch 0, iteration 100
Epoch 0, iteration 110
Epoch 0, iteration 120
Epoch 0, iteration 130
Epoch 0, iteration 140
Epoch 0, iteration 150
Epoch 0, iteration 160
Epoch 0, iteration 170
Epoch 0, iteration 180
Epoch 0, iteration 190
Epoch 0, iteration 200
Epoch 0, iteration 210
Epoch 0, iteration 220
Epoch 0, iteration 230
Epoch 0, iteration 240
Epoch 0, iteration 250
Epoch 0, iteration 260
Epoch 0, iteration 270
Epoch 0, iteration 280
Epoch 0, iteration 290
Epoch 0, iteration 300
Epoch 0, iteration 310
Epoch 0, iteration 320
Epoch 0, iteration 330
Epoch 0, iteration 340
Epoch 0, iteration 350
Epoch 0, iteration 360
Epoch 0, iteration 370
Epoch 0, iteration 380
Epoch 0, iteration 390
Epoch 0, iteration 400
Epoch 0, iteration 410
Epoch 1, iteration 10
Epoch 1, iteration 20
Epoch 1, iteration 30
Epoch 1, iteration 40
Epoch 1, iteration 50


Epoch 8, iteration 350
Epoch 8, iteration 360
Epoch 8, iteration 370
Epoch 8, iteration 380
Epoch 8, iteration 390
Epoch 8, iteration 400
Epoch 8, iteration 410
Epoch 9, iteration 10
Epoch 9, iteration 20
Epoch 9, iteration 30
Epoch 9, iteration 40
Epoch 9, iteration 50
Epoch 9, iteration 60
Epoch 9, iteration 70
Epoch 9, iteration 80
Epoch 9, iteration 90
Epoch 9, iteration 100
Epoch 9, iteration 110
Epoch 9, iteration 120
Epoch 9, iteration 130
Epoch 9, iteration 140
Epoch 9, iteration 150
Epoch 9, iteration 160
Epoch 9, iteration 170
Epoch 9, iteration 180
Epoch 9, iteration 190
Epoch 9, iteration 200
Epoch 9, iteration 210
Epoch 9, iteration 220
Epoch 9, iteration 230
Epoch 9, iteration 240
Epoch 9, iteration 250
Epoch 9, iteration 260
Epoch 9, iteration 270
Epoch 9, iteration 280
Epoch 9, iteration 290
Epoch 9, iteration 300
Epoch 9, iteration 310
Epoch 9, iteration 320
Epoch 9, iteration 330
Epoch 9, iteration 340
Epoch 9, iteration 350
Epoch 9, iteration 360
Epoch 9, iteration 3

Epoch 17, iteration 130
Epoch 17, iteration 140
Epoch 17, iteration 150
Epoch 17, iteration 160
Epoch 17, iteration 170
Epoch 17, iteration 180
Epoch 17, iteration 190
Epoch 17, iteration 200
Epoch 17, iteration 210
Epoch 17, iteration 220
Epoch 17, iteration 230
Epoch 17, iteration 240
Epoch 17, iteration 250
Epoch 17, iteration 260
Epoch 17, iteration 270
Epoch 17, iteration 280
Epoch 17, iteration 290
Epoch 17, iteration 300
Epoch 17, iteration 310
Epoch 17, iteration 320
Epoch 17, iteration 330
Epoch 17, iteration 340
Epoch 17, iteration 350
Epoch 17, iteration 360
Epoch 17, iteration 370
Epoch 17, iteration 380
Epoch 17, iteration 390
Epoch 17, iteration 400
Epoch 17, iteration 410
Epoch 18, iteration 10
Epoch 18, iteration 20
Epoch 18, iteration 30
Epoch 18, iteration 40
Epoch 18, iteration 50
Epoch 18, iteration 60
Epoch 18, iteration 70
Epoch 18, iteration 80
Epoch 18, iteration 90
Epoch 18, iteration 100
Epoch 18, iteration 110
Epoch 18, iteration 120
Epoch 18, iteration 130
E

Epoch 8, iteration 40
Epoch 8, iteration 50
Epoch 8, iteration 60
Epoch 8, iteration 70
Epoch 8, iteration 80
Epoch 8, iteration 90
Epoch 8, iteration 100
Epoch 8, iteration 110
Epoch 8, iteration 120
Epoch 8, iteration 130
Epoch 8, iteration 140
Epoch 8, iteration 150
Epoch 8, iteration 160
Epoch 8, iteration 170
Epoch 8, iteration 180
Epoch 9, iteration 10
Epoch 9, iteration 20
Epoch 9, iteration 30
Epoch 9, iteration 40
Epoch 9, iteration 50
Epoch 9, iteration 60
Epoch 9, iteration 70
Epoch 9, iteration 80
Epoch 9, iteration 90
Epoch 9, iteration 100
Epoch 9, iteration 110
Epoch 9, iteration 120
Epoch 9, iteration 130
Epoch 9, iteration 140
Epoch 9, iteration 150
Epoch 9, iteration 160
Epoch 9, iteration 170
Epoch 9, iteration 180
Epoch 10, iteration 10
Epoch 10, iteration 20
Epoch 10, iteration 30
Epoch 10, iteration 40
Epoch 10, iteration 50
Epoch 10, iteration 60
Epoch 10, iteration 70
Epoch 10, iteration 80
Epoch 10, iteration 90
Epoch 10, iteration 100
Epoch 10, iteration 110
E

Epoch 0, iteration 50
Epoch 0, iteration 60
Epoch 1, iteration 10
Epoch 1, iteration 20
Epoch 1, iteration 30
Epoch 1, iteration 40
Epoch 1, iteration 50
Epoch 1, iteration 60
Epoch 2, iteration 10
Epoch 2, iteration 20
Epoch 2, iteration 30
Epoch 2, iteration 40
Epoch 2, iteration 50
Epoch 2, iteration 60
Epoch 3, iteration 10
Epoch 3, iteration 20
Epoch 3, iteration 30
Epoch 3, iteration 40
Epoch 3, iteration 50
Epoch 3, iteration 60
Epoch 4, iteration 10
Epoch 4, iteration 20
Epoch 4, iteration 30
Epoch 4, iteration 40
Epoch 4, iteration 50
Epoch 4, iteration 60
Epoch 5, iteration 10
Epoch 5, iteration 20
Epoch 5, iteration 30
Epoch 5, iteration 40
Epoch 5, iteration 50
Epoch 5, iteration 60
Epoch 6, iteration 10
Epoch 6, iteration 20
Epoch 6, iteration 30
Epoch 6, iteration 40
Epoch 6, iteration 50
Epoch 6, iteration 60
Epoch 7, iteration 10
Epoch 7, iteration 20
Epoch 7, iteration 30
Epoch 7, iteration 40
Epoch 7, iteration 50
Epoch 7, iteration 60
Epoch 8, iteration 10
Epoch 8, i

Epoch 4, iteration 230
Epoch 4, iteration 240
Epoch 4, iteration 250
Epoch 4, iteration 260
Epoch 4, iteration 270
Epoch 4, iteration 280
Epoch 4, iteration 290
Epoch 4, iteration 300
Epoch 4, iteration 310
Epoch 4, iteration 320
Epoch 4, iteration 330
Epoch 4, iteration 340
Epoch 4, iteration 350
Epoch 4, iteration 360
Epoch 4, iteration 370
Epoch 4, iteration 380
Epoch 4, iteration 390
Epoch 4, iteration 400
Epoch 4, iteration 410
Epoch 5, iteration 10
Epoch 5, iteration 20
Epoch 5, iteration 30
Epoch 5, iteration 40
Epoch 5, iteration 50
Epoch 5, iteration 60
Epoch 5, iteration 70
Epoch 5, iteration 80
Epoch 5, iteration 90
Epoch 5, iteration 100
Epoch 5, iteration 110
Epoch 5, iteration 120
Epoch 5, iteration 130
Epoch 5, iteration 140
Epoch 5, iteration 150
Epoch 5, iteration 160
Epoch 5, iteration 170
Epoch 5, iteration 180
Epoch 5, iteration 190
Epoch 5, iteration 200
Epoch 5, iteration 210
Epoch 5, iteration 220
Epoch 5, iteration 230
Epoch 5, iteration 240
Epoch 5, iteration 2

Epoch 13, iteration 80
Epoch 13, iteration 90
Epoch 13, iteration 100
Epoch 13, iteration 110
Epoch 13, iteration 120
Epoch 13, iteration 130
Epoch 13, iteration 140
Epoch 13, iteration 150
Epoch 13, iteration 160
Epoch 13, iteration 170
Epoch 13, iteration 180
Epoch 13, iteration 190
Epoch 13, iteration 200
Epoch 13, iteration 210
Epoch 13, iteration 220
Epoch 13, iteration 230
Epoch 13, iteration 240
Epoch 13, iteration 250
Epoch 13, iteration 260
Epoch 13, iteration 270
Epoch 13, iteration 280
Epoch 13, iteration 290
Epoch 13, iteration 300
Epoch 13, iteration 310
Epoch 13, iteration 320
Epoch 13, iteration 330
Epoch 13, iteration 340
Epoch 13, iteration 350
Epoch 13, iteration 360
Epoch 13, iteration 370
Epoch 13, iteration 380
Epoch 13, iteration 390
Epoch 13, iteration 400
Epoch 13, iteration 410
Epoch 14, iteration 10
Epoch 14, iteration 20
Epoch 14, iteration 30
Epoch 14, iteration 40
Epoch 14, iteration 50
Epoch 14, iteration 60
Epoch 14, iteration 70
Epoch 14, iteration 80
Ep

Epoch 4, iteration 20
Epoch 5, iteration 10
Epoch 5, iteration 20
Epoch 6, iteration 10
Epoch 6, iteration 20
Epoch 7, iteration 10
Epoch 7, iteration 20
Epoch 8, iteration 10
Epoch 8, iteration 20
Epoch 9, iteration 10
Epoch 9, iteration 20
Epoch 10, iteration 10
Epoch 10, iteration 20
Epoch 11, iteration 10
Epoch 11, iteration 20
Epoch 12, iteration 10
Epoch 12, iteration 20
Epoch 13, iteration 10
Epoch 13, iteration 20
Epoch 14, iteration 10
Epoch 14, iteration 20
Epoch 15, iteration 10
Epoch 15, iteration 20
Epoch 16, iteration 10
Epoch 16, iteration 20
Epoch 17, iteration 10
Epoch 17, iteration 20
Epoch 18, iteration 10
Epoch 18, iteration 20
Epoch 19, iteration 10
Epoch 19, iteration 20
Epoch 20, iteration 10
Epoch 20, iteration 20
Epoch 21, iteration 10
Epoch 21, iteration 20
Epoch 22, iteration 10
Epoch 22, iteration 20
Epoch 23, iteration 10
Epoch 23, iteration 20
Epoch 24, iteration 10
Epoch 24, iteration 20
Epoch 25, iteration 10
Epoch 25, iteration 20
Epoch 26, iteration 10

Epoch 16, iteration 10
Epoch 16, iteration 20
Epoch 16, iteration 30
Epoch 16, iteration 40
Epoch 16, iteration 50
Epoch 16, iteration 60
Epoch 16, iteration 70
Epoch 16, iteration 80
Epoch 16, iteration 90
Epoch 16, iteration 100
Epoch 16, iteration 110
Epoch 16, iteration 120
Epoch 16, iteration 130
Epoch 16, iteration 140
Epoch 16, iteration 150
Epoch 16, iteration 160
Epoch 16, iteration 170
Epoch 16, iteration 180
Epoch 17, iteration 10
Epoch 17, iteration 20
Epoch 17, iteration 30
Epoch 17, iteration 40
Epoch 17, iteration 50
Epoch 17, iteration 60
Epoch 17, iteration 70
Epoch 17, iteration 80
Epoch 17, iteration 90
Epoch 17, iteration 100
Epoch 17, iteration 110
Epoch 17, iteration 120
Epoch 17, iteration 130
Epoch 17, iteration 140
Epoch 17, iteration 150
Epoch 17, iteration 160
Epoch 17, iteration 170
Epoch 17, iteration 180
Epoch 18, iteration 10
Epoch 18, iteration 20
Epoch 18, iteration 30
Epoch 18, iteration 40
Epoch 18, iteration 50
Epoch 18, iteration 60
Epoch 18, iterat

Epoch 35, iteration 80
Epoch 35, iteration 90
Epoch 35, iteration 100
Epoch 35, iteration 110
Epoch 35, iteration 120
Epoch 35, iteration 130
Epoch 35, iteration 140
Epoch 35, iteration 150
Epoch 35, iteration 160
Epoch 35, iteration 170
Epoch 35, iteration 180
Epoch 36, iteration 10
Epoch 36, iteration 20
Epoch 36, iteration 30
Epoch 36, iteration 40
Epoch 36, iteration 50
Epoch 36, iteration 60
Epoch 36, iteration 70
Epoch 36, iteration 80
Epoch 36, iteration 90
Epoch 36, iteration 100
Epoch 36, iteration 110
Epoch 36, iteration 120
Epoch 36, iteration 130
Epoch 36, iteration 140
Epoch 36, iteration 150
Epoch 36, iteration 160
Epoch 36, iteration 170
Epoch 36, iteration 180
Epoch 37, iteration 10
Epoch 37, iteration 20
Epoch 37, iteration 30
Epoch 37, iteration 40
Epoch 37, iteration 50
Epoch 37, iteration 60
Epoch 37, iteration 70
Epoch 37, iteration 80
Epoch 37, iteration 90
Epoch 37, iteration 100
Epoch 37, iteration 110
Epoch 37, iteration 120
Epoch 37, iteration 130
Epoch 37, it

Epoch 39, iteration 30
Epoch 39, iteration 40
Epoch 39, iteration 50
Epoch 39, iteration 60
Epoch 39, iteration 70
Epoch 0, iteration 10
Epoch 0, iteration 20
Epoch 0, iteration 30
Epoch 0, iteration 40
Epoch 0, iteration 50
Epoch 0, iteration 60
Epoch 1, iteration 10
Epoch 1, iteration 20
Epoch 1, iteration 30
Epoch 1, iteration 40
Epoch 1, iteration 50
Epoch 1, iteration 60
Epoch 2, iteration 10
Epoch 2, iteration 20
Epoch 2, iteration 30
Epoch 2, iteration 40
Epoch 2, iteration 50
Epoch 2, iteration 60
Epoch 3, iteration 10
Epoch 3, iteration 20
Epoch 3, iteration 30
Epoch 3, iteration 40
Epoch 3, iteration 50
Epoch 3, iteration 60
Epoch 4, iteration 10
Epoch 4, iteration 20
Epoch 4, iteration 30
Epoch 4, iteration 40
Epoch 4, iteration 50
Epoch 4, iteration 60
Epoch 5, iteration 10
Epoch 5, iteration 20
Epoch 5, iteration 30
Epoch 5, iteration 40
Epoch 5, iteration 50
Epoch 5, iteration 60
Epoch 6, iteration 10
Epoch 6, iteration 20
Epoch 6, iteration 30
Epoch 6, iteration 40
Epoch

Epoch 38, iteration 30
Epoch 39, iteration 10
Epoch 39, iteration 20
Epoch 39, iteration 30
Epoch 0, iteration 10
Epoch 0, iteration 20
Epoch 0, iteration 30
Epoch 0, iteration 40
Epoch 0, iteration 50
Epoch 0, iteration 60
Epoch 0, iteration 70
Epoch 0, iteration 80
Epoch 0, iteration 90
Epoch 0, iteration 100
Epoch 0, iteration 110
Epoch 0, iteration 120
Epoch 0, iteration 130
Epoch 0, iteration 140
Epoch 0, iteration 150
Epoch 0, iteration 160
Epoch 0, iteration 170
Epoch 0, iteration 180
Epoch 0, iteration 190
Epoch 0, iteration 200
Epoch 0, iteration 210
Epoch 0, iteration 220
Epoch 0, iteration 230
Epoch 0, iteration 240
Epoch 0, iteration 250
Epoch 0, iteration 260
Epoch 0, iteration 270
Epoch 0, iteration 280
Epoch 0, iteration 290
Epoch 0, iteration 300
Epoch 0, iteration 310
Epoch 0, iteration 320
Epoch 0, iteration 330
Epoch 0, iteration 340
Epoch 0, iteration 350
Epoch 0, iteration 360
Epoch 0, iteration 370
Epoch 0, iteration 380
Epoch 0, iteration 390
Epoch 0, iteration 4

Epoch 8, iteration 290
Epoch 8, iteration 300
Epoch 8, iteration 310
Epoch 8, iteration 320
Epoch 8, iteration 330
Epoch 8, iteration 340
Epoch 8, iteration 350
Epoch 8, iteration 360
Epoch 8, iteration 370
Epoch 8, iteration 380
Epoch 8, iteration 390
Epoch 8, iteration 400
Epoch 8, iteration 410
Epoch 9, iteration 10
Epoch 9, iteration 20
Epoch 9, iteration 30
Epoch 9, iteration 40
Epoch 9, iteration 50
Epoch 9, iteration 60
Epoch 9, iteration 70
Epoch 9, iteration 80
Epoch 9, iteration 90
Epoch 9, iteration 100
Epoch 9, iteration 110
Epoch 9, iteration 120
Epoch 9, iteration 130
Epoch 9, iteration 140
Epoch 9, iteration 150
Epoch 9, iteration 160
Epoch 9, iteration 170
Epoch 9, iteration 180
Epoch 9, iteration 190
Epoch 9, iteration 200
Epoch 9, iteration 210
Epoch 9, iteration 220
Epoch 9, iteration 230
Epoch 9, iteration 240
Epoch 9, iteration 250
Epoch 9, iteration 260
Epoch 9, iteration 270
Epoch 9, iteration 280
Epoch 9, iteration 290
Epoch 9, iteration 300
Epoch 9, iteration 3

Epoch 17, iteration 70
Epoch 17, iteration 80
Epoch 17, iteration 90
Epoch 17, iteration 100
Epoch 17, iteration 110
Epoch 17, iteration 120
Epoch 17, iteration 130
Epoch 17, iteration 140
Epoch 17, iteration 150
Epoch 17, iteration 160
Epoch 17, iteration 170
Epoch 17, iteration 180
Epoch 17, iteration 190
Epoch 17, iteration 200
Epoch 17, iteration 210
Epoch 17, iteration 220
Epoch 17, iteration 230
Epoch 17, iteration 240
Epoch 17, iteration 250
Epoch 17, iteration 260
Epoch 17, iteration 270
Epoch 17, iteration 280
Epoch 17, iteration 290
Epoch 17, iteration 300
Epoch 17, iteration 310
Epoch 17, iteration 320
Epoch 17, iteration 330
Epoch 17, iteration 340
Epoch 17, iteration 350
Epoch 17, iteration 360
Epoch 17, iteration 370
Epoch 17, iteration 380
Epoch 17, iteration 390
Epoch 17, iteration 400
Epoch 17, iteration 410
Epoch 18, iteration 10
Epoch 18, iteration 20
Epoch 18, iteration 30
Epoch 18, iteration 40
Epoch 18, iteration 50
Epoch 18, iteration 60
Epoch 18, iteration 70
Ep

Epoch 25, iteration 240
Epoch 25, iteration 250
Epoch 25, iteration 260
Epoch 25, iteration 270
Epoch 25, iteration 280
Epoch 25, iteration 290
Epoch 25, iteration 300
Epoch 25, iteration 310
Epoch 25, iteration 320
Epoch 25, iteration 330
Epoch 25, iteration 340
Epoch 25, iteration 350
Epoch 25, iteration 360
Epoch 25, iteration 370
Epoch 25, iteration 380
Epoch 25, iteration 390
Epoch 25, iteration 400
Epoch 25, iteration 410
Epoch 26, iteration 10
Epoch 26, iteration 20
Epoch 26, iteration 30
Epoch 26, iteration 40
Epoch 26, iteration 50
Epoch 26, iteration 60
Epoch 26, iteration 70
Epoch 26, iteration 80
Epoch 26, iteration 90
Epoch 26, iteration 100
Epoch 26, iteration 110
Epoch 26, iteration 120
Epoch 26, iteration 130
Epoch 26, iteration 140
Epoch 26, iteration 150
Epoch 26, iteration 160
Epoch 26, iteration 170
Epoch 26, iteration 180
Epoch 26, iteration 190
Epoch 26, iteration 200
Epoch 26, iteration 210
Epoch 26, iteration 220
Epoch 26, iteration 230
Epoch 26, iteration 240
E

Epoch 33, iteration 410
Epoch 34, iteration 10
Epoch 34, iteration 20
Epoch 34, iteration 30
Epoch 34, iteration 40
Epoch 34, iteration 50
Epoch 34, iteration 60
Epoch 34, iteration 70
Epoch 34, iteration 80
Epoch 34, iteration 90
Epoch 34, iteration 100
Epoch 34, iteration 110
Epoch 34, iteration 120
Epoch 34, iteration 130
Epoch 34, iteration 140
Epoch 34, iteration 150
Epoch 34, iteration 160
Epoch 34, iteration 170
Epoch 34, iteration 180
Epoch 34, iteration 190
Epoch 34, iteration 200
Epoch 34, iteration 210
Epoch 34, iteration 220
Epoch 34, iteration 230
Epoch 34, iteration 240
Epoch 34, iteration 250
Epoch 34, iteration 260
Epoch 34, iteration 270
Epoch 34, iteration 280
Epoch 34, iteration 290
Epoch 34, iteration 300
Epoch 34, iteration 310
Epoch 34, iteration 320
Epoch 34, iteration 330
Epoch 34, iteration 340
Epoch 34, iteration 350
Epoch 34, iteration 360
Epoch 34, iteration 370
Epoch 34, iteration 380
Epoch 34, iteration 390
Epoch 34, iteration 400
Epoch 34, iteration 410
E

Epoch 34, iteration 20
Epoch 34, iteration 30
Epoch 35, iteration 10
Epoch 35, iteration 20
Epoch 35, iteration 30
Epoch 36, iteration 10
Epoch 36, iteration 20
Epoch 36, iteration 30
Epoch 37, iteration 10
Epoch 37, iteration 20
Epoch 37, iteration 30
Epoch 38, iteration 10
Epoch 38, iteration 20
Epoch 38, iteration 30
Epoch 39, iteration 10
Epoch 39, iteration 20
Epoch 39, iteration 30
Epoch 0, iteration 10
Epoch 0, iteration 20
Epoch 1, iteration 10
Epoch 1, iteration 20
Epoch 2, iteration 10
Epoch 2, iteration 20
Epoch 3, iteration 10
Epoch 3, iteration 20
Epoch 4, iteration 10
Epoch 4, iteration 20
Epoch 5, iteration 10
Epoch 5, iteration 20
Epoch 6, iteration 10
Epoch 6, iteration 20
Epoch 7, iteration 10
Epoch 7, iteration 20
Epoch 8, iteration 10
Epoch 8, iteration 20
Epoch 9, iteration 10
Epoch 9, iteration 20
Epoch 10, iteration 10
Epoch 10, iteration 20
Epoch 11, iteration 10
Epoch 11, iteration 20
Epoch 12, iteration 10
Epoch 12, iteration 20
Epoch 13, iteration 10
Epoch 13

Epoch 14, iteration 120
Epoch 14, iteration 130
Epoch 14, iteration 140
Epoch 14, iteration 150
Epoch 14, iteration 160
Epoch 14, iteration 170
Epoch 14, iteration 180
Epoch 15, iteration 10
Epoch 15, iteration 20
Epoch 15, iteration 30
Epoch 15, iteration 40
Epoch 15, iteration 50
Epoch 15, iteration 60
Epoch 15, iteration 70
Epoch 15, iteration 80
Epoch 15, iteration 90
Epoch 15, iteration 100
Epoch 15, iteration 110
Epoch 15, iteration 120
Epoch 15, iteration 130
Epoch 15, iteration 140
Epoch 15, iteration 150
Epoch 15, iteration 160
Epoch 15, iteration 170
Epoch 15, iteration 180
Epoch 16, iteration 10
Epoch 16, iteration 20
Epoch 16, iteration 30
Epoch 16, iteration 40
Epoch 16, iteration 50
Epoch 16, iteration 60
Epoch 16, iteration 70
Epoch 16, iteration 80
Epoch 16, iteration 90
Epoch 16, iteration 100
Epoch 16, iteration 110
Epoch 16, iteration 120
Epoch 16, iteration 130
Epoch 16, iteration 140
Epoch 16, iteration 150
Epoch 16, iteration 160
Epoch 16, iteration 170
Epoch 16, 

Epoch 34, iteration 10
Epoch 34, iteration 20
Epoch 34, iteration 30
Epoch 34, iteration 40
Epoch 34, iteration 50
Epoch 34, iteration 60
Epoch 34, iteration 70
Epoch 34, iteration 80
Epoch 34, iteration 90
Epoch 34, iteration 100
Epoch 34, iteration 110
Epoch 34, iteration 120
Epoch 34, iteration 130
Epoch 34, iteration 140
Epoch 34, iteration 150
Epoch 34, iteration 160
Epoch 34, iteration 170
Epoch 34, iteration 180
Epoch 35, iteration 10
Epoch 35, iteration 20
Epoch 35, iteration 30
Epoch 35, iteration 40
Epoch 35, iteration 50
Epoch 35, iteration 60
Epoch 35, iteration 70
Epoch 35, iteration 80
Epoch 35, iteration 90
Epoch 35, iteration 100
Epoch 35, iteration 110
Epoch 35, iteration 120
Epoch 35, iteration 130
Epoch 35, iteration 140
Epoch 35, iteration 150
Epoch 35, iteration 160
Epoch 35, iteration 170
Epoch 35, iteration 180
Epoch 36, iteration 10
Epoch 36, iteration 20
Epoch 36, iteration 30
Epoch 36, iteration 40
Epoch 36, iteration 50
Epoch 36, iteration 60
Epoch 36, iterat

Epoch 35, iteration 50
Epoch 35, iteration 60
Epoch 35, iteration 70
Epoch 36, iteration 10
Epoch 36, iteration 20
Epoch 36, iteration 30
Epoch 36, iteration 40
Epoch 36, iteration 50
Epoch 36, iteration 60
Epoch 36, iteration 70
Epoch 37, iteration 10
Epoch 37, iteration 20
Epoch 37, iteration 30
Epoch 37, iteration 40
Epoch 37, iteration 50
Epoch 37, iteration 60
Epoch 37, iteration 70
Epoch 38, iteration 10
Epoch 38, iteration 20
Epoch 38, iteration 30
Epoch 38, iteration 40
Epoch 38, iteration 50
Epoch 38, iteration 60
Epoch 38, iteration 70
Epoch 39, iteration 10
Epoch 39, iteration 20
Epoch 39, iteration 30
Epoch 39, iteration 40
Epoch 39, iteration 50
Epoch 39, iteration 60
Epoch 39, iteration 70
Epoch 0, iteration 10
Epoch 0, iteration 20
Epoch 0, iteration 30
Epoch 0, iteration 40
Epoch 0, iteration 50
Epoch 0, iteration 60
Epoch 1, iteration 10
Epoch 1, iteration 20
Epoch 1, iteration 30
Epoch 1, iteration 40
Epoch 1, iteration 50
Epoch 1, iteration 60
Epoch 2, iteration 10
E

Epoch 30, iteration 10
Epoch 30, iteration 20
Epoch 30, iteration 30
Epoch 31, iteration 10
Epoch 31, iteration 20
Epoch 31, iteration 30
Epoch 32, iteration 10
Epoch 32, iteration 20
Epoch 32, iteration 30
Epoch 33, iteration 10
Epoch 33, iteration 20
Epoch 33, iteration 30
Epoch 34, iteration 10
Epoch 34, iteration 20
Epoch 34, iteration 30
Epoch 35, iteration 10
Epoch 35, iteration 20
Epoch 35, iteration 30
Epoch 36, iteration 10
Epoch 36, iteration 20
Epoch 36, iteration 30
Epoch 37, iteration 10
Epoch 37, iteration 20
Epoch 37, iteration 30
Epoch 38, iteration 10
Epoch 38, iteration 20
Epoch 38, iteration 30
Epoch 39, iteration 10
Epoch 39, iteration 20
Epoch 39, iteration 30
Epoch 0, iteration 10
Epoch 0, iteration 20
Epoch 0, iteration 30
Epoch 0, iteration 40
Epoch 0, iteration 50
Epoch 0, iteration 60
Epoch 0, iteration 70
Epoch 0, iteration 80
Epoch 0, iteration 90
Epoch 0, iteration 100
Epoch 0, iteration 110
Epoch 0, iteration 120
Epoch 0, iteration 130
Epoch 0, iteration 1

Epoch 8, iteration 30
Epoch 8, iteration 40
Epoch 8, iteration 50
Epoch 8, iteration 60
Epoch 8, iteration 70
Epoch 8, iteration 80
Epoch 8, iteration 90
Epoch 8, iteration 100
Epoch 8, iteration 110
Epoch 8, iteration 120
Epoch 8, iteration 130
Epoch 8, iteration 140
Epoch 8, iteration 150
Epoch 8, iteration 160
Epoch 8, iteration 170
Epoch 8, iteration 180
Epoch 8, iteration 190
Epoch 8, iteration 200
Epoch 8, iteration 210
Epoch 8, iteration 220
Epoch 8, iteration 230
Epoch 8, iteration 240
Epoch 8, iteration 250
Epoch 8, iteration 260
Epoch 8, iteration 270
Epoch 8, iteration 280
Epoch 8, iteration 290
Epoch 8, iteration 300
Epoch 8, iteration 310
Epoch 8, iteration 320
Epoch 8, iteration 330
Epoch 8, iteration 340
Epoch 8, iteration 350
Epoch 8, iteration 360
Epoch 8, iteration 370
Epoch 8, iteration 380
Epoch 8, iteration 390
Epoch 8, iteration 400
Epoch 8, iteration 410
Epoch 9, iteration 10
Epoch 9, iteration 20
Epoch 9, iteration 30
Epoch 9, iteration 40
Epoch 9, iteration 50


Epoch 16, iteration 230
Epoch 16, iteration 240
Epoch 16, iteration 250
Epoch 16, iteration 260
Epoch 16, iteration 270
Epoch 16, iteration 280
Epoch 16, iteration 290
Epoch 16, iteration 300
Epoch 16, iteration 310
Epoch 16, iteration 320
Epoch 16, iteration 330
Epoch 16, iteration 340
Epoch 16, iteration 350
Epoch 16, iteration 360
Epoch 16, iteration 370
Epoch 16, iteration 380
Epoch 16, iteration 390
Epoch 16, iteration 400
Epoch 16, iteration 410
Epoch 17, iteration 10
Epoch 17, iteration 20
Epoch 17, iteration 30
Epoch 17, iteration 40
Epoch 17, iteration 50
Epoch 17, iteration 60
Epoch 17, iteration 70
Epoch 17, iteration 80
Epoch 17, iteration 90
Epoch 17, iteration 100
Epoch 17, iteration 110
Epoch 17, iteration 120
Epoch 17, iteration 130
Epoch 17, iteration 140
Epoch 17, iteration 150
Epoch 17, iteration 160
Epoch 17, iteration 170
Epoch 17, iteration 180
Epoch 17, iteration 190
Epoch 17, iteration 200
Epoch 17, iteration 210
Epoch 17, iteration 220
Epoch 17, iteration 230
E

Epoch 24, iteration 400
Epoch 24, iteration 410
Epoch 25, iteration 10
Epoch 25, iteration 20
Epoch 25, iteration 30
Epoch 25, iteration 40
Epoch 25, iteration 50
Epoch 25, iteration 60
Epoch 25, iteration 70
Epoch 25, iteration 80
Epoch 25, iteration 90
Epoch 25, iteration 100
Epoch 25, iteration 110
Epoch 25, iteration 120
Epoch 25, iteration 130
Epoch 25, iteration 140
Epoch 25, iteration 150
Epoch 25, iteration 160
Epoch 25, iteration 170
Epoch 25, iteration 180
Epoch 25, iteration 190
Epoch 25, iteration 200
Epoch 25, iteration 210
Epoch 25, iteration 220
Epoch 25, iteration 230
Epoch 25, iteration 240
Epoch 25, iteration 250
Epoch 25, iteration 260
Epoch 25, iteration 270
Epoch 25, iteration 280
Epoch 25, iteration 290
Epoch 25, iteration 300
Epoch 25, iteration 310
Epoch 25, iteration 320
Epoch 25, iteration 330
Epoch 25, iteration 340
Epoch 25, iteration 350
Epoch 25, iteration 360
Epoch 25, iteration 370
Epoch 25, iteration 380
Epoch 25, iteration 390
Epoch 25, iteration 400
E

Epoch 33, iteration 160
Epoch 33, iteration 170
Epoch 33, iteration 180
Epoch 33, iteration 190
Epoch 33, iteration 200
Epoch 33, iteration 210
Epoch 33, iteration 220
Epoch 33, iteration 230
Epoch 33, iteration 240
Epoch 33, iteration 250
Epoch 33, iteration 260
Epoch 33, iteration 270
Epoch 33, iteration 280
Epoch 33, iteration 290
Epoch 33, iteration 300
Epoch 33, iteration 310
Epoch 33, iteration 320
Epoch 33, iteration 330
Epoch 33, iteration 340
Epoch 33, iteration 350
Epoch 33, iteration 360
Epoch 33, iteration 370
Epoch 33, iteration 380
Epoch 33, iteration 390
Epoch 33, iteration 400
Epoch 33, iteration 410
Epoch 34, iteration 10
Epoch 34, iteration 20
Epoch 34, iteration 30
Epoch 34, iteration 40
Epoch 34, iteration 50
Epoch 34, iteration 60
Epoch 34, iteration 70
Epoch 34, iteration 80
Epoch 34, iteration 90
Epoch 34, iteration 100
Epoch 34, iteration 110
Epoch 34, iteration 120
Epoch 34, iteration 130
Epoch 34, iteration 140
Epoch 34, iteration 150
Epoch 34, iteration 160
E

Epoch 25, iteration 30
Epoch 26, iteration 10
Epoch 26, iteration 20
Epoch 26, iteration 30
Epoch 27, iteration 10
Epoch 27, iteration 20
Epoch 27, iteration 30
Epoch 28, iteration 10
Epoch 28, iteration 20
Epoch 28, iteration 30
Epoch 29, iteration 10
Epoch 29, iteration 20
Epoch 29, iteration 30
Epoch 30, iteration 10
Epoch 30, iteration 20
Epoch 30, iteration 30
Epoch 31, iteration 10
Epoch 31, iteration 20
Epoch 31, iteration 30
Epoch 32, iteration 10
Epoch 32, iteration 20
Epoch 32, iteration 30
Epoch 33, iteration 10
Epoch 33, iteration 20
Epoch 33, iteration 30
Epoch 34, iteration 10
Epoch 34, iteration 20
Epoch 34, iteration 30
Epoch 35, iteration 10
Epoch 35, iteration 20
Epoch 35, iteration 30
Epoch 36, iteration 10
Epoch 36, iteration 20
Epoch 36, iteration 30
Epoch 37, iteration 10
Epoch 37, iteration 20
Epoch 37, iteration 30
Epoch 38, iteration 10
Epoch 38, iteration 20
Epoch 38, iteration 30
Epoch 39, iteration 10
Epoch 39, iteration 20
Epoch 39, iteration 30


In [66]:
results_df = pd.DataFrame.from_dict(results_dict)

In [67]:
results_df

Unnamed: 0,dataset,num_epochs,minibatch_size,num_inducing_pts,wallclock_time,spearmanr,mae
0,fluorescence_eval/onehot/standard,20,1000,500,24.940711,"(0.6078755834404194, 0.0)",0.591547
1,aav_eval/onehot/des_mut_split,20,1000,500,111.698156,"(0.5760761882358676, 0.0)",2.543405
2,aav_eval/onehot/mut_des_split,20,1000,500,45.940478,"(0.6713233432967354, 0.0)",2.981712
3,aav_eval/onehot/seven_vs_many_split,20,1000,500,38.740947,"(0.5620174037251007, 0.0)",4.198997
4,kin40k_dataset,20,1000,500,14.629709,"(0.972348994253338, 0.0)",0.187201
5,song_dataset,20,1000,500,191.756128,"(0.5626222427006633, 0.0)",0.570566
6,uci_protein_dataset,20,1000,500,16.827276,"(0.7089183254442435, 0.0)",0.566337
7,fluorescence_eval/onehot/standard,20,1000,3000,620.018441,"(0.6341950687567559, 0.0)",0.552083
8,aav_eval/onehot/des_mut_split,20,1000,3000,4737.131272,"(0.623803408022923, 0.0)",2.351078
9,aav_eval/onehot/mut_des_split,20,1000,3000,1946.572681,"(0.6550181954678969, 0.0)",2.939273


In [68]:
os.chdir(os.path.join(home, "final_results"))
if "gpytorch_results.txt" not in os.listdir():
    results_df.to_csv("gpytorch_results.txt")

In [6]:
results_dict_rd2 = {"dataset":[], "num_epochs":[], "minibatch_size":[],
                "num_inducing_pts":[],
               "wallclock_time":[], "spearmanr":[], "mae":[]}

In [7]:
for num_epochs in [40]:
    for num_inducing, dataset in zip([25,427,2968],
                    ["gb1_eval/onehot/one_vs_rest",
                    "gb1_eval/onehot/two_vs_rest",
                    "gb1_eval/onehot/three_vs_rest"]):
        mbatch_size = max(250, num_inducing)
        wclock, spearman, mae = build_test_gpytorch_model(home, dataset, num_epochs = num_epochs, 
                              num_inducing = num_inducing, minibatch_size = mbatch_size)
        results_dict_rd2["dataset"].append(dataset)
        results_dict_rd2["num_epochs"].append(num_epochs)
        results_dict_rd2["num_inducing_pts"].append(num_inducing)
        results_dict_rd2["spearmanr"].append(spearman)
        results_dict_rd2["mae"].append(mae)
        results_dict_rd2["wallclock_time"].append(wclock)
        results_dict_rd2["minibatch_size"].append(mbatch_size)
        print("%s complete"%dataset)
    
        time.sleep(60)

torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:1672.)
  res = torch.triangular_solve(right_tensor, self.evaluate(), upper=self.upper).solution


gb1_eval/onehot/one_vs_rest complete
gb1_eval/onehot/two_vs_rest complete
gb1_eval/onehot/three_vs_rest complete


In [8]:
results_df2 = pd.DataFrame.from_dict(results_dict_rd2)

In [9]:
results_df2

Unnamed: 0,dataset,num_epochs,minibatch_size,num_inducing_pts,wallclock_time,spearmanr,mae
0,gb1_eval/onehot/one_vs_rest,40,250,25,0.797858,"(0.292581334541054, 2.196010253954897e-171)",0.957145
1,gb1_eval/onehot/two_vs_rest,40,427,427,0.50322,"(0.47783903509123976, 0.0)",0.880399
2,gb1_eval/onehot/three_vs_rest,40,2968,2968,51.30728,"(0.8199098255774282, 0.0)",0.738451
