# Training SetQuence
This notebook can be used as a template to train a SetQuence model in a supervised manner, for diverse goal tasks from training data consisting on sets of sequences.

## 1. Configuring the notebook

### Loading modules (basic and SetQuence-specific)

In [None]:
import torch
from pathlib import Path
from sklearn.dummy import DummyClassifier

In [None]:
from setquence.base import Config, Environment
from setquence.data import get_dataset_loader
from setquence.distributed import get_distributer
from setquence.models import get_model
from setquence.utils import get_optimizer, ns_to_dict
from setquence.utils.metrics import classification_metrics
from setquence.utils.slurm import slurm_config_to_dict

### Configuring the path to the json configuration file

In [None]:
CONFIG_PATH = "train_example_config.json"

### Loading the config
The configuration files are loaded via the <code>Config</code> class, loaded from <code>setquence.base</code>. A <code>Path</code> containing a properly-formatted json file is passed as an argument. A configuration object is automatically generated, putting default values where missing.

In [None]:
config = Config(Path(CONFIG_PATH))

### Setting up the computational environment
This parses the configuration from SLURM, if available, and builds an <code>Environment</code> object that contains settings such as available GPUs, or ranks (for distributed training).

In [None]:
slurm_env = slurm_config_to_dict()
env = Environment(slurm_env)

## 2. Configuring the model and the environment

### Create a model
The function <code>get_model</code> initializes an instance of the model specified in <code>config.model.name</code>, as any architecture implemented in the current version of SetQuence. The configuration relative to the model (<code>config.model.config</code>) and the environment object must be passed during initialization. Model parameters are randomly initialized, except for the DNABERT encoder, if specified (see instructions at [configs/template_instructions.md](https://github.com/danilexn/setquence/blob/main/configs/template_instructions.md))

In [None]:
model = get_model(config.model.name)(config=config.model.config, env=env)

### Create a distributer
The function <code>get_distributer</code> initializes an instance of the <code>Distributer</code>, which takes care of moving data across workers, upon any of the specified distribution strategies (e.g., Distributed Data Parallel, Data Parallel). 

Then, the distributer is attached to the previously created model.

In [None]:
dist = get_distributer(config.model.distributer)(env).init(config.model.distribution)
model.distribute(dist)

## 3. Configuring and loading the dataset

### Loading the training dataset
Upon a specified <code>config.data.dataloader</code>, the data under <code>config.data.train</code> will be loaded via the function <code>get_dataset_loader</code>. The environment has to be passed during initialization. This is required by some dataloaders.

In [None]:
train_dataset = get_dataset_loader(config.data.dataloader)(config.data.train, env)

### Configuring data-loading
The previously initialized distributer contains a method <code>get_sampler</code>, which makes sure to return the best strategy for distributing the dataset across workers upon sampling.

In [None]:
train_sampler = dist.get_sampler(train_dataset, env, shuffle=config.data.train.shuffle)
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=config.data.train.batch_size,
    num_workers=0,
    sampler=train_sampler,
    drop_last=True
)

### ... same for test data

In [None]:
test_dataset = get_dataset_loader(config.data.dataloader)(config.data.test, env)
test_sampler = dist.get_sampler(test_dataset, env)
test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=config.data.test.batch_size,
    num_workers=0,
    sampler=test_sampler
)

## 4. Initialize the optimizer and callback during training

In [None]:
optimizer = get_optimizer(model, config.training.optimizer)

In [None]:
def callback_fn(model, *args, **kwargs):
    prediction_list, label_list = model.predict(test_dataloader)
    if env.rank == 0:
        output = torch.tensor([], dtype=torch.float32)
        for prediction in prediction_list:
            output = torch.cat((output, prediction.to("cpu")), dim=0)

        labels = torch.tensor([], dtype=torch.float32)
        for label in label_list:
            labels = torch.cat((labels, label.to("cpu")), dim=0)
    else:
        return

    with torch.no_grad():
        pred_labels = torch.argmax(output, dim=1)
        probs = torch.nn.Softmax()(output)

    test_metrics = classification_metrics(labels, pred_labels, probs)
    print(test_metrics)

## 5. Train the model!

In [None]:
model.fit(
    config.training.config,
    train_dataloader=train_dataloader,
    optimizer=optimizer,
    callback_fn=callback_fn,
)