In [1]:
import re
from pyfaidx import Fasta
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import pyarrow.dataset as ds

from dataset import GeneIterableFixedB, split_filters_by_chroms
import torch
import pprint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
usecols = [0, 2, 3, 4, 6, 8]
colnames = ["chrom", "feature", "start", "end", "strand", "attributes"]

gtf = pd.read_csv(
    "data/raw_data/ENCFF129ZDE.gtf.gz",
    sep="\t",
    comment="#",
    header=None,
    names=colnames,
    usecols=usecols,
    compression="gzip"
)

In [8]:
# row = gtf.iloc[1]
# import pprint
# pprint.pprint(row.to_dict(), width=120)

gtf.iloc[12:30]

Unnamed: 0,chrom,feature,start,end,strand,attributes
12,chr1,gene,14404,29321,-,"gene_id ""ENSG00000227232.5""; gene_name ""WASH7P..."
13,chr1,transcript,14404,29570,-,"gene_id ""ENSG00000227232.5""; transcript_id ""EN..."
14,chr1,exon,29534,29570,-,"gene_id ""ENSG00000227232.5""; transcript_id ""EN..."
15,chr1,exon,24738,24891,-,"gene_id ""ENSG00000227232.5""; transcript_id ""EN..."
16,chr1,exon,18268,18366,-,"gene_id ""ENSG00000227232.5""; transcript_id ""EN..."
17,chr1,exon,17915,18061,-,"gene_id ""ENSG00000227232.5""; transcript_id ""EN..."
18,chr1,exon,17606,17742,-,"gene_id ""ENSG00000227232.5""; transcript_id ""EN..."
19,chr1,exon,17233,17368,-,"gene_id ""ENSG00000227232.5""; transcript_id ""EN..."
20,chr1,exon,16858,17055,-,"gene_id ""ENSG00000227232.5""; transcript_id ""EN..."
21,chr1,exon,16607,16765,-,"gene_id ""ENSG00000227232.5""; transcript_id ""EN..."


In [30]:
# read the Parquet file
df = pd.read_parquet("data/table/ENCFF180NXA.parquet")

# number of rows
print(len(df))

129911


In [31]:
row = df.iloc[18197]
pprint.pprint(row.to_dict(), width=120)

num_transcripts = df["transcripts"].apply(len).sum()

print("Transcripts:", num_transcripts)

{'chrom': 'chr16',
 'end': 57591681,
 'gene_id': 'ENSG00000159618.15',
 'start': 57542421,
 'strand': '+',
 'transcripts': array([{'transcript_id': 'ENST00000615867.4', 'tx_start': 57542687, 'tx_end': 57565313, 'exon_starts': array([57542687, 57562056, 57562384, 57563091, 57563848, 57565034]), 'exon_ends': array([57542701, 57562157, 57562459, 57563247, 57563979, 57565313]), 'rel_abundance': 1.0}],
      dtype=object)}
Transcripts: 351270


In [36]:
target = "ENCLB376UFGT000206877"

matches = df[df["transcripts"].apply(
    lambda txs: any(t["transcript_id"] == target for t in txs)
)]

print("Number of matching genes:", len(matches))
matches[["gene_id", "chrom", "start", "end"]]

Number of matching genes: 0


Unnamed: 0,gene_id,chrom,start,end


In [4]:
# choose your parquet directory
parquet_dir = "data/table"

# get split filters
train_filter, val_filter, test_filter = split_filters_by_chroms()

# create dataset
dset = GeneIterableFixedB(
    parquet_dir=parquet_dir,
    split_filter=train_filter,
    tx_batch_size=10,
    pad_bp=5000,
    shuffle=False,   # deterministic for debugging
)

batch = next(iter(dset))

# --------------------------------------------------------
# 1. Print shapes
# --------------------------------------------------------
print("\n=== SHAPES ===")
for k, v in batch.items():
    if torch.is_tensor(v):
        print(f"{k:15s} {tuple(v.shape)} {v.dtype}")
    else:
        print(f"{k:15s} {type(v)}")

# --------------------------------------------------------
# 2. Check the first 3 examples
# --------------------------------------------------------
print("\n=== FIRST 3 CONTEXT POSITIONS ===")
print(batch["context_pos"][:3])

print("\n=== FIRST 3 CONTEXT ROLES ===")
print(batch["context_roles"][:3])

print("\n=== FIRST 3 TARGET IDX ===")
print(batch["target_idx"][:3])

print("\n=== FIRST 3 TARGET ROLES ===")
print(batch["target_role"][:3])

print("\n=== FIRST 3 SUFFIX MASK SUMS ===")
print(batch["suffix_mask"][:3].sum(dim=1))

