# Import directories

In [1]:
import os

pwd = os.getcwd()
print(pwd) 

# RNA sequence and location data (h5ad files)
raw_data_dir = os.path.join(pwd, "data/chevrier/h5ad") # Path to /h5ad folder

# h5ad files to load
files_to_load = [
    ("MOB_1", raw_data_dir, "GSM8243006_MOB_1_raw_annotated.h5ad"),
]

# HE images
HE_dir = os.path.join(pwd, "data/chevrier/images/Downsampled")

# Raw slide images
raw_img_dir = os.path.join(pwd, "data/chevrier/images/Raw")

# Raw MOB images
MOB_dir = os.path.join(raw_img_dir, "MOB/MOB_raw/Raw")

# Directory to save raw subimages
MOB_subimg_dir = os.path.join(MOB_dir, "Subimages")
print(MOB_subimg_dir)


/home/jesu3424/EmpiricalBayes/cebmf_seqRNA
/home/jesu3424/EmpiricalBayes/cebmf_seqRNA/data/chevrier/images/Raw/MOB/MOB_raw/Raw/Subimages


# Helper functions

In [2]:
import os
import scanpy as sc
import anndata as ad
import pandas as pd
import numpy as np

def _resolve_image_path(image_dir, sample, exts=(".png", ".tif", ".tiff", ".jpg", ".jpeg")):

    for ext in exts:
        p = os.path.join(image_dir, f"{sample}{ext}")
        if os.path.exists(p):
            return p
    hits = []
    for ext in exts:
        hits.extend(glob.glob(os.path.join(image_dir, f"{sample}*{ext}")))
    return hits[0] if hits else None

def load_and_prepare(sample_name, dir_path, filename):
    path = os.path.join(dir_path, sample_name, filename)
    adata = sc.read_h5ad(path)

    # Barcode â†’ make globally unique and set as index
    if "Barcode" in adata.obs.columns:
        raw_bc = adata.obs["Barcode"].astype(str)
    else:
        raw_bc = pd.Index(adata.obs_names).astype(str)
    uniq_bc = sample_name + "_" + raw_bc
    adata.obs["Barcode"] = uniq_bc
    adata.obs_names = uniq_bc

    # SampleID and Sample
    adata.obs["SampleID"] = sample_name
    if "Sample" not in adata.obs.columns:
        adata.obs["Sample"] = sample_name

    # Harmonize Subregion
    sub_a = adata.obs["Subregion"] if "Subregion" in adata.obs.columns else pd.Series(index=adata.obs_names, dtype="object")
    sub_b = adata.obs["Tissue Subregion"] if "Tissue Subregion" in adata.obs.columns else pd.Series(index=adata.obs_names, dtype="object")
    adata.obs["Subregion"] = sub_a.combine_first(sub_b)
    if "Tissue Subregion" in adata.obs.columns:
        adata.obs.drop(columns=["Tissue Subregion"], inplace=True)

    # Drop junk column
    if "Unnamed: 0" in adata.obs.columns:
        adata.obs.drop(columns=["Unnamed: 0"], inplace=True)

    return adata

def get_downsamp_img_size(samples, image_dir):
    img_sizes = {}
    for sample in samples:
        path = _resolve_image_path(image_dir, sample)
        if path is None:
            raise FileNotFoundError(f"No image found for sample '{sample}' in {image_dir}")
        img = skimage.io.imread(path)

        # Handle grayscale, RGB, RGBA
        if img.ndim == 2:
            gray = img
        elif img.ndim == 3:
            if img.shape[2] == 4:
                img = skimage.color.rgba2rgb(img)
            gray = skimage.color.rgb2gray(img)
        else:
            raise ValueError(f"Unexpected image shape {img.shape} for '{sample}'")
            
        gray = img_as_float32(gray)
        img_sizes[sample] = gray.shape

    return img_sizes

