# Hyperparameter Search
This notebook is an implementation of ``10x_dataset_training.ipynb`` with hyperparameter search. The hyperparameter search is done using Weights and Biases (wandb) and performed with the sweep method, a grid search with random sampling. The hyperparameters are:
- ``learning_rate``: The learning rate of the optimizer.
- ``num_specific_layers``: The number of model-specific layers (i.e., self-attention layers, convolutional layers, etc.).
- ``aa_embedding_dim``: The dimension of the amino acid embedding.
- ``depth_final_dense``: The number of linear layers in the network.
- ``model_name``: The model to use. Either ``bilstm``, ``self_attention``, ``cnn``, or ``bigru``. See the ``README.md`` for more details about the implementations of these architectures. 

In [1]:
import pandas as pd
import tcellmatch.api as tm
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from pytorch_model_summary import summary
from torchmetrics import Accuracy
import torch
import os
import numpy as np
import wandb

# Build Model

In [2]:
USE_BIND_COUNTS = True
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [12]:
def load_model():
    ffn = tm.models.EstimatorFfn()
    indir = '../tutorial_data/'
    data = np.load(f"{indir}ffn_data_continuous_15k.npz")
    ffn.x_train = data["x_train"]
    ffn.covariates_train = data["covariates_train"]
    ffn.y_train = data["y_train"]
    ffn.x_test = data["x_test"]
    ffn.covariates_test = data["covariates_test"]
    ffn.y_test = data["y_test"]
    ffn.clone_train = data["clone_train"]
    ffn.load_idx(f'{indir}SAVED_IDX')
    
    sums_across_last_dim = np.sum(ffn.x_train, axis=-1)

    # Find rows which are not "zero-hot"
    non_zero_hot_rows = np.any(sums_across_last_dim > 0, axis=-1)
    non_zero_hot_rows = np.squeeze(non_zero_hot_rows)
    
    ffn.x_train = ffn.x_train[non_zero_hot_rows]
    ffn.y_train = ffn.y_train[non_zero_hot_rows]
    return ffn

In [13]:
# Initialize wandb with a sample project name

config = wandb.config


lr = 0.001
aa_embedding_dim = 1
depth_final_dense = 1
ffn = load_model()


attention_size = [128] * 1
attention_heads = [16] * 1

ffn.build_self_attention(
    residual_connection=True,
    aa_embedding_dim=aa_embedding_dim,
    attention_size=attention_size,
    use_covariates=False,
    attention_heads=attention_heads,
    depth_final_dense=depth_final_dense,
    optimizer='adam',
    lr=lr,
    loss='pois' if USE_BIND_COUNTS else 'wbce',
    label_smoothing=0
)


# Add WandB Search

In [19]:
sweep_config = {
   'method': 'grid',  # can be random, grid, bayes
   'parameters': {
       'lr': {  # learning rate
           'values': [0.001, 0.005, 0.01, 0.1]
       },
       'aa_embedding_dim': {
           'values': [0, 10, 26]
       },
       'depth_final_dense': {
           'values': [1, 2, 3, 5, 9]
       },
       'model_name': {
           'values': ['self-attention', 'bilstm', 'bigru', 'cnn']
       },
       # i.e., bilstm depth, SA depth, conv depth
       'specific_layer_depth': {
            'values': [1, 2, 3, 4, 5]
       }
   }
}

In [20]:
def train():
    # Initialize wandb with a sample project name
    wandb.init(project="TCR fitting")

    config = wandb.config
   
    # replace your hard-coded hyperparameters with config values
    model_name = config.model_name
    lr = config.lr
    aa_embedding_dim = config.aa_embedding_dim
    depth_final_dense = config.depth_final_dense
    ffn = load_model()

    if model_name.lower() == 'self-attention':
        attention_size = [128] * config.specific_layer_depth
        attention_heads = [16] * config.specific_layer_depth

        ffn.build_self_attention(
            residual_connection=True,
            aa_embedding_dim=aa_embedding_dim,
            attention_size=attention_size,
            use_covariates=False,
            attention_heads=attention_heads,
            depth_final_dense=depth_final_dense,
            optimizer='adam',
            lr=lr,
            loss='pois' if USE_BIND_COUNTS else 'wbce',
            label_smoothing=0
        )
    elif model_name.lower() == 'bilstm':
        topology = [32] * config.specific_layer_depth
        ffn.build_bilstm(
            topology=topology,
            residual_connection=True,
            aa_embedding_dim=aa_embedding_dim,
            optimizer='adam',
            lr=lr,
            loss='pois' if USE_BIND_COUNTS else 'wcbe',
            label_smoothing=0,
            depth_final_dense=depth_final_dense,
            use_covariates=False,
            one_hot_y=not USE_BIND_COUNTS
        )
    elif model_name.lower() == 'bigru':
        topology= [10] * config.specific_layer_depth
        ffn.build_bigru(
            aa_embedding_dim=aa_embedding_dim,
            residual_connection=True,
            lr=lr,
            loss='pois' if USE_BIND_COUNTS else 'wbce',
        )
    elif model_name.lower() == 'cnn':
        n_conv_layers = config.specific_layer_depth
        # filter_widths = [3, 5, 3] 
        # filters = [16, 32, 64]
        pool_sizes = [2] * n_conv_layers
        pool_strides = [2] * n_conv_layers
        ffn.build_conv(
            n_conv_layers=n_conv_layers,
            depth_final_dense=depth_final_dense,
            # filter_widths=filter_widths,
            # filters=filters,
            pool_sizes=pool_sizes,
            pool_strides=pool_strides,
            loss='pois' if USE_BIND_COUNTS else 'wbce',
        )

    # Training model
    EPOCHS = 5
    batch_size = 16
    ffn.model = ffn.model.to(device=device)
    train_curve, val_curve, antigen_loss, antigen_loss_val = ffn.train(
        epochs=EPOCHS,
        batch_size=batch_size,
        log_dir='training_runs',
        allow_early_stopping=False,
        print_loss=False,
        lr_schedule_factor=0.99999,
        use_wandb=True
    )
   
    # Log metrics with wandb
    wandb.log({'Train Loss': train_curve[-1], 'Validation Loss': val_curve[-1]})

In [21]:
sweep_id = wandb.sweep(sweep_config)

Create sweep with ID: tzxirs4o
Sweep URL: https://wandb.ai/jmboesen/uncategorized/sweeps/tzxirs4o


In [22]:
wandb.agent(sweep_id, function=train)

[34m[1mwandb[0m: Agent Starting Run: 4ukda19b with config:
[34m[1mwandb[0m: 	aa_embedding_dim: 0
[34m[1mwandb[0m: 	depth_final_dense: 1
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	model_name: self-attention
[34m[1mwandb[0m: 	specific_layer_depth: 1


<IPython.core.display.HTML object>
VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max=1.0)))
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>




started training...
Number of observations in evaluation data: 1459
Number of observations in training data: 13541


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


Error in callback <function _WandbInit._pause_backend at 0x7f4eeb252660> (for post_run_cell):


BrokenPipeError: [Errno 32] Broken pipe