### Data Loader

Before we can estimate any model, we should load in the data that we created in `linear.Rmd`. We'll reshape it so that we can sample random subjects in each batch.

In [1]:
import pandas as pd
from torch.utils.data import DataLoader
from transformer import LinearData
from transformer import Transformer

# use the data from ../generate
samples_df = pd.read_csv("../data/blooms.csv")

dataset = LinearData(samples_df)
loader = DataLoader(dataset, batch_size=16)
x, y = next(iter(loader))


  from .autonotebook import tqdm as notebook_tqdm


Next, we let's define a model with a forward function that lets us get predicted probabilities for the two classes given the historical microbiome profile so far.

In [2]:
import torch

model = Transformer()
z, probs = model(torch.randn((16, 50, 144)))

We can now train the model based on the input data loader, using a lightning trainer.

In [3]:
import lightning as L
from transformer import LitTransformer

lit_model = LitTransformer(model)
trainer = L.Trainer(max_epochs=40)
trainer.fit(model=lit_model, train_dataloaders=loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type        | Params
--------------------------------------
0 | model | Transformer | 8.5 M 
--------------------------------------
8.5 M     Trainable params
0         Non-trainable params
8.5 M     Total params
34.021    Total estimated model params size (MB)
/Users/ksankaran/miniconda3/envs/interpretability/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/Users/ksankaran/miniconda3/envs/interpretability/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (32) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower va

Epoch 39: 100%|██████████| 32/32 [00:06<00:00,  5.04it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=40` reached.


Epoch 39: 100%|██████████| 32/32 [00:06<00:00,  4.97it/s, v_num=0]


In [4]:
lit_model.model.eval()
p_hat = []
with torch.no_grad():
  for x, _ in loader:
    p_hat.append(lit_model.model(x)[1])

pd.DataFrame(torch.concatenate(p_hat)).to_csv("../data/p_hat.csv")

For future reference, here were the packages we installed for this package.

```
conda install conda-forge::lightning
conda install conda-forge::pandas
conda install conda-forge::tensorboard
conda install pytorch::captum
```