### Transformer on Simulated Trajectories

Before we can estimate any model, we need to reshape the data so that we can sample random subjects in each batch. This is accomplished using the `LinearData` loader, defined in the accompanying `transformer.py` script. We have reserved 375 samples for training and 125 for validation. You can download the data from [this link](https://github.com/krisrs1128/interpretability_review/tree/main/data).

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset
from transformer import LinearData
from transformer import Transformer

torch.manual_seed(20240210)
subsample_size = 10000
samples_df = pd.read_csv("../data/blooms.csv", nrows=subsample_size)

dataset = LinearData(samples_df)
train = Subset(dataset, torch.arange(int(0.75 * subsample_size)))
validation = Subset(dataset, torch.arange(int(0.75 * subsample_size), subsample_size))
loaders = {
  "train": DataLoader(train, batch_size=16),
  "validate": DataLoader(validation, batch_size=16)
}

Next, we let's define a model with a forward function that lets us get predicted probabilities for the two classes given the historical microbiome profile so far. Just to make sure that this works as expected, let's pass in some random data.

In [None]:
model = Transformer()
z, probs = model(torch.randn((16, 50, 144)))

We can now train the model based on the input data loader, using a lightning trainer. Training and validation accuracies can be checked by starting a tensorboard viewer in the `lightning_logs` directory (i.e., `tensorboard --logdir=path/to/lightning_logs`).

In [None]:
import lightning as L
from transformer import LitTransformer

lit_model = LitTransformer(model)
trainer = L.Trainer(max_epochs=70)

%time trainer.fit(lit_model, loaders["train"], loaders["validate"])

In case we're interested, we can extract predicted probabilities for each sample. We set the model to evaluation mode and iterate over each sample in both the training and validation loaders.

In [None]:
lit_model.model.eval()
p_hat = []
with torch.no_grad():
  for x, _ in loaders["train"]:
    p_hat.append(lit_model.model(x)[1])

  for x, _ in loaders["validate"]:
    p_hat.append(lit_model.model(x)[1])