In [4]:
from helical.models.state import stateFineTuningModel
import scanpy as sc
from helical.models.state import stateConfig

# Load the desired dataset
adata = sc.read_h5ad("competition_support_set/competition_val_template.h5ad")

# Get the desired label class
cell_types = list(adata.obs.cell_type)
label_set = set(cell_types)

# Create the fine-tuning model (no need to specify var_dims location)
config = stateConfig(
    batch_size=8,
    model_dir="competition/first_run",
    model_config="configs/config.yaml",
    freeze_backbone=True
)

model = stateFineTuningModel(
    configurer=config, 
    fine_tuning_head="classification", 
    output_size=len(label_set),
)

# Process the data for training
data = model.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types)):
    cell_types[i] = class_id_dict[cell_types[i]]

# Fine-tune
model.train(train_input_data=data, train_labels=cell_types)

INFO:helical.models.state.state_finetune:Loading existing config.yaml from: configs/config.yaml
INFO:helical.models.state.state_finetune:Loading pre-trained model from: competition/first_run/final.ckpt
INFO:helical.models.state.state_finetune:Backbone model frozen - only fine-tuning head will be trained
INFO:helical.models.state.state_finetune:Processing data for state model fine-tuning.
INFO:helical.models.state.state_finetune:Loaded perturbation mapping with 19792 perturbations
INFO:helical.models.state.state_finetune:Successfully processed the data for state model fine-tuning.
INFO:helical.models.state.state_finetune:Loaded head weights from competition/first_run/head_weights.pt
INFO:helical.models.state.state_finetune:Optimizer set up for fine-tuning head only
INFO:helical.models.state.state_finetune:Starting Fine-Tuning
Fine-Tuning: epoch 1/1: 100%|██████████| 12366/12366 [00:13<00:00, 895.02it/s, loss=0]
INFO:helical.models.state.state_finetune:Fine-Tuning Complete. Epochs: 1
