In [22]:
import os
import tempfile
import datasets
from huggingface_hub import login
import shutil
from pathlib import Path
from hest import iter_hest
import numpy as np
import matplotlib.pyplot as plt

In [23]:
hf_token = "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"
cache_dir = "/mnt/HDD8TO/data/cache"

In [28]:
class HEST:
    def __init__(self, hf_token=None, cache_dir=None):
        self.hf_token = hf_token
        self.cache_dir = Path(cache_dir)
        if self.hf_token == "from_env":
            self.hf_token = os.getenv("HF_TOKEN")
        if self.cache_dir is None:
            self.cache_dir = Path(tempfile.gettempdir(), "hest")
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        if hf_token is not None:
            from huggingface_hub import login
            login(self.hf_token)

    def empty_cache(self):
        shutil.rmtree(self.cache_dir)

    def cache_dataset(self, dataset_id, ):
        datasets.load_dataset(
            'MahmoodLab/hest',
            cache_dir=self.cache_dir,
            patterns=[f"*{dataset_id}[_.]**"],
        )

    def load_dataset(self, dataset_id, fullres=False):

        if not dataset_id in os.listdir(self.cache_dir):
            self.cache_dataset(dataset_id)
        for st in iter_hest(self.cache_dir, id_list=dataset_id): # Replaced by one that is present
            sdata = st.to_spatial_data(fullres=fullres)
        return sdata


In [29]:
hest = HEST(hf_token=hf_token, cache_dir=cache_dir)

In [None]:
sdata

SpatialData object
├── Images
│     ├── 'ST_downscaled_hires_image': SpatialImage[cyx] (3, 2496, 2400)
│     ├── 'ST_downscaled_lowres_image': SpatialImage[cyx] (3, 1000, 962)
│     └── 'ST_fullres_image': DataTree[cyx] (3, 19968, 19200), (3, 9984, 9600), (3, 4992, 4800), (3, 2496, 2400), (3, 1248, 1200), (3, 624, 600), (3, 312, 300), (3, 156, 150)
├── Shapes
│     ├── 'cellvit': GeoDataFrame shape: (37483, 3) (2D shapes)
│     ├── 'locations': GeoDataFrame shape: (1084, 2) (2D shapes)
│     └── 'tissue_contours': GeoDataFrame shape: (2, 2) (2D shapes)
└── Tables
      └── 'table': AnnData (1084, 36601)
with coordinate systems:
    ▸ 'ST_downscaled_hires', with elements:
        ST_downscaled_hires_image (Images), cellvit (Shapes), locations (Shapes), tissue_contours (Shapes)
    ▸ 'ST_downscaled_lowres', with elements:
        ST_downscaled_lowres_image (Images), cellvit (Shapes), locations (Shapes), tissue_contours (Shapes)
    ▸ 'ST_fullres', with elements:
        ST_fullres_image 

In [32]:
hf_token = "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"
cache_dir = "/mnt/HDD8TO/data/cache"

In [35]:
def download_in_cache(ids_to_query, cache_dir):
    datasets.load_dataset(
        'MahmoodLab/hest',
        cache_dir=cache_dir,
        patterns=[f"*{id}[_.]**" for id in ids_to_query],
    )

In [36]:
download_in_cache(["INT1","INT2"], cache_dir)

Fetching 24 files: 100%|██████████| 24/24 [00:41<00:00,  1.74s/it]


Unzipping cell vit segmentation...


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]
Generating train split: 2 examples [00:00, 2104.52 examples/s]


In [39]:
os.listdir(cache_dir)

['MahmoodLab___hest',
 'cellvit_seg',
 'thumbnails',
 'tissue_seg',
 'metadata',
 '_mnt_HDD8TO_data_cache_MahmoodLab___hest_custom_config-4453f252e9cb3f84_1.0.0_94127ca856cb5f26aa6d5ab751be03921c2cc400b324e6285c015001d87154f7.lock',
 'patches_vis',
 'patches',
 '.cache',
 'st',
 'spatial_plots',
 'wsis',
 'pixel_size_vis']

In [47]:
def load_dataset(ids_to_query, cache_dir, fullres=False):
    for st in iter_hest(hest_dir=cache_dir, id_list=ids_to_query):
        sdata = st.to_spatial_data(fullres=fullres)
    return sdata

In [49]:
st_data = load_dataset(cache_dir=cache_dir,ids_to_query=["INT1","INT2"])

