Load data 

In [1]:
import sys
sys.path.append("/Users/joaomata/Desktop/DTU/DeepLearning/ProjectDL/Midaa")
import midaa as maa
import scanpy as sc
import numpy as np
import pyro


In [2]:

# Load the real dataset
adata_beta_hfd = sc.read_h5ad(r"data/beta_cells_hfd.h5ad")

# Dataset is too big, subsample for faster testing
adata_beta_hfd = adata_beta_hfd[:1000, :]
print(f"Subsampled dataset shape: {adata_beta_hfd.X.shape}")

# Convert to MIDAA input
input_matrix, norm_factors, input_distribution = maa.get_input_params_adata(adata_beta_hfd)

# If you want to force RNA-seq likelihood, uncomment the next line:
# input_distribution = ["NB"]

# Check input matrix
print("Number of input matrices:", len(input_matrix))         # should be 1
print("Shape of first input matrix:", input_matrix[0].shape)  # e.g., (2638, 1838)

# Check normalization factors
print("Number of norm factors:", len(norm_factors))
print("Shape of first norm factor array:", norm_factors[0].shape)
print("First 5 norm factor values:", norm_factors[0][:5])

# Input distribution
print("Input distribution:", input_distribution)  # usually ["NB"] for RNA-seq

# Convert sparse matrix to dense if needed
X = input_matrix[0].A if hasattr(input_matrix[0], "A") else input_matrix[0]
print("Min value in X:", X.min())
print("Any negatives?", (X < 0).sum())

# Pyro version
print("Pyro version:", pyro.__version__)  # should be 1.9.1

# Extra checks
print("Min:", X.min(), "Max:", X.max(), "NaNs:", np.isnan(X).sum())
print("Norm factors min:", norm_factors[0].min())


Subsampled dataset shape: (1000, 16483)


AttributeError: module 'midaa' has no attribute 'get_input_params_adata'

Run model

In [None]:
# Fit MIDAA
aa_result = maa.fit_MIDAA(
    input_matrix,
    norm_factors,
    input_distribution,
    narchetypes=4,
    torch_seed=42,
    
)

Look at code results and saving the matrices

In [None]:
import numpy as np

def extract_mida_matrices(aa_result, input_matrix):
    iq = aa_result["inferred_quantities"]

    # A: (n_cells × n_archetypes)
    A = iq["A"]

    # Z: latent representation (n_cells × latent_dim)
    Z = iq["Z"]

    # B: archetype positions in latent space (n_archetypes × latent_dim)
    B = iq["B"]

    # archetypes_inferred: gene weights for each archetype
    # shape: (n_genes, n_archetypes)
    C = iq["archetypes_inferred"]

    # X: original input data
    X = input_matrix[0]   # (n_cells × n_genes)

    # Labels = archetype with highest membership
    labels = np.argmax(A, axis=1)

    return A, B, Z, C, X, labels

A, B, Z, C, X, labels = extract_mida_matrices(aa_result, input_matrix)

print("A (memberships):", A.shape)
print("B (latent archetype coords):", B.shape)
print("Z (latent cells):", Z.shape)
print("C (gene weights):", C.shape)
print("X (input data):", X.shape)
print("labels:", labels.shape)

import torch

# Save   three core MIDAA matrices
torch.save(
    {'A': A, 'B': B, 'C': C},
    "midaa_core_matrices.pth"
)

print("Saved A, B, C matrices to midaa_core_matrices.pth")


import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(7, 6))

scatter = plt.scatter(
    Z[:, 0], Z[:, 1],
    c=labels,
    s=10,
    cmap="tab10",       # better for discrete clusters
    alpha=0.8
)

plt.title("MIDAA Latent Space (Z) — Cells Colored by Archetype")
plt.xlabel("Z1")
plt.ylabel("Z2")

# Create legend for 4 archetypes
handles, _ = scatter.legend_elements()
plt.legend(handles, [f"Archetype {i}" for i in range(4)], title="Archetypes")

plt.tight_layout()
plt.show()