# --------------------------------------------------------
# 3. Check that the suffix starts at c+1
# --------------------------------------------------------
print("\n=== VERIFY SUFFIX START ===")
ctx_pos = batch["context_pos"]
suffix_pos = batch["suffix_pos"]
for i in range(5):
    ctx_len = (batch["context_mask"][i] > 0).sum().item()
    if ctx_len == 0:
        print(f"Sample {i}: No context. suffix_pos starts at {suffix_pos[i,0].item()}")
    else:
        last_ctx = ctx_pos[i, ctx_len-1].item()
        suf0 = suffix_pos[i, 0].item()
        print(f"Sample {i}: last ctx = {last_ctx}, suffix starts at {suf0}, diff = {suf0 - last_ctx}")


=== SHAPES ===
X_context       (10, 7, 256) torch.float16
context_pos     (10, 7) torch.int64
context_roles   (10, 7) torch.int64
context_mask    (10, 7) torch.float32
X_suffix        (10, 24918, 256) torch.float16
suffix_pos      (10, 24918) torch.int64
suffix_mask     (10, 24918) torch.float32
target_idx      (10,) torch.int64
target_role     (10,) torch.int64
rel_abundance   (10,) torch.float32

=== FIRST 3 CONTEXT POSITIONS ===
tensor([[    0,     0,     0,     0,     0,     0,     0],
        [19820,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0]])

=== FIRST 3 CONTEXT ROLES ===
tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]])

=== FIRST 3 TARGET IDX ===
tensor([19820,    96,  9430])

=== FIRST 3 TARGET ROLES ===
tensor([0, 3, 0])

=== FIRST 3 SUFFIX MASK SUMS ===
tensor([24918.,  5097., 24918.])

=== VERIFY SUFFIX START ===
Sample 0: No context. suffix_pos starts at 0
Sample 1: las

In [9]:
import math
import random
from pathlib import Path

import pyarrow as pa
import pyarrow.dataset as ds
import torch

ID2ROLE = ["TSS", "D", "A", "PAS"]

def validate_batch_basic(batch, verbose=True):
    X_context     = batch["X_context"]
    context_pos   = batch["context_pos"]
    context_roles = batch["context_roles"]
    context_mask  = batch["context_mask"]
    X_suffix      = batch["X_suffix"]
    suffix_pos    = batch["suffix_pos"]
    suffix_mask   = batch["suffix_mask"]
    target_idx    = batch["target_idx"]
    target_role   = batch["target_role"]
    rel_ab        = batch["rel_abundance"]

    B, M, D = X_context.shape
    B2, S, D2 = X_suffix.shape

    assert B == B2, "Batch size mismatch between context and suffix"
    assert D == D2, "Embedding dims mismatch between context and suffix"
    assert context_pos.shape == (B, M,)
    assert context_roles.shape == (B, M,)
    assert context_mask.shape == (B, M,)
    assert suffix_pos.shape == (B, S,)
    assert suffix_mask.shape == (B, S,)
    assert target_idx.shape == (B,)
    assert target_role.shape == (B,)
    assert rel_ab.shape == (B,)

    # masks must be prefix-ones then zeros
    ctx_lengths = (context_mask > 0).sum(dim=1)
    suf_lengths = (suffix_mask > 0).sum(dim=1)

    for i in range(B):
        ctx_len = ctx_lengths[i].item()
        suf_len = suf_lengths[i].item()
        if ctx_len > 0:
            assert torch.all(context_mask[i, :ctx_len] == 1)
            assert torch.all(context_mask[i, ctx_len:] == 0)
        if suf_len > 0:
            assert torch.all(suffix_mask[i, :suf_len] == 1)
            assert torch.all(suffix_mask[i, suf_len:] == 0)

    if verbose:
        print("✓ basic shapes and masks OK")
        print(f"  B={B}, M={M}, S={S}, D={D}")
        print("  mean ctx length:", ctx_lengths.float().mean().item())
        print("  mean suffix length:", suf_lengths.float().mean().item())


def validate_suffix_alignment(batch, verbose=True):
    context_pos  = batch["context_pos"]
    context_mask = batch["context_mask"]
    suffix_pos   = batch["suffix_pos"]
    suffix_mask  = batch["suffix_mask"]

    B, M = context_pos.shape
    _, S = suffix_pos.shape

    ctx_lengths = (context_mask > 0).sum(dim=1)
    suf_lengths = (suffix_mask > 0).sum(dim=1)

    for i in range(B):
        ctx_len = ctx_lengths[i].item()
        suf_len = suf_lengths[i].item()
        if suf_len == 0:
            continue

        if ctx_len == 0:
            # No context: suffix should start at 0
            assert suffix_pos[i, 0].item() == 0, f"Sample {i}: no context but suffix_pos[0] != 0"
        else:
            last_ctx = context_pos[i, ctx_len - 1].item()
            suf0 = suffix_pos[i, 0].item()
            assert suf0 == last_ctx + 1, f"Sample {i}: suffix start {suf0} != last_ctx+1 {last_ctx+1}"

        # Check monotonicity inside suffix
        sufp = suffix_pos[i, :suf_len]
        assert torch.all(sufp[1:] > sufp[:-1]), f"Sample {i}: suffix positions not strictly increasing"

    if verbose:
        print("✓ suffix alignment and monotonicity OK")


