In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import scanpy as sc
import torch.nn.functional as F
from torch.utils.data import DataLoader
import joblib
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, ConfusionMatrixDisplay, confusion_matrix
from peft import get_peft_model, LoraConfig, TaskType, PeftModel, PeftConfig, LoraModel
from tqdm import tqdm
from scimilarity.nn_models import Encoder, Decoder
import os
import json

from app.utils import load_preprocessed_data, load_artifacts, compute_embeddings, evaluate_knn
from app.model.wrapper import load_encoder, load_decoder, load_lora_encoder
from app.logger_config import get_logger
logger = get_logger()

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ARTIFACTS_DIR = "artifacts"
# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

In [None]:

# ----------------- CONFIG -----------------
RAW_DATA_PATH = "data/raw/dataset.h5ad"
SAMPLE_SIZE = 10  # number of cells to sample
# ------------------------------------------

# 1️⃣ Load raw dataset
adata = sc.read_h5ad(RAW_DATA_PATH)
print(f"Loaded raw data: {adata.shape[0]} cells x {adata.shape[1]} genes")

# 2️⃣ Get gene names from AnnData
gene_order = adata.var_names.tolist()

# 3️⃣ Sample N cells
sample_indices = np.random.choice(adata.n_obs, size=SAMPLE_SIZE, replace=False)


Loaded raw data: 65479 cells x 31460 genes


In [20]:
sample_indices

array([64405, 16320, 54512,  7774, 45822, 55116, 63988, 10211, 11982,
       22610])

In [41]:
for e, idx in enumerate(sample_indices):
    # Extract row as dense vector
    if hasattr(adata.X[idx], "toarray"):  # sparse
        expr_vec = adata.X[idx].toarray().flatten()
    else:
        expr_vec = np.array(adata.X[idx]).flatten()

    # Map gene names to expression values
    expr_dict = {gene: float(expr_vec[i]) for i, gene in enumerate(adata.var_names) if expr_vec[i] != 0.0}

    # Save as JSON
    json_path = os.path.join("data/json_samples", f"{e+1}.json")
    os.makedirs(os.path.dirname(json_path), exist_ok=True)
    with open(json_path, "w") as f:
        json.dump({"expression": expr_dict}, f, indent=2)