In [1]:
import random
import time
import torch
from bayes_opt import BayesianOptimization

from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import TensorDataset

from elfragmentador.utils import get_random_peptide
from elfragmentador import model

def concat_batches(batches):
    out = []
    for i, _ in enumerate(batches[0]):
        out.append(torch.cat([b[i] for b in batches]))

    return tuple(out)

def prepare_input_tensors(num=50):
    peps = [
        {
            "nce": 20 + (10 * random.random()),
            "charge": random.randint(1, 5),
            "seq": get_random_peptide(),
        }
        for _ in range(num)
    ]

    tensors = [model.PepTransformerModel.torch_batch_from_seq(**pep) for pep in peps]
    tensors = TensorDataset(*concat_batches(batches=tensors))

    return tensors

NUM_PEPTIDES = 50
input_tensors = prepare_input_tensors(NUM_PEPTIDES)
batches = DataLoader(input_tensors, batch_size=1)

In [2]:

@torch.no_grad()
def measure_time(model, batches):
    model.eval()
    st = time.time()
    for b in batches:
        _ = model(*b)

    et = time.time() - st
    return et


def optimize_time_budget(budget, batches):
    batches = batches
    budget = budget
    
    def _main(num_encoder_layers, num_decoder_layers, nhid, ninp, nhead):
        stat_dict =  {
            'num_encoder_layers':int(num_encoder_layers),
            'num_decoder_layers':int(num_encoder_layers + (num_decoder_layers)),
            'nhid':int((nhid)),
            'ninp':int((ninp)),
            'nhead':int((nhead)),
            }

        stat_dict["ninp"] = int(stat_dict["ninp"]/stat_dict["nhead"]) * stat_dict["nhead"]
        stat_dict["nhid"] = int(stat_dict["nhid"]/stat_dict["nhead"]) * stat_dict["nhead"]

        try:
            mod = model.PepTransformerModel(**stat_dict)
        except AssertionError as e:
            print(stat_dict)
            raise(e)

        et = measure_time(model=mod, batches=batches)

        # This is the number of seconds per inference, per sample
        et = et / len(batches)
        mae = 1 - abs(budget - et)
        return mae
    
    return _main


In [3]:
import warnings
warnings.simplefilter("ignore")

import logging
logging.getLogger("root").setLevel(logging.ERROR)

import pandas as pd

BOUNDS = {
    'num_encoder_layers':(2,6),
    'num_decoder_layers':(1,4),
    'nhid':(64, 2048),
    'ninp':(64, 2048),
    'nhead':(2,8),
}

BUDGET = 0.02

optimizer = BayesianOptimization(
    f=optimize_time_budget(budget=BUDGET, batches=batches),
    pbounds=BOUNDS,
)

optimizer.maximize(
    n_iter=500,
)

|   iter    |  target   |   nhead   |   nhid    |   ninp    | num_de... | num_en... |
-------------------------------------------------------------------------------------
| [0m 1       [0m | [0m 0.8784  [0m | [0m 5.664   [0m | [0m 1.068e+0[0m | [0m 868.8   [0m | [0m 3.204   [0m | [0m 3.558   [0m |
| [95m 2       [0m | [95m 0.9936  [0m | [95m 3.625   [0m | [95m 430.9   [0m | [95m 337.8   [0m | [95m 2.731   [0m | [95m 3.1     [0m |
| [0m 3       [0m | [0m 0.8822  [0m | [0m 2.582   [0m | [0m 817.0   [0m | [0m 806.2   [0m | [0m 2.011   [0m | [0m 5.117   [0m |
| [0m 4       [0m | [0m 0.8152  [0m | [0m 7.331   [0m | [0m 1.982e+0[0m | [0m 891.9   [0m | [0m 3.412   [0m | [0m 3.144   [0m |
| [0m 5       [0m | [0m 0.8357  [0m | [0m 6.592   [0m | [0m 65.15   [0m | [0m 1.173e+0[0m | [0m 3.274   [0m | [0m 2.289   [0m |
| [0m 6       [0m | [0m 0.9917  [0m | [0m 4.58    [0m | [0m 431.0   [0m | [0m 329.2   [0m | [0m 1.06

In [4]:
x_obs = pd.DataFrame([res["params"] for res in optimizer.res])
x_obs["Target"] = ([res["target"] for res in optimizer.res])
x_obs

Unnamed: 0,nhead,nhid,ninp,num_decoder_layers,num_encoder_layers,Target
0,5.664336,1068.371454,868.833242,3.204300,3.558123,0.878355
1,3.625233,430.859471,337.825087,2.730968,3.099728,0.993642
2,2.582101,817.028404,806.213283,2.010582,5.116899,0.882202
3,7.330757,1982.492847,891.912342,3.412071,3.143991,0.815231
4,6.592314,65.148256,1173.348437,3.273565,2.289407,0.835672
...,...,...,...,...,...,...
500,2.632558,1037.675814,103.848356,3.068476,5.399551,0.999067
501,8.000000,575.351882,92.193556,4.000000,2.000000,0.993035
502,8.000000,1009.466715,64.000000,1.000000,6.000000,0.997500
503,2.337868,933.794860,161.062386,3.403104,2.473007,0.995685


In [5]:
# Allow 10% off the time budget
df = x_obs[x_obs["Target"] > (1-(BUDGET*0.1))].sort_values("Target").reset_index(drop=True)
print(df)
df.to_csv("bayes_opt_arches.csv", index=False)

       nhead         nhid        ninp  num_decoder_layers  num_encoder_layers  \
0   8.000000  1925.943331  203.673323            1.000000            2.000000   
1   2.000000  1513.306466  120.738941            4.000000            2.000000   
2   2.482978   263.073498  419.452400            1.001440            2.875290   
3   4.743766  1711.348043  166.943240            3.919750            2.034334   
4   8.000000   974.282187  286.511349            1.000000            2.000000   
5   2.443741  1670.729762  149.342009            3.060752            2.276151   
6   7.252317   533.767692   64.037326            3.251644            4.410376   
7   8.000000   348.310051   97.772862            4.000000            6.000000   
8   7.417538  1840.508452  233.716065            1.325357            2.024808   
9   2.105893   714.354237  154.235054            1.455196            4.833155   
10  7.134612   717.633714  118.463566            1.621987            4.527051   
11  2.000000   720.554957  3

In [9]:
df["nhead"] = df["nhead"].astype("int")
df["nhid"] = (df["nhid"]/df["nhead"]).astype("int") * df["nhead"]
df["ninp"] = (df["ninp"]/df["nhead"]).astype("int") * df["nhead"]
df

Unnamed: 0,nhead,nhid,ninp,num_decoder_layers,num_encoder_layers,Target
0,8,1920,200,1.0,2.0,0.998093
1,2,1512,120,4.0,2.0,0.998135
2,2,262,418,1.00144,2.87529,0.998153
3,4,1708,164,3.91975,2.034334,0.998212
4,8,968,280,1.0,2.0,0.998354
5,2,1670,148,3.060752,2.276151,0.99839
6,7,532,63,3.251644,4.410376,0.998426
7,8,344,96,4.0,6.0,0.998477
8,7,1834,231,1.325357,2.024808,0.998533
9,2,714,154,1.455196,4.833155,0.998546
