In [None]:
import os
import scanpy as sc

input_dir = "/work/magroup/skrieger/tissue_generator/quantized_slices/subclass_z1_d338_0_rotated"

sorted_slices = [
    "sec_05.h5ad", "sec_06.h5ad", "sec_08.h5ad", "sec_09.h5ad", "sec_10.h5ad",
    "sec_11.h5ad", "sec_12.h5ad", "sec_13.h5ad", "sec_14.h5ad", "sec_15.h5ad",
    "sec_16.h5ad", "sec_17.h5ad", "sec_18.h5ad", "sec_19.h5ad", "sec_24.h5ad",
    "sec_25.h5ad", "sec_26.h5ad", "sec_27.h5ad", "sec_28.h5ad", "sec_29.h5ad",
    "sec_30.h5ad", "sec_31.h5ad", "sec_32.h5ad", "sec_33.h5ad", "sec_35.h5ad",
    "sec_36.h5ad", "sec_37.h5ad", "sec_38.h5ad", "sec_39.h5ad", "sec_40.h5ad",
    "sec_42.h5ad", "sec_43.h5ad", "sec_44.h5ad", "sec_45.h5ad", "sec_46.h5ad",
    "sec_47.h5ad", "sec_48.h5ad", "sec_49.h5ad", "sec_50.h5ad", "sec_51.h5ad",
    "sec_52.h5ad", "sec_54.h5ad", "sec_55.h5ad", "sec_56.h5ad", "sec_57.h5ad",
    "sec_58.h5ad", "sec_59.h5ad", "sec_60.h5ad", "sec_61.h5ad", "sec_62.h5ad",
    "sec_64.h5ad", "sec_66.h5ad", "sec_67.h5ad"
]

sorted_slices = ["sec_30.h5ad", "sec_31.h5ad", "sec_32.h5ad", "sec_33.h5ad", "sec_35.h5ad", "sec_36.h5ad", "sec_37.h5ad", "sec_38.h5ad", "sec_39.h5ad", "sec_40.h5ad"]
# sorted_slices = ["sec_30.h5ad", "sec_31.h5ad", "sec_32.h5ad"]
# sorted_slices = ["sec_40.h5ad"]
# sorted_slices=["sec_38.h5ad", "sec_39.h5ad", "sec_40.h5ad"]
representations = [sc.read_h5ad(os.path.join(input_dir, fname)) for fname in sorted_slices]
for rep in representations:
    rep.obsm["spatial"]=rep.obsm["original_spatial"]

slices=representations
print(f"Loaded {len(slices)} anndata objects into memory.")

In [None]:
# P(location)
from alignment_model import AlignementModel

density_model = AlignementModel(slices, z_posn=[-1, 0, 1], pin_key="parcellation_structure", use_ccf=True)
density_model.fit()
aligned_slices = density_model.get_common_coordinate_locations()


In [None]:
# P(gene_exp|token)
from gene_exp_model import GeneExpModel
# aligned_slices=aligned_slices[-5:-4]
gene_exp_model = GeneExpModel(aligned_slices, use_subclass=True)
gene_exp_model.fit()
slices_tokenized = gene_exp_model.get_tokenized_slices()
# test_slice=slices_tokenized[0]
test_slice=slices_tokenized[-5]
val_slice=slices_tokenized[-2]
slices_tokenized=slices_tokenized[:-5]+slices_tokenized[-4:-2]+slices_tokenized[-1:]


In [None]:
# P(region|location)
from celltype_model import CelltypeModel

region_model = CelltypeModel(slices_tokenized,gene_exp_model.num_tokens, val_slice=val_slice, epochs=100,learning_rate=0.001, batch_size=32, device="cuda")

region_model.fit()
# region_model.load_model("ops/distance_transform_temp/best_model.pt")

In [None]:
from matplotlib.pyplot import rc_context
from metrics import *
from analysis import *

device="cuda"
xyz_samples = torch.tensor(test_slice.obsm["aligned_spatial"], dtype=torch.float32).to(device)
density_tensor = torch.tensor(test_slice.obs["entropy"], dtype=torch.float32).to(device)
xyz_samples= torch.cat([xyz_samples,density_tensor.unsqueeze(-1)],dim=-1)
adata_argmax, _ = generate_anndata_from_samples(region_model, xyz_samples, device, sample_from_probs=False)
adata_sampled, _ = generate_anndata_from_samples(region_model, xyz_samples, device, sample_from_probs=True)
print(soft_accuracy(test_slice.obs["token"].to_numpy().tolist(),test_slice.obsm["aligned_spatial"],adata_sampled.obs["token"].tolist(),adata_sampled.obsm["spatial"],radius=0.05))

assign_shared_colors([adata_argmax,adata_sampled,test_slice], color_key="token")
plot_spatial_with_palette(test_slice, color_key="token", spot_size=0.001, figsize=(30,30))
plot_spatial_with_palette(test_slice, color_key="entropy", spot_size=0.002, figsize=(30,30))
plot_spatial_with_palette(adata_argmax, color_key="token", spot_size=0.001, figsize=(30,30))
plot_spatial_with_palette(adata_sampled, color_key="token", spot_size=0.001, figsize=(30,30))