In [None]:
import torch
import scanpy as sc
import numpy as np
from tqdm import tqdm

from ot_cfm_package import (
    load_adata, create_ot_cfm_model, create_ot_cfm_optimizer, 
    create_training_dataloader, ExactOptimalTransportConditionalFlowMatcher, 
    correct_sources
)

In [None]:
# Load dataset
target2_moa = load_adata("../../data/Tim_target2_wellres_featuresimputed_druginfoadded_pycytominer.h5ad")

In [None]:
# Init device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ot_cfm_model = create_ot_cfm_model(adata=target2_moa, use_pca=True).to(device)
ot_cfm_optimizer = create_ot_cfm_optimizer(ot_cfm_model)
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=0.1)

In [None]:
# Hyperparameters
epochs = 50  
batch_size = 64

dataloader = create_training_dataloader(target2_moa, batch_size=batch_size, exclude_source="source_2", use_pca=True)

In [None]:
# Training loop
for epoch in range(epochs):
    for source_batch, target_batch, source_one_hot in dataloader:
        # Move batches to the device (GPU or CPU)
        source_batch = source_batch.to(device)
        target_batch = target_batch.to(device)
        source_one_hot = source_one_hot.to(device)

        # Forward pass
        outputs = ot_cfm_model(source_batch, source_one_hot)
        loss = FM.compute_loss(outputs, target_batch)

        # Backpropagation
        ot_cfm_optimizer.zero_grad()
        loss.backward()
        ot_cfm_optimizer.step()

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")

In [None]:
# Correction: Correct all sources except source_2 using the trained OT-CFM model
corrected_sources = correct_sources(target2_moa, ot_cfm_model, exclude_source="source_2", device=device)

# Combine corrected sources with source_2 data
all_corrected_adata = sc.concat(corrected_sources, axis=0)
source_2_data = target2_moa[target2_moa.obs["Metadata_Source"] == "source_2"].copy()
all_corrected_adata = sc.concat([all_corrected_adata, source_2_data], join="outer")

In [None]:
# Plot UMAP visualization
sc.pl.umap(all_corrected_adata, color="Metadata_Source")