def get_downsamp_subimage_size_coords_seqRNA(
    adata,
    cell_filter,
    celltype_col,
    image_dir,
    sample_col="Sample_ID",
    x_col="x_scaled_image",
    y_col="y_scaled_image",
):
    # Subset
    if cell_filter is None or (isinstance(cell_filter, str) and cell_filter.strip() == ""):
        subset = adata.copy()
    elif isinstance(cell_filter, (pd.Series, np.ndarray)):
        subset = adata[cell_filter].copy()
    elif callable(cell_filter):
        mask = cell_filter(adata.obs)
        subset = adata[mask].copy()
    else:
        raise ValueError("cell_filter must be None, a boolean mask, a string, or a callable.")

    subset.obs["subtype"] = subset.obs[celltype_col]
    samples = subset.obs[sample_col].unique().tolist()
    
    print(f"Plotting {len(subset)} cells across {len(samples)} samples: {samples}")

    subimage_diam = {}
    subimage_x_sorted = {}
    subimage_y_sorted = {}
    
    # Get the size and coordinates for each sample
    for i, sample in enumerate(samples):

        adata_sample = subset[subset.obs[sample_col] == sample].copy()

        # Convert the gene expression data into a dataframe
        gene_expression_df = pd.DataFrame(
            adata_sample.X.toarray() if not isinstance(adata_sample.X, np.ndarray) else adata_sample.X,
            index=adata_sample.obs_names,
            columns=adata_sample.var_names
        )
        
        # Sort the spot metadata in row major order according to X/Y coordinates
        sorted_obs = adata_sample.obs.sort_values(by=[y_col, x_col], ascending=[True, True])
        adata_sample.obs = sorted_obs

        # Sort the gene_expression data according to the new .obs_names order in the spot data
        sorted_gene_expression_df = gene_expression_df.loc[adata_sample.obs_names]

        print(f"Shape of sorted gene_expressions dataframe is {sorted_gene_expression_df.shape}")
    
        # Calculate diameter and crop subimage
        x_sorted = sorted_obs[x_col].values
        y_sorted = sorted_obs[y_col].values
        if len(x_sorted) > 1:
            diameter = abs(x_sorted[1] - x_sorted[0])
        else:
            diameter = 32  # fallback if only one spot
            
        subimage_diam.update({sample: diameter})
        subimage_x_sorted.update({sample: x_sorted})
        subimage_y_sorted.update({sample: y_sorted})

    return subimage_diam, subimage_x_sorted, subimage_y_sorted, sorted_gene_expression_df

# Load VSI image
def _resolve_vsi_path(image_dir, recursive=True):
    """
    Find all .vsi files in the given directory.

    Args:
        directory (str): Path to the directory to search.
        recursive (bool): If True, search subdirectories as well.

    Returns:
        list[str]: List of full paths to .vsi files.
    """
    directory = Path(image_dir)
    
    if recursive:
        files = list(directory.rglob("*.vsi"))
    else:
        files = list(directory.glob("*.vsi"))

    return [str(f.resolve()) for f in files]