def validate_target_alignment(batch, verbose=True):
    suffix_mask = batch["suffix_mask"]
    suffix_pos  = batch["suffix_pos"]
    target_idx  = batch["target_idx"]

    B, S = suffix_mask.shape
    suf_lengths = (suffix_mask > 0).sum(dim=1)

    for i in range(B):
        suf_len = suf_lengths[i].item()
        ti = target_idx[i].item()
        assert 0 <= ti < suf_len, f"Sample {i}: target_idx {ti} out of suffix length {suf_len}"

    if verbose:
        print("✓ target_idx in-range for all samples")


def pretty_print_small_batch(batch, n=5):
    n = min(n, batch["X_context"].shape[0])
    ctx_pos   = batch["context_pos"]
    ctx_mask  = batch["context_mask"]
    ctx_roles = batch["context_roles"]
    suf_pos   = batch["suffix_pos"]
    suf_mask  = batch["suffix_mask"]
    tgt_idx   = batch["target_idx"]
    tgt_role  = batch["target_role"]
    rel_ab    = batch["rel_abundance"]

    print("\n=== SMALL BATCH SUMMARY ===")
    for i in range(n):
        ctx_len = int(ctx_mask[i].sum().item())
        suf_len = int(suf_mask[i].sum().item())
        print(f"\nSample {i}")
        print("  rel_abundance:", float(rel_ab[i].item()))
        print("  ctx_len:", ctx_len, "suffix_len:", suf_len)
        print("  context_pos:", ctx_pos[i, :ctx_len].tolist())
        print("  context_roles:", [ID2ROLE[int(x)] for x in ctx_roles[i, :ctx_len]])
        print("  suffix_pos[0:5]:", suf_pos[i, :min(5, suf_len)].tolist())
        print("  target_idx:", int(tgt_idx[i].item()))
        print("  target_role:", ID2ROLE[int(tgt_role[i].item())])


def validate_one_batch(batch, verbose=True):
    validate_batch_basic(batch, verbose=verbose)
    validate_suffix_alignment(batch, verbose=verbose)
    validate_target_alignment(batch, verbose=verbose)
    if verbose:
        pretty_print_small_batch(batch, n=5)

In [11]:
parquet_dir = "data/table"
train_filter, val_filter, test_filter = split_filters_by_chroms()

dset = GeneIterableFixedB(
    parquet_dir=parquet_dir,
    split_filter=train_filter,
    tx_batch_size=16,
    pad_bp=5000,
    shuffle=True,
)

batch = next(iter(dset))
validate_one_batch(batch, verbose=True)

✓ basic shapes and masks OK
  B=16, M=5, S=12525, D=256
  mean ctx length: 2.25
  mean suffix length: 8011.25
✓ suffix alignment and monotonicity OK
✓ target_idx in-range for all samples

=== SMALL BATCH SUMMARY ===

Sample 0
  rel_abundance: 0.1875
  ctx_len: 0 suffix_len: 12525
  context_pos: []
  context_roles: []
  suffix_pos[0:5]: [0, 1, 2, 3, 4]
  target_idx: 5000
  target_role: TSS

Sample 1
  rel_abundance: 0.1875
  ctx_len: 1 suffix_len: 7524
  context_pos: [5000]
  context_roles: ['TSS']
  suffix_pos[0:5]: [5001, 5002, 5003, 5004, 5005]
  target_idx: 166
  target_role: D

Sample 2
  rel_abundance: 0.1875
  ctx_len: 2 suffix_len: 7357
  context_pos: [5000, 5167]
  context_roles: ['TSS', 'D']
  suffix_pos[0:5]: [5168, 5169, 5170, 5171, 5172]
  target_idx: 468
  target_role: A

Sample 3
  rel_abundance: 0.1875
  ctx_len: 3 suffix_len: 6888
  context_pos: [5000, 5167, 5636]
  context_roles: ['TSS', 'D', 'A']
  suffix_pos[0:5]: [5637, 5638, 5639, 5640, 5641]
  target_idx: 144
  ta