# Hyperparameter Search
This notebook is an implementation of ``10x_dataset_training.ipynb`` with hyperparameter search. The hyperparameter search is done using Weights and Biases (wandb) and performed with the sweep method, a grid search with random sampling. The hyperparameters are:
- ``learning_rate``: The learning rate of the optimizer.
- ``num_specific_layers``: The number of model-specific layers (i.e., self-attention layers, convolutional layers, etc.).
- ``aa_embedding_dim``: The dimension of the amino acid embedding.
- ``depth_final_dense``: The number of linear layers in the network.
- ``model_name``: The model to use. Either ``bilstm``, ``self_attention``, ``cnn``, or ``bigru``. See the ``README.md`` for more details about the implementations of these architectures. 

In [2]:
import pandas as pd
import tcellmatch.api as tm
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from pytorch_model_summary import summary
from torchmetrics import Accuracy
import torch
import os
import numpy as np
import wandb

# Load Data

In [3]:
def load_model():
    print('loading model...')
    ffn = tm.models.EstimatorFfn()
    indir = '../tutorial_data/'
    data = np.load(f"{indir}ffn_data_continuous_15k.npz")
    ffn.x_train = data["x_train"]
    ffn.covariates_train = data["covariates_train"]
    ffn.y_train = data["y_train"]
    ffn.x_test = data["x_test"]
    ffn.covariates_test = data["covariates_test"]
    ffn.y_test = data["y_test"]
    ffn.clone_train = data["clone_train"]
    ffn.load_idx(f'{indir}SAVED_IDX')
    sums_across_last_dim = np.sum(ffn.x_train, axis=-1)

    # Find rows which are not "zero-hot"
    non_zero_hot_rows = np.any(sums_across_last_dim > 0, axis=-1)

    ffn.x_train = ffn.x_train[non_zero_hot_rows]
    ffn.x_train = ffn.x_train[:, np.newaxis, :]
    return ffn

In [None]:
ffn = load_model()
print(ffn.x_train.shape, ffn.y_train.shape, ffn.covariates_train.shape)
attention_size = [128] * 1
attention_heads = [16] * 1
print(attention_size, attention_heads)
ffn.build_self_attention(
    residual_connection=True,
    aa_embedding_dim=0,
    attention_size=attention_size,
    use_covariates=False,
    attention_heads=attention_heads,
    depth_final_dense=1,
    optimizer='adam',
    lr=0.001,
    loss='pois',
    label_smoothing=0
)
print('built')
EPOCHS = 2
batch_size = 16
print('ak to train')
train_curve, val_curve, antigen_loss, antigen_loss_val = ffn.train(
    epochs=EPOCHS,
    batch_size=batch_size,
    log_dir='training_runs',
    allow_early_stopping=False,
    print_loss=True,
    lr_schedule_factor=0.99999,
    use_wandb=False
)

loading model...


(15000, 1, 40, 26) (15000, 50) (15000, 2)
[128] [16]
built
ak to train
started training...
pre partition
Number of observations in evaluation data: 1526
Number of observations in training data: 13474
post partition partition
loaded up data
At beginning of epoch...
2.580904483795166
2.3816614151000977
3.5810325145721436
3.226235866546631
2.259986162185669
1.0143965482711792
1.2908402681350708
1.4610847234725952
0.8280069828033447
1.0991848707199097
1.5233083963394165
3.7885024547576904
1.2976499795913696
0.7468215823173523
1.4784337282180786
1.5324702262878418
1.0549285411834717
2.525811195373535
0.9838641285896301
2.117570400238037
1.357994556427002
4.205320358276367
1.6021696329116821
1.1233086585998535
0.9501087665557861
1.0942893028259277
1.0694996118545532
2.4533028602600098
0.9575803279876709
1.3155338764190674
1.5218037366867065
0.6309264302253723
0.8096922039985657
3.459019660949707
2.248868703842163
2.23818302154541
1.6943048238754272
2.4967267513275146
1.1491695642471313
0.942

0.7125205993652344
1.2693194150924683
0.737281084060669
1.544362187385559
0.732653796672821
0.6158252954483032
1.6790955066680908
4.130128383636475
0.498200923204422
2.3803458213806152
0.593722939491272
1.8451266288757324
3.191197395324707
0.5449780821800232
0.6674139499664307
0.7572245001792908
0.8805040121078491
0.7565365433692932
0.8900900483131409
0.4895645081996918
0.4885500371456146