def save_vsi_subimages_to_file(
    image_dir, 
    downsamp_tiff_size, 
    downsamp_cent_coord, 
    downsamp_tissue_size,
    downsamp_patch_size,
    downsamp_x_coords, 
    downsamp_y_coords,
    gene_expressions,
    output_path = None,
    channel=0, 
    z_index=0, 
    percentiles=(0, 55),
):
    paths = _resolve_vsi_path(image_dir, recursive=False)

    if paths is None:
        raise FileNotFoundError(f"No images found in {image_dir}")
    
    for path in paths:
        reader = bioformats.ImageReader(path)

        # Get raw image metadata
        width  = reader.rdr.getSizeX()
        height = reader.rdr.getSizeY()
        channels = reader.rdr.getSizeC()
        zslices  = reader.rdr.getSizeZ()
        timepoints = reader.rdr.getSizeT()
        
        print(f"Full image size: {width} x {height}")
        print(f"Channels: {channels}, Z: {zslices}, T: {timepoints}")

        # Calculate the central location of the sample of interest in the raw slide
        downsamp_rate = round(width / downsamp_tiff_size[0])
        
        print(f"Downsampling rate is {downsamp_rate}")
        
        raw_cent_coord = (downsamp_cent_coord[0] * downsamp_rate, downsamp_cent_coord[1] * downsamp_rate)

        print(f"Central coordinate on downsampled .tiff image: {downsamp_cent_coord}, central coordinate on raw slide image: {raw_cent_coord}")

        # Calculate the size of the corresponding tissue in the raw slide image
        raw_half_tissue_width = int(downsamp_tissue_size[0] // 2 * downsamp_rate)
        raw_half_tissue_height = int(downsamp_tissue_size[1] // 2 * downsamp_rate)

        print(f"half_width * half_height of tissue in raw slide is {raw_half_tissue_width} * {raw_half_tissue_height}")

        # Calculate the start location of the corresponding tissue in the raw slide image
        x_tissue_min = max(raw_cent_coord[0] - raw_half_tissue_width, 0)
        y_tissue_min = max(raw_cent_coord[1] - raw_half_tissue_height, 0)
        
        print(f"Start location of tissue in raw slide is ({x_tissue_min}, {y_tissue_min})")
        
        # Calculate the size of the subimage in the raw slide image
        raw_half_patch_size = int(downsamp_patch_size // 2 * downsamp_rate)

        print(f"half_size of subimage in raw slide is {raw_half_patch_size}")

        # Calculate the scaled X and Y coordinates in the raw slide image
        raw_x_coords = downsamp_x_coords * downsamp_rate
        raw_y_coords = downsamp_y_coords * downsamp_rate

        # The number of submiages that need to be saved
        n_spots = len(downsamp_x_coords)

        raw_coords = []
        seq_RNA = []
        
        for i in range(n_spots):

            # Get the relative central coordinates for each subimage
            x0 = raw_x_coords[i]
            y0 = raw_y_coords[i]

            # Save the raw x_coords and y_coords as a tuple and append it to list
            raw_coords.append((x0, y0))

            # Append corresponding gene expression to list
            seq_RNA.append(gene_expressions.iloc[i])

            # Calculate the parameters required to crop the corresponding subimage: (x, y, width, height)
            x_min = max(x0 - raw_half_patch_size + x_tissue_min, 0)
            y_min = max(y0 - raw_half_patch_size + y_tissue_min, 0)
            w, h = [raw_half_patch_size * 2 for _ in range(2)]

            # Crop the corresponding subimage: (x, y, width, height)
            subimg = reader.read(z=0, t=0, series=0, XYWH=(x_min, y_min, w, h))  # NumPy array

            # Handle grayscale, RGB, RGBA
            if subimg.ndim == 2:
                gray = subimg
            elif subimg.ndim == 3:
                if subimg.shape == 4:
                    subimg = skimage.color.rgba2rgb(subimg)
                gray = skimage.color.rgb2gray(subimg)
            else:
                raise ValueError(f"Unexpected image shape {subimg.shape}")
    
            gray = img_as_float32(gray)
    
            if percentiles is not None:
                low, high = np.percentile(gray, percentiles)
                if not np.isfinite(low) or not np.isfinite(high) or high <= low:
                    low, high = np.min(gray), np.max(gray)
                gray = skimage.exposure.rescale_intensity(gray, in_range=(low, high))
    
            # Rotate the image by 90 degrees
            gray_temp = gray.copy()
            gray_rot = np.rot90(gray_temp, k=-1)

            # If array is float, scale to 0-255 and convert to uint8
            if gray_rot.dtype != np.uint8:
                gray_rot = (255 * (gray_rot - gray_rot.min()) / (gray_rot.ptp() + 1e-8)).astype(np.uint8)

            # Save subimage to output_path
            os.makedirs(output_path, exist_ok=True)
            img = Image.fromarray(gray_rot)
            filename = os.path.join(output_path, f"image_{i}.png")
            img.save(filename)
            
        reader.close()

    return raw_coords, seq_RNA

# Start the Java Virtual Machine

In [3]:
import bioformats
import javabridge

# Start the JVM to be abke to load the .
javabridge.start_vm(class_path=bioformats.JARS, max_heap_size="512m", run_headless=True,
    args=["-Dorg.slf4j.simpleLogger.defaultLogLevel=error"])

# Preprocess data for PyTorch Dataset

In [4]:
from pathlib import Path
from PIL import Image
import skimage
import skimage.io
from skimage.color import rgb2gray
from skimage import img_as_float32

# list of samples
MOB_samples = ['MOB_1']

adatas = []

for sample_name, dir_path, filename in files_to_load:
    print(f"Reading {sample_name}")
    adatas.append(load_and_prepare(sample_name, dir_path, filename))

# Combine datasets if more than one
combined = ad.concat(adatas, axis=0, join="outer", label=None, index_unique=None)

dups = combined.obs["Barcode"].duplicated().sum()
print(f"Combined shape: {combined.shape}")
print(f"Duplicate Barcodes: {dups}")
print("SampleIDs present:", combined.obs["SampleID"].unique())

# Get the size of downsampled tissue images
downsamp_img_sizes = get_downsamp_img_size(samples=MOB_samples, image_dir=HE_dir)
print(f"Downsampled image sizes: {downsamp_img_sizes}")

# Get the size of downsampled subimages within a tissue and X/Y coordinates
downsamp_subimg_sizes, downsamp_x_coords, downsamp_y_coords, gene_expressions = get_downsamp_subimage_size_coords_seqRNA(
    adata = combined,
    cell_filter = combined.obs["SampleID"].str.contains("MOB", case=False, na=False),
    celltype_col = "Subregion",
    image_dir = HE_dir,
    sample_col="SampleID",
    x_col="X_Scaled",
    y_col="Y_Scaled",
)

# Specify size of downsampled .tiff image and central coordinate of the tissue of interest in the downsampled .tiff image
downsamp_tiff_size = (9695, 3838)
downsamp_cent_coord = (5028, 2664)

# Save subimages from raw microscope images to file 
for sample in MOB_samples:
    subimg_coords, subimg_seqRNA = save_vsi_subimages_to_file(
        MOB_dir, 
        downsamp_tiff_size, 
        downsamp_cent_coord, 
        downsamp_img_sizes.get(sample), 
        downsamp_subimg_sizes.get(sample), 
        downsamp_x_coords.get(sample), 
        downsamp_y_coords.get(sample),
        gene_expressions,
        output_path = MOB_subimg_dir,
    )

print(f"Length of subimg_coords is: {len(subimg_coords)}")
print(f"Length of RNA sequence list and length of each RNA sequence is: {len(subimg_seqRNA)} and {subimg_seqRNA[0].shape}")

Reading MOB_1


  adata.obs["Subregion"] = sub_a.combine_first(sub_b)


Combined shape: (7716, 17243)
Duplicate Barcodes: 0
SampleIDs present: ['MOB_1']
Downsampled image sizes: {'MOB_1': (1899, 1900)}
Plotting 7716 cells across 1 samples: ['MOB_1']
Shape of sorted gene_expressions dataframe is (7716, 17243)
10:58:23.742 [Thread-0] DEBUG loci.common.NIOByteBufferProvider -- Using mapped byte buffer? false
10:58:23.760 [Thread-0] DEBUG loci.formats.ClassList -- Could not find loci.formats.in.URLReader
java.lang.ClassNotFoundException: loci.formats.in.URLReader
	at java.base/jdk.internal.loader.BuiltinClassLoader.loadClass(BuiltinClassLoader.java:581)
	at java.base/jdk.internal.loader.ClassLoaders$AppClassLoader.loadClass(ClassLoaders.java:178)
	at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:527)
	at java.base/java.lang.Class.forName0(Native Method)
	at java.base/java.lang.Class.forName(Class.java:315)
	at loci.formats.ClassList.parseLine(ClassList.java:196)
	at loci.formats.ClassList.parseFile(ClassList.java:258)
	at loci.formats.ClassList.<i

# Generate PyTorch Dataset

In [5]:
'''
Grab all subimages from disk, pair it up with subimg_coords, create a PyTorch object
and create a separate PyTorch object for subimg_seqRNA
'''
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset

# Convert image to PyTorch tensor (C x H x W) with values in [0, 1]
transform = transforms.Compose([transforms.ToTensor()])

# Create the gene expression PyTorch Dataset
class GeneExpressionDataset(Dataset):
    def __init__(self, subimg_seqRNA):
        # Convert each Series into a 1D tensor of values
        self.data = [torch.tensor(s.values, dtype=torch.float32) for s in subimg_seqRNA]
            
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Define and create the subimage PyTorch Dataset
class RawSubimageSpotDataset(Dataset):
    """
    PyTorch Dataset for cropped subimages from raw slide images and their (x, y) coordinates
    within the raw slide image. 
    
    Each item is (tensor_img, tensor_label(x, y)).
    """
    def __init__(self, subimg_dir, subimg_coords):
        
        # Store results
        self.tensor_list = []

        for i in range(len(subimg_coords)):
            filename = os.path.join(subimg_dir, f"image_{i}.png")
            subimg = np.array(Image.open(filename))
            tensor_img = transform(subimg).unsqueeze(0)
            tensor_label = torch.tensor(subimg_coords[i])
            self.tensor_list.append((tensor_img, tensor_label))

    def __len__(self):
        return len(self.tensor_list)

    def __getitem__(self, idx):
        return self.tensor_list[idx]

seqRNA_dataset_MOB = GeneExpressionDataset(subimg_seqRNA)
raw_dataset_MOB = RawSubimageSpotDataset(MOB_subimg_dir, subimg_coords)
print(f"First subimage object is {raw_dataset_MOB[0]}")

First subimage object is (tensor([[[[0.8745, 0.8784, 0.8784,  ..., 0.7961, 1.0000, 1.0000],
          [0.9569, 0.8784, 0.8784,  ..., 0.7961, 0.9647, 0.9490],
          [1.0000, 0.9137, 0.9137,  ..., 0.6706, 0.8392, 0.9216],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 0.8745, 0.8745],
          [0.9686, 0.9686, 0.9686,  ..., 1.0000, 0.8745, 0.8314],
          [0.9686, 0.8863, 0.8863,  ..., 1.0000, 0.8745, 0.8314]]]]), tensor([8472.5473, 5907.0455], dtype=torch.float64))


In [6]:
import glob

# Delete all subimages in disk
for file in glob.glob(os.path.join(MOB_subimg_dir, "*.png")):
    os.remove(file)
print(f"Finished deleting subimages")

Finished deleting subimages
