## Example: Sequence Data

Next, let's consider the AAV dataset, designed vs mutant split, 
from the FLIP benchmark suite. For this dataset, we train on 200,000
length 57 amino acid sequences and try to predict the fitness
of a pre-specified test set. Dallago et al. report that a standard
1d-CNN trained on this achieves a Spearman's r of 0.75, while
a 750-million parameter pretrained model that took 50 GPU-days of
time to train achieves Spearman's r of 0.71.

We'll evaluate a convolution kernel and show that we can easily
match or outperform the deep learning baselines without too
much effort.

This was originally run using xGPR v0.2.

In [1]:
import os
import shutil
import subprocess
import math
import time
import zipfile

import pandas as pd
import numpy as np

from xGPR import xGPRegression as xGPReg
from xGPR import build_regression_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#This may take a minute...
subprocess.run(["git", "clone", "https://github.com/J-SNACKKB/FLIP"])

shutil.move(os.path.join("FLIP", "splits", "aav", "full_data.csv.zip"), "full_data.csv.zip")
fname = "full_data.csv.zip"

with zipfile.ZipFile(fname, "r") as zip_ref:
    zip_ref.extractall()

os.remove("full_data.csv.zip")


shutil.rmtree("FLIP")

Cloning into 'FLIP'...


In [3]:
raw_data = pd.read_csv("full_data.csv")
os.remove("full_data.csv")

  raw_data = pd.read_csv("full_data.csv")


In [4]:
raw_data["input_seq"] = [f.upper().replace("*", "") for f in raw_data["mutated_region"].tolist()]

We'll use simple one-hot encoding for the sequences. This may take a minute to set up.

In [5]:
def one_hot_encode(input_seq_list, y_values, chunk_size, ftype = "train"):
    aas = ["A", "C", "D", "E", "F", "G", "H", "I",
               "K", "L", "M", "N", "P", "Q", "R", "S", "T",
               "V", "W", "Y", "-"]
    output_x, output_y = [], []
    xfiles, yfiles = [], []
    fcounter = 0
    
    for seq, y_value in zip(input_seq_list, y_values):
        encoded_x = np.zeros((1,57,21), dtype = np.uint8)
        for i, letter in enumerate(seq):
            encoded_x[0, i, aas.index(letter)] = 1

        output_x.append(encoded_x)
        output_y.append(y_value)

        if len(output_x) >= chunk_size:
            xfiles.append(f"{fcounter}_{ftype}_xblock.npy")
            yfiles.append(f"{fcounter}_{ftype}_yblock.npy")
            np.save(xfiles[-1], np.vstack(output_x))
            np.save(yfiles[-1], np.asarray(output_y))
            fcounter += 1
            output_x, output_y = [], []
            print(f"Encoded file {fcounter}")
    return xfiles, yfiles

In [6]:
train_data = raw_data[raw_data["des_mut_split"]=="train"]
test_data = raw_data[raw_data["des_mut_split"]=="test"]


train_x_files, train_y_files = one_hot_encode(train_data["input_seq"].tolist(),
                                              train_data["score"].tolist(), 2000, "train")
test_x_files, test_y_files = one_hot_encode(test_data["input_seq"].tolist(),
                                            test_data["score"].tolist(), 2000, "test")

Encoded file 1
Encoded file 2
Encoded file 3
Encoded file 4
Encoded file 5
Encoded file 6
Encoded file 7
Encoded file 8
Encoded file 9
Encoded file 10
Encoded file 11
Encoded file 12
Encoded file 13
Encoded file 14
Encoded file 15
Encoded file 16
Encoded file 17
Encoded file 18
Encoded file 19
Encoded file 20
Encoded file 21
Encoded file 22
Encoded file 23
Encoded file 24
Encoded file 25
Encoded file 26
Encoded file 27
Encoded file 28
Encoded file 29
Encoded file 30
Encoded file 31
Encoded file 32
Encoded file 33
Encoded file 34
Encoded file 35
Encoded file 36
Encoded file 37
Encoded file 38
Encoded file 39
Encoded file 40
Encoded file 41
Encoded file 42
Encoded file 43
Encoded file 44
Encoded file 45
Encoded file 46
Encoded file 47
Encoded file 48
Encoded file 49
Encoded file 50
Encoded file 51
Encoded file 52
Encoded file 53
Encoded file 54
Encoded file 55
Encoded file 56
Encoded file 57
Encoded file 58
Encoded file 59
Encoded file 60
Encoded file 61
Encoded file 62
Encoded file 63
E

In [7]:
training_dset = build_regression_dataset(train_x_files, train_y_files, chunk_size = 2000)

Here we'll use the FHTConv1d kernel, a kernel for sequences. Convolution kernels are usually slower than RBF / Matern, especially if the sequence is long. We'll run a quick and dirty tuning experiment using 1024 random features, then fine-tune this using a larger number of random features just as we did for the tabular dataset.