In [51]:
st_data.images

{'ST_downscaled_hires_image': <xarray.SpatialImage 'ST_downscaled_lowres_image' (c: 3, y: 2400, x: 2400)> Size: 17MB
dask.array<transpose, shape=(3, 2400, 2400), dtype=int8, chunksize=(3, 2400, 2400), chunktype=numpy.ndarray>
Coordinates:
  * c        (c) int64 24B 0 1 2
  * y        (y) float64 19kB 0.5 1.5 2.5 3.5 ... 2.398e+03 2.398e+03 2.4e+03
  * x        (x) float64 19kB 0.5 1.5 2.5 3.5 ... 2.398e+03 2.398e+03 2.4e+03
Attributes:
    transform:  {'ST_downscaled_hires': Identity }, 'ST_downscaled_lowres_image': <xarray.SpatialImage 'ST_downscaled_lowres' (c: 3, y: 1000, x: 1000)> Size: 3MB
dask.array<array, shape=(3, 1000, 1000), dtype=uint8, chunksize=(3, 1000, 1000), chunktype=numpy.ndarray>
Coordinates:
  * c        (c) int64 24B 0 1 2
  * y        (y) float64 8kB 0.5 1.5 2.5 3.5 4.5 ... 996.5 997.5 998.5 999.5
  * x        (x) float64 8kB 0.5 1.5 2.5 3.5 4.5 ... 996.5 997.5 998.5 999.5
Attributes:
    transform:  {'ST_downscaled_lowres': Identity }}

In [None]:
# download id in cache
# load from cache
# delete cache


In [None]:
hf_token = "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"
cache_dir = "/mnt/HDD8TO/data/cache"

In [61]:
hest.sdata

AttributeError: 'HEST' object has no attribute 'sdata'

In [None]:
image = np.array(sdata.images["ST_fullres_image"]["scale0"].image)

In [None]:
import os
import tempfile
import datasets
from huggingface_hub import login
import shutil
from pathlib import Path
from hest import iter_hest
import numpy as np
import matplotlib.pyplot as plt

def convert_center_to_slicing(coords_center, dimensions, reverse_order=False):
    """
    Convert a center coordinate and dimensions to a slicing.
    If reverse_order is True, the slicing is returned in the reverse order.
    example
        input: np.array([150, 350, 550]), np.array([100, 100, 100])
        output: (slice(100,200), slice(300,400), slice(500,600))
    """
    dimensions = np.array(dimensions).astype(int)
    coords_center = np.array(coords_center).astype(int)
    half_dimensions = np.array(dimensions) // 2
    coords_min = coords_center - half_dimensions
    coords_max = coords_center - half_dimensions + dimensions
    slicing = tuple([slice(min_, max_) for min_, max_ in zip(coords_min, coords_max)])
    if reverse_order:
        return slicing[::-1]
    return slicing


class HEST:
    def __init__(self, hf_token=None, cache_dir=None):
        self.hf_token = hf_token
        self.cache_dir = Path(cache_dir)
        if self.hf_token == "from_env":
            self.hf_token = os.getenv("HF_TOKEN")
        if self.cache_dir is None:
            self.cache_dir = Path(tempfile.gettempdir(), "hest")
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        if hf_token is not None:
            from huggingface_hub import login
            login(self.hf_token)

    def empty_cache(self):
        shutil.rmtree(self.cache_dir)

    def cache_dataset(self, dataset_id, ):
        datasets.load_dataset(
            'MahmoodLab/hest',
            cache_dir=self.cache_dir,
            patterns=[f"*{dataset_id}[_.]**"],
        )

    def load_dataset(self, dataset_id, fullres=True):
        if not dataset_id in os.listdir(self.cache_dir):
            self.cache_dataset(dataset_id)
        for st in iter_hest(self.cache_dir, id_list=['INT1']): # Replaced by one that is present
            sdata = st.to_spatial_data(fullres=fullres)
        sdata = self.repatch(sdata)
        return sdata


    def repatch(self, sdata):
        image = np.array(sdata.images["ST_fullres_image"]["scale0"].image)
        DIMENSIONS = (128, 128)
        geometries = sdata.shapes["locations"]["geometry"]
        slicings = geometries.apply(lambda x: (slice(None),) + convert_center_to_slicing((int(x.x), int(x.y)), DIMENSIONS)).values
        patchs = []
        for slicing in slicings:
            patchs.append(image[slicing].swapaxes(0, 2))
        sdata.tables['table'].obsm['embeddings'] = np.array(patchs)
        return sdata


