In [1]:
import scanpy as sc
import torch, sys
import pandas as pd


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.3.5 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/lorenzoturiano/Bocconi/thesis/local/thesis-env/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/lorenzoturiano/Bocconi/thesis/local/thesis-env/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/lorenzoturiano/Bocconi/thesis/local/thesis-env/lib/python3.11/site-pack

In [7]:
# Load the .h5ad file
adata = sc.read_h5ad("datasets/combined/9d5eb472-3657-4035-8aea-d3053934e120.h5ad")

In [3]:
adata.obs["suspension_type"].value_counts()

suspension_type
nucleus    568598
cell       135698
Name: count, dtype: int64

In [6]:
adata.obs["cell_type"].value_counts()

cell_type
regular ventricular cardiac myocyte     190710
fibroblast                              138055
endothelial cell                        131505
mural cell                              104593
myeloid cell                             51426
regular atrial cardiac myocyte           45911
lymphocyte                               24922
neural cell                               6622
adipocyte                                 6347
mast cell                                 1853
endothelial cell of lymphatic vessel      1295
mesothelial cell                          1057
Name: count, dtype: int64

In [9]:
# Keep original
adata.obs["cell_type_fine"] = adata.obs["cell_type"].astype(str)

In [10]:
# Define mapping fine -> macro
macro_map = {
    "adipocyte": "adipocyte",
    "endothelial cell": "endothelial",
    "endothelial cell of lymphatic vessel": "endothelial",
    "fibroblast": "fibroblast",
    "lymphocyte": "lymphocyte",
    "mast cell": "mast",
    "mesothelial cell": "mesothelial",
    "mural cell": "mural",
    "myeloid cell": "myeloid",
    "neural cell": "neural",
    "regular atrial cardiac myocyte": "myocyte",
    "regular ventricular cardiac myocyte": "myocyte",
}

In [11]:
# 2) Apply mapping
adata.obs["cell_type"] = adata.obs["cell_type_fine"].map(macro_map)

# 3) Safety check: see what's unmapped
unmapped = adata.obs["cell_type"].isna().sum()
print("Unmapped cells:", unmapped)

Unmapped cells: 0


In [12]:
print(adata.obs["cell_type"].value_counts())

cell_type
myocyte        236621
fibroblast     138055
endothelial    132800
mural          104593
myeloid         51426
lymphocyte      24922
neural           6622
adipocyte        6347
mast             1853
mesothelial      1057
Name: count, dtype: int64


In [16]:
ct = pd.crosstab(adata.obs["cell_type"], adata.obs["suspension_type"]).reindex(
    columns=["cell", "nucleus"], fill_value=0
)

ct["total"] = ct.sum(axis=1)
ct["min(cell,nucleus)"] = ct[["cell", "nucleus"]].min(axis=1)

# add totals row (min on totals row will be the smaller of the two totals)
ct.loc["total"] = ct.sum(axis=0)

print(ct)

suspension_type    cell  nucleus   total  min(cell,nucleus)
cell_type                                                  
adipocyte             1     6346    6347                  1
endothelial       80629    52171  132800              52171
fibroblast         3163   134892  138055               3163
lymphocyte        11945    12977   24922              11945
mast                  1     1852    1853                  1
mesothelial         195      862    1057                195
mural             22641    81952  104593              22641
myeloid           14702    36724   51426              14702
myocyte            1924   234697  236621               1924
neural              497     6125    6622                497
total            135698   568598  704296             107240


In [None]:
# Split datasets
RNA_data   = adata[adata.obs["suspension_type"] == "cell"].copy()
GEX_data = adata[adata.obs["suspension_type"] == "nucleus"].copy()

# Save
RNA_data.write_h5ad("datasets/combined/heart_suspension_cell.h5ad", compression="gzip")
GEX_data.write_h5ad("datasets/combined/heart_suspension_nucleus.h5ad", compression="gzip")

: 

## Training

In [None]:
sys.path.append('modules/')

from VAE_UNET import RNA_VAE_UNET
from data_loader import CustomDataloader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
custom = CustomDataloader(GEX_data, RNA_data, batch_size=64, seed=42)
dataloader = custom.get_dataloader()

In [None]:
vae = RNA_VAE_UNET(input_dim=GEX_data.n_vars, output_dim=RNA_data.n_vars, latent_dim=32, normalized=False, device=device)
vae.train(dataloader, n_epochs=100, beta=0.1, threshold=False)

In [None]:
RNA_fake = vae.generate_anndata(GEX_data, RNA_data, threshold=False)
print("Max value gene_exp:", GEX_data.X.max())
print("Max value rna_exp :", RNA_data.X.max())
print("Max value rna_fake:", RNA_fake.X.max())
RNA_fake.write(path+"Fake_RNA_VAE_UNET.h5ad", compression="gzip")
print("Object saved")

In [3]:
# Load the .h5ad file
adata_cortex = sc.read_h5ad("datasets/single-nuclei/fe86d86c-16cc-4047-a741-d9e186b35175.h5ad")

In [8]:
genes_cortex = adata_cortex.var_names
genes = adata.var_names

isec = genes.intersection(genes_cortex)

isec

Index(['ENSG00000229905', 'ENSG00000237491', 'ENSG00000177757',
       'ENSG00000225880', 'ENSG00000230368', 'ENSG00000272438',
       'ENSG00000223764', 'ENSG00000187634', 'ENSG00000188976',
       'ENSG00000187961',
       ...
       'ENSG00000278704', 'ENSG00000277400', 'ENSG00000274847',
       'ENSG00000276256', 'ENSG00000273748', 'ENSG00000278817',
       'ENSG00000277196', 'ENSG00000278384', 'ENSG00000276345',
       'ENSG00000271254'],
      dtype='object', length=26467)