Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions python/genvarloader/_dataset/_genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@nb.njit(parallel=True, nogil=True, cache=True)
def get_diffs_sparse(
geno_offset_idxs: NDArray[np.integer],
geno_offset_idx: NDArray[np.integer],
geno_v_idxs: NDArray[np.integer],
geno_offsets: NDArray[np.integer],
ilens: NDArray[np.integer],
Expand All @@ -24,7 +24,7 @@ def get_diffs_sparse(

Parameters
----------
geno_offset_idxs : NDArray[np.intp]
geno_offset_idx : NDArray[np.intp]
Shape = (n_regions, ploidy) Indices for each region into offsets.
geno_v_idxs : NDArray[np.int32]
Shape = (variants*samples*ploidy) Sparse genotypes i.e. variant indices for ALT genotypes.
Expand All @@ -43,11 +43,11 @@ def get_diffs_sparse(
v_starts : Optional[NDArray[np.int32]]
Shape = (total_variants) Positions of unique variants.
"""
n_queries, ploidy = geno_offset_idxs.shape
n_queries, ploidy = geno_offset_idx.shape
diffs = np.empty((n_queries, ploidy), np.int32)
for query in nb.prange(n_queries):
for hap in nb.prange(ploidy):
o_idx = geno_offset_idxs[query, hap]
o_idx = geno_offset_idx[query, hap]
if geno_offsets.ndim == 1:
o_s, o_e = geno_offsets[o_idx], geno_offsets[o_idx + 1]
else:
Expand Down Expand Up @@ -118,7 +118,7 @@ def reconstruct_haplotypes_from_sparse(
out_offsets: NDArray[np.integer],
regions: NDArray[np.integer],
shifts: NDArray[np.integer],
geno_offset_idxs: NDArray[np.integer],
geno_offset_idx: NDArray[np.integer],
geno_offsets: NDArray[np.integer],
geno_v_idxs: NDArray[np.integer],
v_starts: NDArray[np.integer],
Expand All @@ -135,6 +135,9 @@ def reconstruct_haplotypes_from_sparse(
):
"""Reconstruct haplotypes from reference sequence and variants.

Batched parallel driver: dispatches to :func:`reconstruct_haplotype_from_sparse`
(singular) for each ``(query, hap)`` pair.

Parameters
----------
out : NDArray[np.uint8]
Expand All @@ -145,7 +148,7 @@ def reconstruct_haplotypes_from_sparse(
Shape = (batch, 3) Regions to reconstruct haplotypes.
shifts : NDArray[np.uint32]
Shape = (batch, ploidy) Shifts for each region.
geno_offset_idxs: NDArray[np.intp]
geno_offset_idx: NDArray[np.intp]
Shape = (batch, ploidy) Indices for each region into offsets.
geno_offsets : NDArray[np.uint32]
Shape = (batch*ploidy + 1) Offsets into genos.
Expand Down Expand Up @@ -174,7 +177,7 @@ def reconstruct_haplotypes_from_sparse(
annot_ref_pos : NDArray[np.int32] | None
Ragged buffer for shape (batch, ploidy, ~length). Reference positions for annotations.
"""
batch_size, ploidy = geno_offset_idxs.shape
batch_size, ploidy = geno_offset_idx.shape
for query in nb.prange(batch_size):
q = regions[query]
c_idx: int = q[0]
Expand All @@ -185,7 +188,7 @@ def reconstruct_haplotypes_from_sparse(

for hap in nb.prange(ploidy):
# index for full sparse genos
o_idx = geno_offset_idxs[query, hap]
o_idx = geno_offset_idx[query, hap]
if geno_offsets.ndim == 1:
o_s, o_e = geno_offsets[o_idx], geno_offsets[o_idx + 1]
else:
Expand Down Expand Up @@ -244,7 +247,10 @@ def reconstruct_haplotype_from_sparse(
annot_v_idxs: NDArray[np.integer] | None = None,
annot_ref_pos: NDArray[np.integer] | None = None,
):
"""Reconstruct a haplotype from reference sequence and variants.
"""Reconstruct a single haplotype from reference sequence and variants.

Single-haplotype inner kernel. Use :func:`reconstruct_haplotypes_from_sparse`
(plural) to reconstruct a batch in parallel.

Parameters
----------
Expand Down Expand Up @@ -419,7 +425,7 @@ def reconstruct_haplotype_from_sparse(
def choose_exonic_variants(
starts: NDArray[np.integer],
ends: NDArray[np.integer],
geno_offset_idxs: NDArray[np.integer],
geno_offset_idx: NDArray[np.integer],
geno_v_idxs: NDArray[np.integer],
geno_offsets: NDArray[np.integer],
v_starts: NDArray[np.integer],
Expand All @@ -433,7 +439,7 @@ def choose_exonic_variants(
Shape = (n_regions) Start positions for each region.
ends : NDArray[np.int32]
Shape = (n_regions) Ends for each region.
geno_offset_idxs : NDArray[np.intp]
geno_offset_idx : NDArray[np.intp]
Shape = (n_regions, ploidy) Indices for each region into offsets.
offsets : NDArray[np.int64]
Shape = (total_variants + 1) Offsets into sparse genotypes.
Expand All @@ -446,12 +452,12 @@ def choose_exonic_variants(
deterministic : bool
Whether to deterministically assign variants to groups
"""
n_regions, ploidy = geno_offset_idxs.shape
n_regions, ploidy = geno_offset_idx.shape

lengths = np.empty((n_regions, ploidy), np.int64)
for query in nb.prange(n_regions):
for hap in range(ploidy):
o_idx = geno_offset_idxs[query, hap]
o_idx = geno_offset_idx[query, hap]
if geno_offsets.ndim == 1:
o_s, o_e = geno_offsets[o_idx], geno_offsets[o_idx + 1]
else:
Expand All @@ -468,7 +474,7 @@ def choose_exonic_variants(
ref_start: int = starts[query]
ref_end: int = ends[query]
for hap in nb.prange(ploidy):
o_idx = geno_offset_idxs[query, hap]
o_idx = geno_offset_idx[query, hap]
# Mirror filter_af's (2, n_slices) indexing (sibling kernel below).
if geno_offsets.ndim == 1:
o_s, o_e = geno_offsets[o_idx], geno_offsets[o_idx + 1]
Expand Down Expand Up @@ -521,7 +527,7 @@ def _choose_exonic_variants(

@nb.njit(parallel=True, nogil=True, cache=True)
def filter_af(
geno_offset_idxs: NDArray[np.integer],
geno_offset_idx: NDArray[np.integer],
geno_offsets: NDArray[np.integer],
geno_v_idxs: NDArray[np.integer],
afs: NDArray[np.number],
Expand All @@ -530,7 +536,7 @@ def filter_af(
) -> tuple[NDArray[np.bool_], NDArray[OFFSET_TYPE]]:
"""Filter variants based on allele frequency, marking them to keep or not."""

batch_size, ploidy = geno_offset_idxs.shape
batch_size, ploidy = geno_offset_idx.shape

if geno_offsets.ndim == 1:
keep_offsets = geno_offsets.astype(OFFSET_TYPE)
Expand All @@ -549,7 +555,7 @@ def filter_af(
for query in nb.prange(batch_size):
for hap in range(ploidy):
# index for full sparse genos
o_idx = geno_offset_idxs[query, hap]
o_idx = geno_offset_idx[query, hap]
if geno_offsets.ndim == 1:
o_s, o_e = geno_offsets[o_idx], geno_offsets[o_idx + 1]
else:
Expand Down
49 changes: 26 additions & 23 deletions python/genvarloader/_dataset/_haps.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,13 @@ def _haplotype_ilens(
) -> NDArray[np.int32]:
"""`idx` must be 1D."""
# (b p)
geno_offset_idxs = self._get_geno_offset_idx(idx, self.genotypes)
geno_offset_idx = self._get_geno_offset_idx(idx, self.genotypes)

if self.filter == "exonic":
keep, keep_offsets = choose_exonic_variants(
starts=regions[:, 1],
ends=regions[:, 2],
geno_offset_idxs=geno_offset_idxs,
geno_offset_idx=geno_offset_idx,
geno_v_idxs=self.genotypes.data,
geno_offsets=self.genotypes.offsets,
v_starts=self.variants.start,
Expand All @@ -319,7 +319,7 @@ def _haplotype_ilens(

# (r s p)
hap_ilens = get_diffs_sparse(
geno_offset_idxs=geno_offset_idxs,
geno_offset_idx=geno_offset_idx,
geno_v_idxs=self.genotypes.data,
geno_offsets=self.genotypes.offsets,
ilens=self.variants.ilen,
Expand Down Expand Up @@ -353,7 +353,7 @@ def haplotype_lengths_for_plan(
keep, keep_offsets = choose_exonic_variants(
starts=regions[:, 1],
ends=regions[:, 2],
geno_offset_idxs=geno_offset_idx,
geno_offset_idx=geno_offset_idx,
geno_v_idxs=self.genotypes.data,
geno_offsets=self.genotypes.offsets,
v_starts=self.variants.start,
Expand Down Expand Up @@ -453,7 +453,7 @@ def get_haps_and_shifts(
assert_never(self.kind)

return (
out, # type: ignore | pylance doesn't like this but it's correct behavior for the signature
out,
req.geno_offset_idx,
req.shifts,
req.diffs,
Expand Down Expand Up @@ -488,7 +488,7 @@ def _prepare_request(
keep, keep_offsets = choose_exonic_variants(
starts=regions[:, 1],
ends=regions[:, 2],
geno_offset_idxs=geno_offset_idx,
geno_offset_idx=geno_offset_idx,
geno_v_idxs=self.genotypes.data,
geno_offsets=self.genotypes.offsets,
v_starts=self.variants.start,
Expand Down Expand Up @@ -545,10 +545,13 @@ def _get_geno_offset_idx(
idx: NDArray[np.integer],
genotypes: Ragged[V_IDX_TYPE],
) -> NDArray[np.intp]:
r_idx, s_idx = np.unravel_index(idx, genotypes.shape[:2]) # type: ignore
r_idx, s_idx = np.unravel_index(idx, genotypes.shape[:2]) # type: ignore[no-matching-overload] # Ragged.shape is tuple[int | None, ...]; numpy overload expects all-int
ploid_idx = np.arange(genotypes.shape[-2], dtype=np.intp)
rsp_idx = (r_idx[:, None], s_idx[:, None], ploid_idx)
geno_offset_idx = np.ravel_multi_index(rsp_idx, genotypes.shape[:-1]) # type: ignore
# (region, sample, ploid) index tuple for ravel_multi_index.
region_sample_ploid_idx = (r_idx[:, None], s_idx[:, None], ploid_idx)
geno_offset_idx = np.ravel_multi_index(
region_sample_ploid_idx, genotypes.shape[:-1]
) # type: ignore[no-matching-overload] # Ragged.shape is tuple[int | None, ...]; numpy overload expects all-int
return geno_offset_idx

def _get_variants(
Expand All @@ -560,7 +563,7 @@ def _get_variants(
keep_offsets: NDArray[np.integer] | None = None,
) -> RaggedVariants:
# TODO: maybe filter variants for region, shifts?
r, s = np.unravel_index(idx, self.genotypes.shape[:2]) # type: ignore
r, s = np.unravel_index(idx, self.genotypes.shape[:2]) # type: ignore[no-matching-overload] # Ragged.shape is tuple[int | None, ...]; numpy overload expects all-int
# (b p ~v)
genos = cast(Ragged[V_IDX_TYPE], self.genotypes[r, s])

Expand Down Expand Up @@ -598,7 +601,7 @@ def _get_variants(
# guaranteed to have same shape as genotypes but need to make it contiguous/copy the data
dosages = self.dosages[r, s]
if _keep is not None:
dosages = ak.to_regular(dosages[_keep], 1) # type: ignore
dosages = ak.to_regular(dosages[_keep], 1)
fields["dosage"] = Ragged(ak.to_packed(dosages))

fields.update(
Expand Down Expand Up @@ -650,7 +653,7 @@ def _reconstruct_haplotypes(self, req: ReconstructionRequest) -> Ragged[np.bytes
req.out_offsets,
)
reconstruct_haplotypes_from_sparse(
geno_offset_idxs=req.geno_offset_idx,
geno_offset_idx=req.geno_offset_idx,
out=haps.data,
out_offsets=haps.offsets,
regions=req.regions,
Expand Down Expand Up @@ -681,7 +684,7 @@ def _reconstruct_haplotypes(self, req: ReconstructionRequest) -> Ragged[np.bytes
out_buf = np.empty(total, np.uint8)

reconstruct_haplotypes_from_sparse(
geno_offset_idxs=flat_geno_idx.reshape(-1, 1),
geno_offset_idx=flat_geno_idx.reshape(-1, 1),
out=out_buf,
out_offsets=splice_plan.permuted_out_offsets,
regions=permuted_regions,
Expand Down Expand Up @@ -741,7 +744,7 @@ def _reconstruct_annotated_haplotypes(

# annot offsets match haps offsets, so we share them.
reconstruct_haplotypes_from_sparse(
geno_offset_idxs=req.geno_offset_idx,
geno_offset_idx=req.geno_offset_idx,
out=haps.data,
out_offsets=haps.offsets,
regions=req.regions,
Expand Down Expand Up @@ -778,7 +781,7 @@ def _reconstruct_annotated_haplotypes(
annot_pos_buf = np.empty(total, np.int32)

reconstruct_haplotypes_from_sparse(
geno_offset_idxs=flat_geno_idx.reshape(-1, 1),
geno_offset_idx=flat_geno_idx.reshape(-1, 1),
out=out_buf,
out_offsets=splice_plan.permuted_out_offsets,
regions=permuted_regions,
Expand Down Expand Up @@ -824,7 +827,7 @@ def _permute_request_for_splice(
NDArray[np.bool_] | None,
NDArray[np.integer] | None,
]:
"""Permute the per-element arrays in ``req`` according to ``splice_plan.perm``.
"""Permute the per-element arrays in ``req`` according to ``splice_plan.permutation``.

``geno_offset_idx`` and ``shifts`` have shape ``(B, P)``; flatten to
``(B*P,)`` in (query, ploidy) C-order, then permute. The kernel then
Expand All @@ -833,27 +836,27 @@ def _permute_request_for_splice(
assert req.splice_plan is not None
splice_plan = req.splice_plan
ploidy = req.shifts.shape[1] if req.shifts.ndim > 1 else 1
perm = splice_plan.perm
permutation = splice_plan.permutation

flat_geno_idx = req.geno_offset_idx.reshape(-1)[perm].astype(
flat_geno_idx = req.geno_offset_idx.reshape(-1)[permutation].astype(
np.intp, copy=False
)
flat_shifts = req.shifts.reshape(-1)[perm].astype(np.int32, copy=False)
flat_shifts = req.shifts.reshape(-1)[permutation].astype(np.int32, copy=False)
# regions has shape (B, 3). For (B*P, 3), each query repeats P times
# consecutively, then we apply the same perm.
# consecutively, then we apply the same permutation.
regions_flat = np.repeat(req.regions, ploidy, axis=0)
permuted_regions = regions_flat[perm]
permuted_regions = regions_flat[permutation]

# keep / keep_offsets: per-k granularity (length B*P + 1).
if req.keep is not None and req.keep_offsets is not None:
keep_lens = np.diff(req.keep_offsets)
keep_lens_perm = keep_lens[perm]
keep_lens_perm = keep_lens[permutation]
keep_offsets_perm = lengths_to_offsets(
keep_lens_perm.astype(np.int64), dtype=np.int64
)
keep_perm = np.empty(int(keep_lens_perm.sum()), dtype=np.bool_)
write_cursor = 0
for k_old in perm:
for k_old in permutation:
s = int(req.keep_offsets[k_old])
e = int(req.keep_offsets[k_old + 1])
width = e - s
Expand Down
12 changes: 6 additions & 6 deletions python/genvarloader/_dataset/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,11 @@ def with_tracks(
new_tracks = self._tracks.with_tracks(None)
elif isinstance(tracks, str):
new_tracks = self._tracks.with_tracks([tracks]).to_kind(
_kind, # type: ignore
_kind, # type: ignore[bad-argument-type] # _kind is broader union; runtime branch ensures correct subtype
)
else:
new_tracks = self._tracks.with_tracks(tracks).to_kind(
_kind, # type: ignore
_kind, # type: ignore[bad-argument-type] # _kind is broader union; runtime branch ensures correct subtype
)

# Validate: at least one of (seqs, tracks) must remain active.
Expand Down Expand Up @@ -998,7 +998,7 @@ def haplotype_lengths(
if out_reshape is not None:
hap_lens = hap_lens.reshape(
*out_reshape,
self._seqs.genotypes.shape[-2], # type: ignore
self._seqs.genotypes.shape[-2],
)

return hap_lens
Expand Down Expand Up @@ -1148,7 +1148,7 @@ def write_transformed_track(
overwrite=overwrite,
)

return replace(self, _tracks=new_tracks) # type: ignore
return replace(self, _tracks=new_tracks) # type: ignore[bad-return] # dataclasses.replace returns Self but pyrefly widens to base Dataset union

def write_annot_tracks(
self, tracks: dict[str, str | Path | pl.DataFrame], overwrite: bool = False
Expand Down Expand Up @@ -1554,7 +1554,7 @@ def __getitem__(
def __getitem__(
self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx]
) -> SEQ | TRK | tuple[SEQ, TRK]:
return super().__getitem__(idx) # type: ignore
return super().__getitem__(idx) # type: ignore[bad-return] # base Dataset returns broad union; SEQ/TRK typevars narrow at use sites


class RaggedDataset(Dataset, Generic[MaybeRSEQ, MaybeRTRK]):
Expand Down Expand Up @@ -1709,4 +1709,4 @@ def __getitem__(
def __getitem__(
self, idx: StrIdx | tuple[StrIdx] | tuple[StrIdx, StrIdx]
) -> RSEQ | RTRK | tuple[RSEQ, RTRK]:
return super().__getitem__(idx) # type: ignore
return super().__getitem__(idx) # type: ignore[bad-return] # base Dataset returns broad union; RSEQ/RTRK typevars narrow at use sites
Loading
Loading