Skip to content

Commit

Permalink
Merge pull request #82 from klarman-cell-observatory/yiming
Browse files Browse the repository at this point in the history
Clean up Visium I/O code
  • Loading branch information
yihming committed Dec 22, 2021
2 parents 43d7b0a + 612873c commit 0c7111b
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 48 deletions.
4 changes: 2 additions & 2 deletions pegasusio/multimodal_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def img(self) -> Union[pd.DataFrame, None]:
# Set the img field if needed
@img.setter
def img(self, img: pd.DataFrame):
assert self._unidata is not None
assert self._unidata is not None
assert self._unidata.get_modality() == "visium", "data needs to be spatial"
self._unidata.img = img

Expand Down Expand Up @@ -598,4 +598,4 @@ def _clean_tmp(self) -> dict:

def _addback_tmp(self, _tmp_multi) -> None:
for key, _tmp_dict in _tmp_multi.items():
self.data[key]._addback_tmp(_tmp_dict)
self.data[key]._addback_tmp(_tmp_dict)
2 changes: 1 addition & 1 deletion pegasusio/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,4 @@ def _infer_output_file_type(output_File: str) -> str:
write_scp_file(data, output_file, is_sparse = is_sparse, precision = precision)

data._addback_tmp(_tmp_multi)
logger.info(f"{file_type} file '{output_file}' is written.")
logger.info(f"{file_type} file '{output_file}' is written.")
2 changes: 1 addition & 1 deletion pegasusio/spatial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
barcode_multigraphs: Optional[Dict[str, csr_matrix]] = None,
feature_multigraphs: Optional[Dict[str, csr_matrix]] = None,
cur_matrix: str = "raw.data",
img=None,
img: Optional[pd.DataFrame] = None,
) -> None:
assert metadata["modality"] == "visium"
super().__init__(
Expand Down
51 changes: 22 additions & 29 deletions pegasusio/spatial_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
import os, re

import pandas as pd

Expand All @@ -14,10 +14,6 @@ def process_spatial_metadata(df):
df.set_index("barcodekey", inplace=True)


def is_image(filename):
return filename.endswith((".png", ".jpg"))


def load_visium_folder(input_path) -> MultimodalData:
"""
Method to read the visium spatial data folder
Expand Down Expand Up @@ -70,30 +66,27 @@ def load_visium_folder(input_path) -> MultimodalData:
with open(f"{spatial_path}/scalefactors_json.json") as fp:
scale_factors = json.load(fp)

arr = os.listdir(spatial_path)
for png in arr:
if not is_image(png):
continue
if "hires" in png:
with Image.open(f"{spatial_path}/{png}") as data:
data.load()
dict = {
"sample_id": sample_id,
"image_id": "hires",
"data": data,
"scaleFactor": scale_factors["tissue_hires_scalef"],
}
img = img.append(dict, ignore_index=True)
elif "lowres" in png:
with Image.open(f"{spatial_path}/{png}") as data:
data.load()
dict = {
"sample_id": sample_id,
"image_id": "lowres",
"data": data,
"scaleFactor": scale_factors["tissue_lowres_scalef"],
}
img = img.append(dict, ignore_index=True)
def get_image_data(filepath, sample_id, image_id, scaleFactor):
data = Image.open(filepath)
dict = {
"sample_id": sample_id,
"image_id": image_id,
"data": data,
"scaleFactor": scaleFactor,
}
return dict

for png in [f for f in os.listdir(spatial_path) if re.match(".*\.png", f)]:
if ("_hires_" in png) or ("_lowres_" in png):
filepath = f"{spatial_path}/{png}"
res_tag = "hires" if "_hires_" in png else "lowres"
image_item = get_image_data(
filepath,
sample_id,
res_tag,
scale_factors[f"tissue_{res_tag}_scalef"]
)
img = img.append(image_item, ignore_index=True)

assert not img.empty, "the image data frame is empty"
spdata = SpatialData(
Expand Down
2 changes: 1 addition & 1 deletion pegasusio/unimodal_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,4 +661,4 @@ def _clean_tmp(self) -> dict:
def _addback_tmp(self, _tmp_dict: dict) -> None:
for key in ["metadata", "barcode_multiarrays", "feature_multiarrays", "barcode_multigraphs", "feature_multigraphs"]:
if key in _tmp_dict:
getattr(self, key).update(_tmp_dict[key])
getattr(self, key).update(_tmp_dict[key])
2 changes: 1 addition & 1 deletion pegasusio/vdj_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def load_10x_vdj_file(input_csv: str, genome: str = None, modality: str = None)
from pegasusio.cylib.funcs import convert_10x_vdj_to_vdjdata
except ModuleNotFoundError:
print("No module named 'pegasusio.cylib.funcs'")

df = pd.read_csv(input_csv, na_filter = False) # Otherwise, '' will be converted to NaN
idx = df["productive"] == (True if df["productive"].dtype.kind == "b" else "True")
df = df[idx]
Expand Down
2 changes: 1 addition & 1 deletion pegasusio/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __getitem__(self, key: Union[str, "Ellipsis"]) -> Union[csr_matrix, Dict[str
if key is Ellipsis:
for key in self.parent:
if key not in self.multigraphs:
self.multigraphs[key] = self.parent[key][self.index][:,self.index]
self.multigraphs[key] = self.parent[key][self.index][:,self.index]
return self.multigraphs

if key not in self.multigraphs:
Expand Down
23 changes: 11 additions & 12 deletions pegasusio/zarr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ def read_series(self, group: zarr.Group, name: str) -> Union[pd.Categorical, np.
# categorical column
return pd.Categorical.from_codes(group[name][...], categories = group[f'_categories/{name}'][...], ordered = group[name].attrs['ordered'])
else:
if isinstance(group[name], zarr.core.Array):
return group[name][...]
elif isinstance(group[name], zarr.hierarchy.Group):
if isinstance(group[name], zarr.hierarchy.Group):
ll = []
for data in group[name].arrays():
ll.append(PIL.Image.fromarray(data[1][...]))
return ll

else:
return group[name][...]

def read_dataframe(self, group: zarr.Group) -> pd.DataFrame:
columns = group.attrs.get('columns', None)
if columns is None:
Expand Down Expand Up @@ -195,10 +195,10 @@ def read_mapping(self, group: zarr.Group) -> dict:

if 'scalar' in group.attrs:
res_dict.update(group.attrs['scalar'])

for key in group.array_keys():
res_dict[key] = self.read_array(group, key)

for key in group.group_keys():
sub_group = group[key]
data_type = sub_group.attrs['data_type']
Expand Down Expand Up @@ -255,7 +255,7 @@ def read_unimodal_data(self, group: zarr.Group) -> UnimodalData:
)
if isinstance (unidata, SpatialData):
unidata.img = self.read_dataframe(group["img"]) if "img" in group else dict()

if group.attrs.get("_cur_matrix", None) is not None:
unidata.select_matrix(group.attrs["_cur_matrix"])

Expand Down Expand Up @@ -317,9 +317,8 @@ def write_dataframe(self, parent: zarr.Group, name: str, df: pd.DataFrame) -> No
colgroup = group.create_group(col, overwrite = True)
x = 0
for data in df[col].values:
npdata = np.array(data)
x = x+1
self.write_series(colgroup, col + str(x), npdata)
self.write_series(colgroup, col + str(x), np.array(data))
else:
self.write_series(group, col, df[col].values)
group.attrs.update(**attrs_dict)
Expand All @@ -329,7 +328,7 @@ def write_array(self, group: zarr.Group, name: str, array: np.ndarray) -> None:
group.create_dataset(name, data = array, shape = array.shape, chunks = calc_chunk(array.shape), dtype = dtype, compressor = COMPRESSOR, overwrite = True)

def write_record_array(self, parent: zarr.Group, name: str, array: np.recarray) -> None:
group = parent.create_group(name, overwrite = True)
group = parent.create_group(name, overwrite = True)
attrs_dict = {'data_type' : 'record_array', 'columns' : list(array.dtype.names)}
for col in array.dtype.names:
self.write_array(group, col, array[col])
Expand Down Expand Up @@ -470,5 +469,5 @@ def write_multimodal_data(self, data: MultimodalData, overwrite: bool = True) ->
for key in data.data.deleted:
del self.root[key]
for key in data.data.accessed:
self.write_unimodal_data(self.root, key, data.get_data(key), overwrite = key in data.data.modified)
self.root.attrs['_selected'] = data._selected
self.write_unimodal_data(self.root, key, data.get_data(key), overwrite = key in data.data.modified)
self.root.attrs['_selected'] = data._selected

0 comments on commit 0c7111b

Please sign in to comment.