In [2]:
import pandas as pd
import numpy as np
import anndata
import json
import os
import scanpy as sc
from sklearn.model_selection import train_test_split

## Data download

In [3]:

# Define data paths
data_dir = "data_input"
os.makedirs(data_dir, exist_ok=True)

pancreas_adata_path = os.path.join(data_dir, "pancreas_full.h5ad")
train_path = os.path.join(data_dir, "pancreas_train.h5ad")
valid_path = os.path.join(data_dir, "pancreas_valid.h5ad")
test_path  = os.path.join(data_dir, "pancreas_test.h5ad")

# Download if missing, otherwise load from local file
pancreas_adata = sc.read(
    pancreas_adata_path,
    backup_url="https://figshare.com/ndownloader/files/24539828",
)

# Split dataset by technology: keep smartseq2/celseq2 as held-out test
query_mask = pancreas_adata.obs["tech"].isin(["smartseq2", "celseq2"]).to_numpy()
pancreas_no_test = pancreas_adata[~query_mask].copy()
pancreas_test    = pancreas_adata[ query_mask].copy()

# 80/20 train/valid split on the remaining data, stratified by technology
y = pancreas_no_test.obs["tech"].astype("category")
indices = np.arange(pancreas_no_test.n_obs)

idx_train, idx_valid = train_test_split(
    indices,
    test_size=0.20,
    train_size=0.80,
    random_state=42,
    shuffle=True,
    stratify=y  # stratify by technology
)

pancreas_train = pancreas_no_test[idx_train].copy()
pancreas_valid = pancreas_no_test[idx_valid].copy()

# Save splits
pancreas_train.write(train_path)
pancreas_valid.write(valid_path)
pancreas_test.write(test_path)

print(
    f"Train: {pancreas_train.n_obs} cells | "
    f"Valid: {pancreas_valid.n_obs} cells | "
    f"Test: {pancreas_test.n_obs} cells"
)

# Print counts per technology
print("\nCells per technology:")
for name, ad in [("Train", pancreas_train),
                 ("Valid", pancreas_valid),
                 ("Test", pancreas_test)]:
    counts = ad.obs["tech"].value_counts().sort_index()
    print(f"\n{name} split:")
    for tech, n in counts.items():
        print(f"  {tech}: {n}")

# --- Cleanup: delete the original full dataset file ---
del pancreas_adata  # drop reference to ensure no open handle
try:
    if os.path.exists(pancreas_adata_path):
        os.remove(pancreas_adata_path)
        print(f"Deleted '{pancreas_adata_path}'")
except Exception as e:
    print(f"[WARN] Could not delete '{pancreas_adata_path}': {e}")


# --- Save full gene list to JSON ---
all_genes = pancreas_train.var_names.tolist()

genes_json_path = os.path.join("data_input", "all_genes_list.json")
os.makedirs("data_input", exist_ok=True)

with open(genes_json_path, "w") as f:
    json.dump(all_genes, f, indent=2)

print(f"Saved {len(all_genes)} genes to {genes_json_path}")

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 301M/301M [00:22<00:00, 13.7MB/s] 


Train: 9362 cells | Valid: 2341 cells | Test: 4679 cells

Cells per technology:

Train split:
  celseq: 803
  fluidigmc1: 510
  inDrop1: 1550
  inDrop2: 1379
  inDrop3: 2884
  inDrop4: 1042
  smarter: 1194

Valid split:
  celseq: 201
  fluidigmc1: 128
  inDrop1: 387
  inDrop2: 345
  inDrop3: 721
  inDrop4: 261
  smarter: 298

Test split:
  celseq2: 2285
  smartseq2: 2394
Deleted 'data_input/pancreas_full.h5ad'
Saved 19093 genes to data_input/all_genes_list.json


## Data inspection

In [4]:
# Read the data
adata = anndata.read_h5ad("./data_input/pancreas_train.h5ad")

In [5]:
# Display the AnnData object summary
print("AnnData object summary:")
print(adata)

# Display the first few rows of the observation metadata
print("\nFirst 5 rows of adata.obs:")
print(adata.obs.head())