Many kernels in xGPR have kernel-specific settings. For FHTConv1d, we can set two key options: whether to average over the sequence (defaults to False) and the width of the convolution to use. Just as with a convolutional network, the width of the convolution filters can affect performance. One way to choose a good setting: see what marginal likelihood score you get from hyperparameter tuning (e.g. with ``crude_bayes`` or ``crude_grid``) using a small number of RFFs (e.g. 1024 - 2048) for several different settings of "conv_width". The smallest score achieved likely corresponds to the best value for "conv_width".

In [8]:
aav_model = xGPReg(num_rffs = 1024, variance_rffs = 512,
                  kernel_choice = "FHTConv1d",
                   kernel_settings = {"conv_width":9, "averaging":False},
                   verbose = True, device = "gpu")

start_time = time.time()
hparams, niter, best_score = aav_model.tune_hyperparams_crude(training_dset)
end_time = time.time()

print(f"Best estimated negative marginal log likelihood: {best_score}")
print(f"Wallclock: {end_time - start_time}")

Grid point 0 acquired.
Grid point 1 acquired.
Grid point 2 acquired.
Grid point 3 acquired.
Grid point 4 acquired.
Grid point 5 acquired.
Grid point 6 acquired.
Grid point 7 acquired.
Grid point 8 acquired.
Grid point 9 acquired.
New hparams: [-2.1170445]
Additional acquisition 10.
New hparams: [-2.2990037]
Additional acquisition 11.
New hparams: [-2.2856463]
Best score achieved: 130393.122
Best hyperparams: [-1.0931465 -2.2856463]
Best estimated negative marginal log likelihood: 130393.122
Wallclock: 31.07062029838562


We now have a rough estimate of our hyperparameters, acquired using a sketchy kernel approximation
(num_rffs=1024) and a crude tuning procedure. Let's fine-tune this a little. We could use
the built-in tuning routine in xGPR the way we did for the tabular data, or we could use
Optuna (or some other library), or we could do a simple gridsearch. For illustrative
purposes here, we'll use Optuna using num_rffs=4,096 (a somewhat better kernel
approximation) and see what that looks like. We'll search the region around the
hyperparameters obtained from ``tune_hyperparams_crude``. To run this
next piece, you'll need to have Optuna installed. Optuna is one of our
favorite approaches and is often able to do a little better than other methods.

In [9]:
import optuna
from optuna.samplers import TPESampler

def objective(trial):
    lambda_ = trial.suggest_float("lambda_", -2., 0.)
    sigma = trial.suggest_float("sigma", -3., -1.)
    hparams = np.array([lambda_, sigma])
    nmll = aav_model.exact_nmll(hparams, training_dset)
    return nmll

In [10]:
aav_model.num_rffs = 4096

sampler = TPESampler(seed=123)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=35)

[I 2023-11-28 09:45:09,310] A new study created in memory with name: no-name-e07ce6e5-f71f-4c9e-b987-5f01481589c9
[I 2023-11-28 09:45:26,211] Trial 0 finished with value: 116884.97261995246 and parameters: {'lambda_': -0.6070616288042767, 'sigma': -2.4277213300992413}. Best is trial 0 with value: 116884.97261995246.


Evaluated NMLL.


[I 2023-11-28 09:45:43,088] Trial 1 finished with value: 114181.22032483175 and parameters: {'lambda_': -1.5462970928715938, 'sigma': -1.8973704618342175}. Best is trial 1 with value: 114181.22032483175.


Evaluated NMLL.


[I 2023-11-28 09:45:59,962] Trial 2 finished with value: 113271.88301457987 and parameters: {'lambda_': -0.5610620604288739, 'sigma': -2.153787079751078}. Best is trial 2 with value: 113271.88301457987.


Evaluated NMLL.


