# Hyperparameter Search
This notebook is an implementation of ``10x_dataset_training.ipynb`` with hyperparameter search. The hyperparameter search is done using Weights and Biases (wandb) and performed with the sweep method, a grid search with random sampling. The hyperparameters are:
- ``learning_rate``: The learning rate of the optimizer.
- ``num_specific_layers``: The number of model-specific layers (i.e., self-attention layers, convolutional layers, etc.).
- ``aa_embedding_dim``: The dimension of the amino acid embedding.
- ``depth_final_dense``: The number of linear layers in the network.
- ``model_name``: The model to use. Either ``bilstm``, ``self_attention``, ``cnn``, or ``bigru``. See the ``README.md`` for more details about the implementations of these architectures. 

In [1]:
import pandas as pd
import tcellmatch.api as tm
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from pytorch_model_summary import summary
from torchmetrics import Accuracy
import torch
import os
import numpy as np
import wandb

# Load Data

In [3]:
ffn = tm.models.EstimatorFfn()
indir = '../tutorial_data/'
data = np.load(f"{indir}ffn_data_continuous_15k.npz")
ffn.x_train = data["x_train"]
ffn.covariates_train = data["covariates_train"]
ffn.y_train = data["y_train"]
ffn.x_test = data["x_test"]
ffn.covariates_test = data["covariates_test"]
ffn.y_test = data["y_test"]
ffn.clone_train = data["clone_train"]

In [4]:
ffn.load_idx(f'{indir}SAVED_IDX')

## Reshape Data

In [5]:
sums_across_last_dim = np.sum(ffn.x_train, axis=-1)

# Find rows which are not "zero-hot"
non_zero_hot_rows = np.any(sums_across_last_dim > 0, axis=-1)

ffn.x_train = ffn.x_train[non_zero_hot_rows]
ffn.x_test = ffn.x_test[non_zero_hot_rows]

IndexError: boolean index did not match indexed array along dimension 0; dimension is 5437 but corresponding boolean dimension is 15000

In [6]:
ffn.x_train = ffn.x_train[:, np.newaxis, :]

# Build Model

In [7]:
model_name = 'CNN'
USE_BIND_COUNTS = True

# Train model
Train this model for 2 epochs     

### Add WandB Search

In [8]:
sweep_config = {
   'method': 'grid',  # can be random, grid, bayes
   'parameters': {
       'lr': {  # learning rate
           'values': [0.001, 0.005, 0.01, 0.1]
       },
       'aa_embedding_dim': {
           'values': [0, 10, 26]
       },
       'depth_final_dense': {
           'values': [1, 2, 3, 5, 9]
       },
       'model_name': {
           'values': ['self-attention', 'bilstm', 'bigru', 'cnn']
       },
       # i.e., bilstm depth, SA depth, conv depth
       'specific_layer_depth': {
            'values': [1, 2, 3, 4, 5]
       }
   }
}

In [9]:
def train():
    # Initialize wandb with a sample project name
    wandb.init(project="TCR fitting")

    config = wandb.config
   
    # replace your hard-coded hyperparameters with config values
    model_name = config.model_name
    lr = config.lr
    aa_embedding_dim = config.aa_embedding_dim
    depth_final_dense = config.depth_final_dense

    if model_name.lower() == 'self-attention':
        attention_size = [128] * config.specific_layer_depth
        attention_heads = [16] * config.specific_layer_depth

        ffn.build_self_attention(
            residual_connection=True,
            aa_embedding_dim=aa_embedding_dim,
            attention_size=attention_size,
            use_covariates=False,
            attention_heads=attention_heads,
            depth_final_dense=depth_final_dense,
            optimizer='adam',
            lr=lr,
            loss='pois' if USE_BIND_COUNTS else 'wbce',
            label_smoothing=0
        )
    elif model_name.lower() == 'bilstm':
        topology = [32] * config.specific_layer_depth
        ffn.build_bilstm(
            topology=topology,
            residual_connection=True,
            aa_embedding_dim=aa_embedding_dim,
            optimizer='adam',
            lr=lr,
            loss='pois' if USE_BIND_COUNTS else 'wcbe',
            label_smoothing=0,
            depth_final_dense=depth_final_dense,
            use_covariates=False,
            one_hot_y=not USE_BIND_COUNTS
        )
    elif model_name.lower() == 'bigru':
        topology= [10] * config.specific_layer_depth
        ffn.build_bigru(
            aa_embedding_dim=aa_embedding_dim,
            residual_connection=True,
            lr=lr,
            loss='pois' if USE_BIND_COUNTS else 'wbce',
        )
    elif model_name.lower() == 'cnn':
        n_conv_layers = config.specific_layer_depth
        # filter_widths = [3, 5, 3] 
        # filters = [16, 32, 64]
        pool_sizes = [2] * n_conv_layers
        pool_strides = [2] * n_conv_layers
        ffn.build_conv(
            n_conv_layers=n_conv_layers,
            depth_final_dense=depth_final_dense,
            # filter_widths=filter_widths,
            # filters=filters,
            pool_sizes=pool_sizes,
            pool_strides=pool_strides,
            loss='pois' if USE_BIND_COUNTS else 'wbce',
        )

    # Training model
    EPOCHS = 2
    batch_size = 100000
    train_curve, val_curve, antigen_loss, antigen_loss_val = ffn.train(
        epochs=EPOCHS,
        batch_size=batch_size,
        log_dir='training_runs',
        allow_early_stopping=False,
        print_loss=True,
        lr_schedule_factor=0.99999,
        use_wandb=True
    )
   
    # Log metrics with wandb
    wandb.log({'Train Loss': train_curve[-1], 'Validation Loss': val_curve[-1]})

In [10]:
sweep_id = wandb.sweep(sweep_config)

Create sweep with ID: jdl4t6mu
Sweep URL: https://wandb.ai/jmboesen/uncategorized/sweeps/jdl4t6mu


In [11]:
wandb.agent(sweep_id, function=train)

[34m[1mwandb[0m: Agent Starting Run: 3bxl3c99 with config:
[34m[1mwandb[0m: 	aa_embedding_dim: 0
[34m[1mwandb[0m: 	depth_final_dense: 1
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	model_name: self-attention
[34m[1mwandb[0m: 	specific_layer_depth: 1
[34m[1mwandb[0m: Currently logged in as: [33mjmboesen[0m. Use [1m`wandb login --relogin`[0m to force relogin


started training...
Number of observations in evaluation data: 1498
Number of observations in training data: 13502
2.554189920425415
1.804654598236084


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


VBox(children=(Label(value='0.012 MB of 0.032 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.369231…