### Add WandB Search

In [None]:
sweep_config = {
   'method': 'grid',  # can be random, grid, bayes
   'parameters': {
       'lr': {  # learning rate
           'values': [0.001, 0.005, 0.01, 0.1]
       },
       'aa_embedding_dim': {
           'values': [0, 10, 26]
       },
       'depth_final_dense': {
           'values': [1, 2, 3, 5, 9]
       },
       'model_name': {
           'values': ['self-attention', 'bilstm', 'bigru', 'cnn']
       },
       # i.e., bilstm depth, SA depth, conv depth
       'specific_layer_depth': {
            'values': [1, 2, 3, 4, 5]
       }
   }
}

In [None]:
def train():
    # Initialize wandb with a sample project name
    wandb.init(project="TCR fitting")

    config = wandb.config
   
    # replace your hard-coded hyperparameters with config values
    model_name = config.model_name
    lr = config.lr
    aa_embedding_dim = config.aa_embedding_dim
    depth_final_dense = config.depth_final_dense
    ffn = load_model()
    print(ffn.x_train.shape, ffn.y_train.shape, ffn.covariates_train.shape)
    if model_name.lower() == 'self-attention':
        attention_size = [128] * config.specific_layer_depth
        attention_heads = [16] * config.specific_layer_depth
        print(attention_size, attention_heads)
        ffn.build_self_attention(
            residual_connection=True,
            aa_embedding_dim=aa_embedding_dim,
            attention_size=attention_size,
            use_covariates=False,
            attention_heads=attention_heads,
            depth_final_dense=depth_final_dense,
            optimizer='adam',
            lr=lr,
            loss='pois',
            label_smoothing=0
        )
        print('built')
    
    elif model_name.lower() == 'bilstm':
        topology = [32] * config.specific_layer_depth
        ffn.build_bilstm(
            topology=topology,
            residual_connection=True,
            aa_embedding_dim=aa_embedding_dim,
            optimizer='adam',
            lr=lr,
            loss='pois',
            label_smoothing=0,
            depth_final_dense=depth_final_dense,
            use_covariates=False,
            one_hot_y=False,
        )
        print('built')

    elif model_name.lower() == 'bigru':
        topology= [10] * config.specific_layer_depth
        ffn.build_bigru(
            aa_embedding_dim=aa_embedding_dim,
            residual_connection=True,
            lr=lr,
            loss='pois',
        )
        print('built')

    elif model_name.lower() == 'cnn':
        n_conv_layers = config.specific_layer_depth
        # filter_widths = [3, 5, 3] 
        # filters = [16, 32, 64]
        pool_sizes = [2] * n_conv_layers
        pool_strides = [2] * n_conv_layers
        ffn.build_conv(
            n_conv_layers=n_conv_layers,
            depth_final_dense=depth_final_dense,
            # filter_widths=filter_widths,
            # filters=filters,
            pool_sizes=pool_sizes,
            pool_strides=pool_strides,
            loss='pois',
        )
        print('built')

    # Training model
    EPOCHS = 2
    batch_size = 100000
    print('ak to train')
    train_curve, val_curve, antigen_loss, antigen_loss_val = ffn.train(
        epochs=EPOCHS,
        batch_size=batch_size,
        log_dir='training_runs',
        allow_early_stopping=False,
        print_loss=True,
        lr_schedule_factor=0.99999,
        use_wandb=True
    )
    # Log metrics with wandb
    wandb.log({'Train Loss': train_curve[-1], 'Validation Loss': val_curve[-1]})

In [None]:
sweep_id = wandb.sweep(sweep_config)

Create sweep with ID: v9p9h06w
Sweep URL: https://wandb.ai/jmboesen/uncategorized/sweeps/v9p9h06w


In [None]:
wandb.agent(sweep_id, function=train)

[34m[1mwandb[0m: Agent Starting Run: c3vrz4eu with config:
[34m[1mwandb[0m: 	aa_embedding_dim: 0
[34m[1mwandb[0m: 	depth_final_dense: 1
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	model_name: self-attention
[34m[1mwandb[0m: 	specific_layer_depth: 1
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


(15000, 40, 26) (15000, 50) (15000, 2)
[128] [16]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

Run c3vrz4eu errored: IndexError('tuple index out of range')
[34m[1mwandb[0m: [32m[41mERROR[0m Run c3vrz4eu errored: IndexError('tuple index out of range')
[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