[I 2023-11-28 09:46:16,900] Trial 3 finished with value: 112182.98616781154 and parameters: {'lambda_': -0.038471603230769036, 'sigma': -1.6303405228302734}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:46:33,838] Trial 4 finished with value: 112451.94625265888 and parameters: {'lambda_': -1.0381361970312781, 'sigma': -2.215764963611699}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:46:50,780] Trial 5 finished with value: 116679.78219362846 and parameters: {'lambda_': -1.3136439676982612, 'sigma': -1.5419005852319168}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:47:07,721] Trial 6 finished with value: 120520.54953081356 and parameters: {'lambda_': -1.1228555106407512, 'sigma': -2.8806442067808633}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:47:24,666] Trial 7 finished with value: 116443.92207441722 and parameters: {'lambda_': -1.2039114893391372, 'sigma': -1.5240091885359286}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:47:41,605] Trial 8 finished with value: 114600.58825822489 and parameters: {'lambda_': -1.635016539093, 'sigma': -2.649096487705015}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:47:58,542] Trial 9 finished with value: 112309.72714681382 and parameters: {'lambda_': -0.9368972523163233, 'sigma': -1.9363448258062679}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:48:15,484] Trial 10 finished with value: 118087.34232658173 and parameters: {'lambda_': -0.03129046371014721, 'sigma': -1.0097302786311597}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:48:32,427] Trial 11 finished with value: 112422.48824147781 and parameters: {'lambda_': -0.01866037953665966, 'sigma': -1.8420837673609056}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:48:49,370] Trial 12 finished with value: 117675.56910373297 and parameters: {'lambda_': -1.86792286212751, 'sigma': -1.6554794205210857}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:49:06,316] Trial 13 finished with value: 116479.30082963311 and parameters: {'lambda_': -0.6682392377917878, 'sigma': -1.334284575460567}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:49:23,258] Trial 14 finished with value: 112972.71033390879 and parameters: {'lambda_': -0.3272278654086118, 'sigma': -2.0359901306985515}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:49:40,203] Trial 15 finished with value: 112217.96159582998 and parameters: {'lambda_': -0.8172325331915455, 'sigma': -1.9106811994190158}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:49:57,157] Trial 16 finished with value: 117274.06183211948 and parameters: {'lambda_': -0.2603175001600282, 'sigma': -2.3111339546818375}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:50:14,100] Trial 17 finished with value: 112619.09588957127 and parameters: {'lambda_': -0.8180620085440145, 'sigma': -1.8200765777075854}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:50:31,043] Trial 18 finished with value: 116250.46632943655 and parameters: {'lambda_': -0.32229338117455153, 'sigma': -1.2419508741475265}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:50:47,986] Trial 19 finished with value: 114084.89684029567 and parameters: {'lambda_': -0.8671099982598062, 'sigma': -1.6394201895068492}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:51:04,927] Trial 20 finished with value: 113199.17909443154 and parameters: {'lambda_': -0.532872551556751, 'sigma': -2.136297897825651}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:51:21,869] Trial 21 finished with value: 112390.44567547613 and parameters: {'lambda_': -0.9186721161131867, 'sigma': -1.9082246988035152}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:51:38,814] Trial 22 finished with value: 113226.48339402488 and parameters: {'lambda_': -0.8068172896190198, 'sigma': -1.7220230284031095}. Best is trial 3 with value: 112182.98616781154.


Evaluated NMLL.


[I 2023-11-28 09:51:55,757] Trial 23 finished with value: 112114.72530880311 and parameters: {'lambda_': -1.0103391936363806, 'sigma': -2.0911671517463963}. Best is trial 23 with value: 112114.72530880311.


Evaluated NMLL.


[I 2023-11-28 09:52:12,696] Trial 24 finished with value: 112280.49477969573 and parameters: {'lambda_': -1.2392140586015046, 'sigma': -2.087070595587511}. Best is trial 23 with value: 112114.72530880311.


Evaluated NMLL.


[I 2023-11-28 09:52:29,638] Trial 25 finished with value: 114474.11009340384 and parameters: {'lambda_': -1.0468149298796818, 'sigma': -2.4317724702573495}. Best is trial 23 with value: 112114.72530880311.


Evaluated NMLL.


[I 2023-11-28 09:52:46,581] Trial 26 finished with value: 112187.96857642675 and parameters: {'lambda_': -0.7052050499628298, 'sigma': -2.049740057759603}. Best is trial 23 with value: 112114.72530880311.


Evaluated NMLL.


[I 2023-11-28 09:53:03,525] Trial 27 finished with value: 112561.07469358988 and parameters: {'lambda_': -0.4200102911513557, 'sigma': -2.020155365111411}. Best is trial 23 with value: 112114.72530880311.


Evaluated NMLL.


[I 2023-11-28 09:53:20,467] Trial 28 finished with value: 112025.64316648839 and parameters: {'lambda_': -0.15817474787806624, 'sigma': -1.7725920521360994}. Best is trial 28 with value: 112025.64316648839.


Evaluated NMLL.


[I 2023-11-28 09:53:37,410] Trial 29 finished with value: 112043.3765583417 and parameters: {'lambda_': -0.15429868271440328, 'sigma': -1.7910748578352933}. Best is trial 28 with value: 112025.64316648839.


Evaluated NMLL.


[I 2023-11-28 09:53:54,354] Trial 30 finished with value: 117769.18858275587 and parameters: {'lambda_': -0.19240335145379095, 'sigma': -2.3114322360951163}. Best is trial 28 with value: 112025.64316648839.


Evaluated NMLL.


