### Data Loader

Before we can estimate any model, we should load in the data that we created in `concept.Rmd`. We'll reshape it so that we can sample random subjects in each batch.

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset
from concept import ConceptData

# use the data from ../generate
torch.manual_seed(20240210)
samples_df = pd.read_csv("../data/blooms.csv")
concepts = pd.read_csv("../data/concepts.csv")

dataset = ConceptData(samples_df, concepts)
train, validation = Subset(dataset, torch.arange(375)), Subset(dataset, torch.arange(375, 500))
loaders = {
  "train": DataLoader(train, batch_size=16),
  "validate": DataLoader(validation, batch_size=16)
}

We can now train the model based on the input data loader, using a lightning trainer.

In [None]:
import lightning as L
from concept import ConceptBottleneck, LitConcept

concepts
model = ConceptBottleneck()
lit_model = LitConcept(model)
trainer = L.Trainer(max_epochs=100, default_root_dir="concept_logs")
trainer.fit(lit_model, loaders["train"], loaders["validate"])

In [None]:
import torch

lit_model.model.eval()
p_hat = []
with torch.no_grad():
  for x, c, _ in loaders["train"]:
    p_hat.append(lit_model.model(x)[1])
  for x, c, _ in loaders["validate"]:
    p_hat.append(lit_model.model(x)[1])

pd.DataFrame(torch.concatenate(p_hat)).to_csv("../data/p_hat_concept.csv")

For future reference, here were the packages we installed for this package.

```
conda create -n interpretability python=3.12
conda activate interpretability

conda install -y conda-forge::lightning
conda install -y conda-forge::pandas
conda install -y conda-forge::tensorboard
conda install -y conda-forge::transformers
conda install -y pytorch::captum
```