In [162]:
sdata_list = []
for id in ["INT1", "INT2"]:
    hest = HEST(hf_token=hf_token, cache_dir=cache_dir)
    sdata_list.append(hest.load_dataset(id))


In [164]:
sdata_list[0]

In [103]:
hest = HEST(hf_token="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", cache_dir="./data")

In [104]:
sdata = hest.load_dataset("INT1",fullres=True)

In [105]:
sdata.repatch()

<__main__.HEST at 0x72e418e36fc0>

In [132]:
sdata.sdata.tables['table'].obsm['embeddings'] = sdata.patchs

In [138]:
sdata.sdata.tables['table'].obs

Unnamed: 0,in_tissue,array_row,array_col,pxl_row_in_fullres,pxl_col_in_fullres,n_genes_by_counts,log1p_n_genes_by_counts,total_counts,log1p_total_counts,pct_counts_in_top_50_genes,pct_counts_in_top_100_genes,pct_counts_in_top_200_genes,pct_counts_in_top_500_genes,total_counts_mito,log1p_total_counts_mito,pct_counts_mito,region,instance_id
AAACAGAGCGACTCCT-1,1,14,94,6724,5167,3638,8.199464,9904.0,9.200795,23.374394,33.885299,45.032310,58.208805,130.0,4.875197,1.312601,locations,0
AAACCACTACACAGAT-1,1,3,117,4207,3072,1926,7.563720,4568.0,8.427050,26.751313,39.623468,51.773205,67.250438,162.0,5.093750,3.546410,locations,1
AAACGACAGTCTTGCC-1,1,2,118,4098,2881,1122,7.023759,2169.0,7.682482,27.247580,39.741817,52.973721,71.323190,64.0,4.174387,2.950669,locations,2
AAACGAGACGGTTGAT-1,1,35,79,8365,9166,2557,7.846981,6079.0,8.712760,23.424905,35.038658,46.603060,61.243626,77.0,4.356709,1.266656,locations,3
AAACGCCCGAGATCGG-1,1,4,108,5192,3262,889,6.791221,1578.0,7.364547,27.249683,40.177440,54.752852,75.348542,59.0,4.094345,3.738910,locations,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTGTAATCCGTACTCG-1,1,35,55,10992,9166,2771,7.927324,7612.0,8.937613,26.497635,38.899107,50.604309,64.240673,130.0,4.875197,1.707830,locations,1079
TTGTCGTTCAGTTACC-1,1,22,58,10664,6690,1248,7.130099,2357.0,7.765569,25.668222,37.802291,50.233347,68.264743,59.0,4.094345,2.503182,locations,1080
TTGTGGTATAGGTATG-1,1,24,126,3222,7071,594,6.388561,916.0,6.821107,28.602620,42.139738,56.986900,89.737991,57.0,4.060443,6.222707,locations,1081
TTGTTCAGTGTGCTAC-1,1,24,64,10007,7071,3195,8.069655,8524.0,9.050758,23.732989,34.760676,46.609573,60.183013,121.0,4.804021,1.419521,locations,1082


In [141]:
sdata.sdata.tables['table'].obsm['embeddings']

array([[[[157, 160, 201],
         [151, 156, 196],
         [139, 144, 186],
         ...,
         [207, 207, 217],
         [207, 207, 215],
         [212, 211, 219]],

        [[161, 165, 202],
         [164, 168, 203],
         [158, 162, 199],
         ...,
         [210, 210, 218],
         [211, 210, 216],
         [213, 212, 218]],

        [[163, 166, 199],
         [169, 173, 202],
         [165, 168, 199],
         ...,
         [214, 213, 219],
         [215, 214, 220],
         [216, 215, 221]],

        ...,

        [[149, 155, 187],
         [187, 185, 207],
         [201, 198, 219],
         ...,
         [217, 216, 222],
         [217, 216, 222],
         [217, 216, 222]],

        [[139, 149, 185],
         [179, 179, 203],
         [195, 193, 215],
         ...,
         [217, 216, 222],
         [217, 216, 222],
         [217, 216, 222]],

        [[131, 144, 186],
         [171, 173, 198],
         [189, 189, 213],
         ...,
         [217, 216, 222],
        