# Display available layers
print("\nAvailable layers in adata:")
print(adata.layers.keys())

# Display the first 5x5 block of the counts layer (if it exists)
print("\nFirst 5x5 of counts layer:")
print(pd.DataFrame(adata.layers["counts"][:5, :5], 
                    columns=adata.var_names[:5], 
                    index=adata.obs_names[:5]))

AnnData object summary:
AnnData object with n_obs × n_vars = 9362 × 19093
    obs: 'tech', 'celltype', 'size_factors'
    layers: 'counts'

First 5 rows of adata.obs:
                                   tech celltype  size_factors
3rd-C86_S85                  fluidigmc1    delta      5.060723
human3_lib4.final_cell_0804     inDrop3    alpha      0.010361
human3_lib4.final_cell_0815     inDrop3     beta      0.011553
Sample_163                      smarter     beta      1.000000
human3_lib1.final_cell_0737     inDrop3    alpha      0.014493

Available layers in adata:
KeysView(Layers with keys: counts)

First 5x5 of counts layer:
                             A1BG  A1CF  A2M  A2ML1  A4GALT
3rd-C86_S85                  14.7  11.0  0.0    3.0     0.0
human3_lib4.final_cell_0804   0.0   0.0  0.0    0.0     0.0
human3_lib4.final_cell_0815   0.0   0.0  0.0    0.0     0.0
Sample_163                    0.0   0.0  0.0    0.0     0.0
human3_lib1.final_cell_0737   0.0   1.0  0.0    0.0     0.0


In [24]:
# 1. How many unique technologies are present, and what are their names?
print("adata.obs[tech]")
print(adata.obs["tech"].head())
unique_techs = adata.obs["tech"].unique()
print("Unique technologies:", unique_techs)

num_tech = adata.obs["tech"].nunique()
print("---------------")
print(f"Number of unique technologies: {num_tech}")

adata.obs[tech]
3rd-C86_S85                    fluidigmc1
human3_lib4.final_cell_0804       inDrop3
human3_lib4.final_cell_0815       inDrop3
Sample_163                        smarter
human3_lib1.final_cell_0737       inDrop3
Name: tech, dtype: category
Categories (7, object): ['celseq', 'fluidigmc1', 'inDrop1', 'inDrop2', 'inDrop3', 'inDrop4', 'smarter']
Unique technologies: ['fluidigmc1', 'inDrop3', 'smarter', 'inDrop1', 'celseq', 'inDrop4', 'inDrop2']
Categories (7, object): ['celseq', 'fluidigmc1', 'inDrop1', 'inDrop2', 'inDrop3', 'inDrop4', 'smarter']
---------------
Number of unique technologies: 7


In [None]:
# 2. How many samples (cells) belong to each technology?
tech_counts = adata.obs["tech"].value_counts()
print(tech_counts)

tech
inDrop3       2884
inDrop1       1550
inDrop2       1379
smarter       1194
inDrop4       1042
celseq         803
fluidigmc1     510
Name: count, dtype: int64


In [None]:
# 3. What is the total number of genes measured in the dataset?
nbr_genes = adata.layers["counts"].shape[1]
print(f"Number of total number of genes: {nbr_genes}")

Number of total number of genes: 19093


In [None]:
# 4. What is the total number of samples (cells) in the dataset?
nbr_samples = adata.layers["counts"].shape[0]
print(f"Number of total number of samples: {nbr_samples}")

Number of total number of genes: 9362


## Variance analysis

In [None]:
# 5. For each technology, calculate the variance of each gene across all cells.

# Extract counts matrix and create a DataFrame
counts_df = pd.DataFrame(
    adata.layers["counts"],
    index=adata.obs_names,
    columns=adata.var_names
)

# Add the 'tech' column from adata.obs
counts_df["tech"] = adata.obs["tech"].values

# Group by tech and calculate variance for each gene
gene_variances_by_tech = counts_df.groupby("tech").var()
print(gene_variances_by_tech)


                             A1BG  A1CF  A2M  A2ML1  A4GALT  A4GNT  AA06  \