[I 2023-11-28 09:54:11,296] Trial 31 finished with value: 112027.72722358054 and parameters: {'lambda_': -0.15545204796200673, 'sigma': -1.774862465080108}. Best is trial 28 with value: 112025.64316648839.


Evaluated NMLL.


[I 2023-11-28 09:54:28,241] Trial 32 finished with value: 112031.75142268102 and parameters: {'lambda_': -0.14590070421398404, 'sigma': -1.7753784396393886}. Best is trial 28 with value: 112025.64316648839.


Evaluated NMLL.


[I 2023-11-28 09:54:45,188] Trial 33 finished with value: 112163.16659169078 and parameters: {'lambda_': -0.457054472835695, 'sigma': -1.7798674390851137}. Best is trial 28 with value: 112025.64316648839.


Evaluated NMLL.


[I 2023-11-28 09:55:02,137] Trial 34 finished with value: 112021.80990490183 and parameters: {'lambda_': -0.17193176584567704, 'sigma': -1.7713977474762077}. Best is trial 34 with value: 112021.80990490183.


Evaluated NMLL.


Set the model hyperparameters to the best ones found by Optuna.

In [11]:
study.best_params

{'lambda_': -0.17193176584567704, 'sigma': -1.7713977474762077}

In [13]:
aav_model.set_hyperparams(np.array([-0.17193176, -1.771397]), training_dset)

Now we'll fit the model using 8192 RFFs. We like to use a more accurate kernel approximationwhen fitting than when tuning for two reasons. First, tuning is more expensive because the model has to be fit multiple times when tuning hyperparameters. Second, model performance usually
increases faster by increasing the number of rffs used for fitting than for tuning. (Using 16,384 RFFs here for fitting further
increases test set performance as you'd expect.)

On gpu, for fitting, ``mode=exact`` works well up to 8,192 RFFs or so, while ``mode=cg`` although
slower for small numbers of RFFs is more scalable. On this dataset, using 8,192 RFFs, "exact" takes about 70 seconds on our GPU.
We'll use cg here just for illustrative purposes. Notice that using fitting with default settings takes about 45 iterations with
CG. We can speed this up by changing the defaults (see the Advanced Tutorials for more on how to do this).

``tol`` determines how tight the fit is. 1e-6 (default) is usually fine. Decreasing the number will improve performance but
with rapidly diminishing returns and make fitting take longer. For noise free data or to get a small additional boost in
performance, use 1e-7. 1e-8 is (nearly always) overkill.

In [14]:
aav_model.num_rffs = 8192
start_time = time.time()
aav_model.fit(training_dset, mode = 'cg', tol = 1e-6)
end_time = time.time()
print(f"Wallclock: {end_time - start_time}")

starting fitting
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
Using rank: 1024
Chunk 0 complete.
Chunk 10 complete.
Chunk 20 complete.
Chunk 30 complete.
Chunk 40 complete.
Chunk 50 complete.
Chunk 60 complete.
Chunk 70 complete.
Chunk 80 complete.
Chunk 90 complete.
Iteration 0
Iteration 5
Iteration 10
Iteration 15
Iteration 20
Iteration 25
Iteration 30
Iteration 35
Iteration 40
Iteration 45
Now performing variance calculations...
Fitting complete.
Wallclock: 419.55269742012024


In [15]:
start_time = time.time()
all_preds, ground_truth = [], []
for xfile, yfile in zip(test_x_files, test_y_files):
    x, y = np.load(xfile), np.load(yfile)
    ground_truth.append(y)
    preds = aav_model.predict(x, get_var = False)
    all_preds.append(preds)
    
all_preds, ground_truth = np.concatenate(all_preds), np.concatenate(ground_truth)
end_time = time.time()
print(f"Wallclock: {end_time - start_time}")

Wallclock: 3.5997021198272705


In [16]:
from scipy.stats import spearmanr

spearmanr(all_preds, ground_truth)

SignificanceResult(statistic=0.7548986561289404, pvalue=0.0)

Spearman's r of 0.75 plus matches the performance for a 1d-CNN reported by Dallago et al
for this dataset and is similar to the performance of a fine-tuned LLM (Spearman's r 0.79).
As discussed above, we can get further slight improvements in performance
just by tweaking this model. We can do even better by using a more informative
representation of the protein sequences. In our original paper we achieved a Spearman's r
of about 0.8 on this dataset, outperforming fine-tuned LLMs (and costing significantly less to train
than a fine-tuned LLM).
Whether small gains in performance from further "tweaking" or more informative representations is worthwhile
obviously depends on your application...

In [19]:
for testx, testy in zip(test_x_files, test_y_files):
    os.remove(testx)
    os.remove(testy)

In [None]:
for xfile, yfile in zip(train_x_files, train_y_files):
    os.remove(xfile)
    os.remove(yfile)