3rd-C86_S85                  14.7  11.0  0.0    3.0     0.0    0.0   0.0   
human3_lib4.final_cell_0804   0.0   0.0  0.0    0.0     0.0    0.0   0.0   
human3_lib4.final_cell_0815   0.0   0.0  0.0    0.0     0.0    0.0   0.0   
Sample_163                    0.0   0.0  0.0    0.0     0.0    0.0   0.0   
human3_lib1.final_cell_0737   0.0   1.0  0.0    0.0     0.0    0.0   0.0   

                                   AAAS       AACS  AACSP1  ...  ZWILCH  \
3rd-C86_S85                    2.000000  32.650002     0.0  ...     1.0   
human3_lib4.final_cell_0804    0.000000   0.000000     0.0  ...     0.0   
human3_lib4.final_cell_0815    0.000000   0.000000     0.0  ...     0.0   
Sample_163                   125.726196   0.000000     0.0  ...     0.0   
human3_lib1.final_cell_0737    1.000000   0.000000     0.0  ...     0.0   

                             ZWINT  ZXDA    ZXDB    ZXDC      ZYG11B  ZYX  \
3rd-C86_S85    

  gene_variances_by_tech = counts_df.groupby("tech").var()


                  A1BG          A1CF           A2M      A2ML1     A4GALT  \
tech                                                                       
celseq        0.711420      0.803843      0.077942   0.314592   0.057393   
fluidigmc1  836.826904  48139.230469  80871.554688  27.673862  92.704498   
inDrop1       0.014628      0.337339      0.217571   0.000000   0.070607   
inDrop2       0.004335      0.480227      0.281990   0.000000   0.097883   
inDrop3       0.002079      0.226632      0.042091   0.000000   0.153713   
inDrop4       0.003828      0.545368      0.163191   0.000000   0.150532   
smarter      99.531158    418.537323      0.415753   0.081800   0.111717   

               A4GNT  AA06          AAAS           AACS    AACSP1  ...  \
tech                                                               ...   
celseq      0.005021   0.0      0.157013       0.792771  0.008743  ...   
fluidigmc1  0.001961   0.0  37999.804688  110174.632812  0.000000  ...   
inDrop1     0.00000

In [None]:
# 6. Compute a weighted average of these variances for each gene, using the number of cells per technology as weights.

# Count cells per technology
cell_counts = adata.obs["tech"].value_counts()
print("Count cells per technology:")
print(cell_counts)

Count cells per technology:
tech
inDrop3       2884
inDrop1       1550
inDrop2       1379
smarter       1194
inDrop4       1042
celseq         803
fluidigmc1     510
Name: count, dtype: int64


In [38]:
# 7. Save a list containing the top 2000 genes
pd.set_option("display.float_format", "{:.6f}".format)

# Sort genes by weighted variance in descending order and take the top 2000

# Align weights with variance table
weights = cell_counts.loc[gene_variances_by_tech.index]

# Multiply variances by weights
weighted_variances = gene_variances_by_tech.mul(weights, axis=0)

# Weighted average variance per gene
weighted_avg_variance = weighted_variances.sum(axis=0) / weights.sum()

top_2000_genes = weighted_avg_variance.sort_values(ascending=False).head(2000)
print("Print top variable genes:")
print(top_2000_genes.head())


# Get just the gene names as a list
top2000_genes = top_2000_genes.index.tolist()

print("\nShow first 5 genes")
print(top2000_genes[:5])  # Show first 5 genes
print(f"\nNumber of genes selected: {len(top2000_genes)}")


# Save to a JSON file
os.makedirs("./data_output", exist_ok=True)
with open("./data_output/top2000_genes_centralized.json", "w") as f:
    json.dump(top2000_genes, f, indent=2)

Print top variable genes:
GCG   5211096399.928920
INS    203439060.944606
TTR    160181022.505291
SST    139494309.512972
PPY    116951271.415973
dtype: float64

Show first 5 genes
['GCG', 'INS', 'TTR', 'SST', 'PPY']

Number of genes selected: 2000
