In [1]:
import sys
from pathlib import Path
from tempfile import TemporaryDirectory

import genvarloader as gvl
import numba as nb
import numpy as np
import polars as pl
import seqpro as sp
import pooch
from loguru import logger
from einops import rearrange
from tqdm.auto import tqdm

# Tutorial: Geuvadis

In this tutorial we'll see how to use GenVarLoader (GVL) to:

1. Write a GVL dataset
2. Add transforms
3. Lazily subset it (train/test splits)
4. Get a PyTorch DataLoader
5. Cache transformed tracks on disk (optional)
6. Evaluate Basenji2 across genes and individuals (optional)

This tutorial also assumes you have read ["What's a gvl.Dataset?"](https://genvarloader.readthedocs.io/en/stable/dataset.html).

## Logging

A quick note on logging: GenVarLoader uses [loguru](https://loguru.readthedocs.io/en/stable/index.html) for logging. We will enable it at the "INFO" level to get some additional information from GVL for this tutorial.

In [2]:
logger.remove()
logger.add(sys.stderr, level="INFO")
logger.enable("genvarloader")

## Download the data

The Geuvadis dataset is 451 individuals from the 1000 Genomes Project that have both whole genome sequencing and RNA-seq from blood samples. We'll see how to use GVL to get a high performance dataloader that yields haplotypes and tracks for training or running inference with sequence models. For the sake of this tutorial, we'll only work with chromosome 22 so everything can run in a few minutes.

Downloading this data should take ~5-10 minutes and is the slowest step in this notebook.

In [3]:
# GRCh38 chromosome 22 sequence
reference = pooch.retrieve(
    url="https://ftp.ensembl.org/pub/release-112/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.chromosome.22.fa.gz",
    known_hash="sha256:974f97ac8ef7ffae971b63b47608feda327403be40c27e391ee4a1a78b800df5",
    progressbar=True,
)
if not Path(f"{reference[:-3]}.bgz").exists():
    !gzip -dc {reference} | bgzip > {reference[:-3]}.bgz
reference = reference[:-3] + ".bgz"

# PLINK 2 files
variants = pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/1kGP.chr22.pgen",
    known_hash="md5:31aba970e35f816701b2b99118dfc2aa",
    progressbar=True,
    fname="1kGP.chr22.pgen",
)
pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/1kGP.chr22.psam",
    known_hash="md5:eefa7aad5acffe62bf41df0a4600129c",
    progressbar=True,
    fname="1kGP.chr22.psam",
)
pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/1kGP.chr22.pvar",
    known_hash="md5:5f922af91c1a2f6822e2f1bb4469d12b",
    progressbar=True,
    fname="1kGP.chr22.pvar",
)

# BigWigs and sample ID mapping
bw_paths = pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/bw_chr22.tar.gz",
    known_hash="md5:14bf72e9e9d3e2318d07315c4a2675fb",
    progressbar=True,
    processor=pooch.Untar(),
)
bw_table_path = pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/bigwig_table.csv",
    known_hash="md5:7fe7c55b61c7dfa66cfd0a49336f3b08",
    progressbar=True,
)

# BED
bed_path = pooch.retrieve(
    url="doi:10.5281/zenodo.13656224/chr22_egenes.bed",
    known_hash="md5:ccb55548e4ddd416d50dbe6638459421",
    progressbar=True,
)

## Writing the GVL dataset

We'll specify a directory to store the dataset (similar to Zarr stores).

In [4]:
tmp_dir = TemporaryDirectory(suffix=".gvl")
ds_path = tmp_dir.name

We'll also need a table or dictionary specifying the sample names for each BigWig. Tables must have at least have columns `sample` and `path` as seen below. The join is added here to update the paths to match the actual download paths.

In [5]:
bigwig_table = (
    pl.read_csv(bw_table_path)
    .join(
        pl.Series(bw_paths).to_frame("realpath"),
        left_on="path",
        right_on=pl.col("realpath").str.split("/").list.get(-1),
    )
    .drop("path")
    .rename({"realpath": "path"})
)
bigwig_table.head()

sample,read_count,path
str,i64,str
"""HG00236""",34548283,"""/carter/users/dlaub/.cache/poo…"
"""HG00259""",53041143,"""/carter/users/dlaub/.cache/poo…"
"""NA20519""",36620358,"""/carter/users/dlaub/.cache/poo…"
"""NA20811""",24398971,"""/carter/users/dlaub/.cache/poo…"
"""NA20768""",30019566,"""/carter/users/dlaub/.cache/poo…"


Finally, we'll need a BED file specifying what regions to include in the dataset. We can either specify a path or a polars DataFrame. We'll use [gvl.read_bedlike](https://genvarloader.readthedocs.io/en/latest/api.html#genvarloader.read_bedlike) to conveniently read the BED file into memory and subset it to just the top 20 eGenes for this tutorial. The BED file provided corresponds to transcription start sites of eGenes, sorted in descending order by their absolute sum of coefficients.

In [6]:
bed = gvl.read_bedlike(bed_path)[:20]
bed.head()

chrom,chromStart,chromEnd,name,score,strand
str,i64,i64,str,f64,str
"""chr22""",41699499,41699499,"""ENSG00000167077""",,"""+"""
"""chr22""",42835412,42835412,"""ENSG00000100266""",,"""-"""
"""chr22""",20858983,20858983,"""ENSG00000099940""",,"""+"""
"""chr22""",20707691,20707691,"""ENSG00000241973""",,"""-"""
"""chr22""",49918167,49918167,"""ENSG00000184164""",,"""+"""


Now we're ready to write the dataset.

The `bed` above specifies the transcription start site for each gene so chromStart == chromEnd, so we'll expand those regions to $2^{17}$ (131,072) bp using [gvl.with_length](https://genvarloader.readthedocs.io/en/latest/api.html#genvarloader.with_length) which corresponds to the input length for Basenji2.

We'll also instantiate a [gvl.BigWigs](https://genvarloader.readthedocs.io/en/latest/api.html#genvarloader.BigWigs) from the above table (we could also use a dictionary). We'll name this track "read-depth" so we can manage different transformations of the track data or provide multiple tracks for the same samples. Later, we'll add a transformed track for $\log_2(\text{CPM}+1)$ to see this in action.

Finally, we'll pass `max_jitter` as 128 bp. This will allow random jittering of the sequences and tracks up to 128 bp in either direction. When we open the dataset later it will use the maximum amount of jitter by default.

In [7]:
gvl.write(
    path=ds_path,
    bed=gvl.with_length(bed, 2**17),  # change region length to 131,072 bp
    variants=variants,
    bigwigs=gvl.BigWigs.from_table(name="read-depth", table=bigwig_table),
    max_jitter=128,  # allow up to 128 bp jitter
    overwrite=True,
)

[32m2025-03-19 20:27:36.453[0m | [1mINFO    [0m | [36mgenvarloader._dataset._write[0m:[36mwrite[0m:[36m99[0m - [1mWriting dataset to /tmp/tmpmcduh61m.gvl[0m
[32m2025-03-19 20:27:36.454[0m | [1mINFO    [0m | [36mgenvarloader._dataset._write[0m:[36mwrite[0m:[36m104[0m - [1mFound existing GVL store, overwriting.[0m
[32m2025-03-19 20:27:36.541[0m | [1mINFO    [0m | [36mgenvarloader._dataset._write[0m:[36mwrite[0m:[36m172[0m - [1mUsing 451 samples.[0m
[32m2025-03-19 20:27:36.541[0m | [1mINFO    [0m | [36mgenvarloader._dataset._write[0m:[36mwrite[0m:[36m178[0m - [1mWriting genotypes.[0m


  0%|          | 0/1 [00:00<?, ?it/s]

[32m2025-03-19 20:27:42.068[0m | [1mINFO    [0m | [36mgenvarloader._dataset._write[0m:[36mwrite[0m:[36m197[0m - [1mWriting BigWig intervals.[0m


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

[32m2025-03-19 20:27:46.202[0m | [1mINFO    [0m | [36mgenvarloader._dataset._write[0m:[36mwrite[0m:[36m204[0m - [1mFinished writing.[0m


Note that [gvl.write](https://genvarloader.readthedocs.io/en/latest/api.html#genvarloader.write) will also automatically use the intersection of samples from source files. In this case, they are perfectly matched to each other. But, if we had used PLINK files for the full 3,202 samples from the 1000 Genomes Project then it would have identified and used the 451 intersecting samples.

## Dataloader

Now that the dataset is written, we can add a transform, split it, and get a PyTorch dataloader in ~10 lines of code.

In [8]:
def transform(haps, tracks):
    haps = rearrange(
        sp.DNA.ohe(haps), "... length alphabet -> ... alphabet length"
    ).astype(np.float32)
    return haps, tracks


ds = (
    gvl.Dataset.open(ds_path, reference=reference)
    .with_seqs("haplotypes")
    .with_tracks("read-depth")
    .with_len(2**17)
    .with_transform(transform)
)
n_train = round(ds.n_samples * 0.8)
gene1_train_ds = ds.subset_to(samples=slice(0, n_train))
dl = gene1_train_ds.to_dataloader(batch_size=16, num_workers=0, shuffle=True)

[32m2025-03-19 20:27:46.208[0m | [1mINFO    [0m | [36mgenvarloader._dataset._impl[0m:[36mopen[0m:[36m227[0m - [1mLoading reference genome into memory. This typically has a modest memory footprint (a few GB) and greatly improves performance.[0m
[32m2025-03-19 20:27:46.243[0m | [1mINFO    [0m | [36mgenvarloader._dataset._impl[0m:[36mopen[0m:[36m269[0m - [1mOpened dataset:
GVL store at /tmp/tmpmcduh61m.gvl
Is subset: False
# of regions: 20
# of samples: 451
Output length: ragged
Jitter: 0 (max: 128)
Deterministic: True
Sequence type: reference [haplotypes] annotated
Active tracks: read-depth
Tracks available: read-depth
[0m


GVL uses numba JIT compiled functions extensively, so the first call to `gvl.write`, first batch from a dataloader, etc. will often take much longer than subsequent calls due to compilation. This allows GVL to be multithreaded almost everywhere that it can be, so using `num_workers=0` or `1` is usually the best choice for dataloader throughput.

In [9]:
haps, tracks = next(iter(dl))
print(haps.shape, tracks.shape)

torch.Size([16, 2, 4, 131072]) torch.Size([16, 1, 2, 131072])


After one-hot encoding, the haplotypes have shape `(batch, ploidy, alphabet, length)` and the tracks have shape `(batch, tracks, ploidy, length)`.

## Pre-computing transformed tracks (optional)

Suppose we would like to normalize the read depth across the dataset to account for library size. We could compute this on-the-fly, but GVL also offers a way to write this data back to disk to cache this computation and potentially improve performance. Note that this is the most technical part of this tutorial, so feel free to skip this and come back later.

In [10]:
sample_library_sizes = (
    pl.Series(ds.samples)
    .to_frame("sample")
    .join(bigwig_table, on="sample", how="left")["read_count"]
    .to_numpy()
)
sample_library_sizes[:5]

array([27256165, 43941108, 39687917, 22341838, 23258231])

For this step, we'll use [Dataset.write_transformed_track](https://genvarloader.readthedocs.io/en/latest/api.html#genvarloader.Dataset.write_transformed_track) which expects a transform function to be given. From the docs:

> The arguments given to the transform will be the dataset indices, region indices, and sample indices as numpy arrays and the tracks themselves as a [Ragged](https://genvarloader.readthedocs.io/en/latest/api.html#genvarloader.Ragged) array with shape `(regions, samples)`. The tracks must be a [Ragged](https://genvarloader.readthedocs.io/en/latest/api.html#genvarloader.Ragged) array since regions may be different lengths to accomodate indels. This function should then return the transformed tracks as a [Ragged](https://genvarloader.readthedocs.io/en/latest/api.html#genvarloader.Ragged) array with the same shape and lengths.

Below, you can see an example of a transform of ragged data that uses Numba to accelerate the computation. Note that working with [Ragged](https://genvarloader.readthedocs.io/en/latest/api.html#genvarloader.Ragged) arrays is often not necessary with on-the-fly transformations, since data for deep learning is readily processed to be uniform length before any transformations.

In [11]:
@nb.njit(parallel=True, nogil=True, fastmath=True)
def inner_transform(s_idx, data, offsets):
    log_cpm = np.empty_like(data)
    for i in nb.prange(len(offsets) - 1):
        start = offsets[i]
        end = offsets[i + 1]
        sample = s_idx[i]
        log_cpm[start:end] = np.log1p(
            data[start:end] / sample_library_sizes[sample] * 1e6
        )
    return log_cpm


def log_cpm(r_idx, s_idx, tracks: gvl.Ragged[np.float32]):
    data = inner_transform(s_idx, tracks.data, tracks.offsets)
    return gvl.Ragged.from_offsets(data, tracks.shape, tracks.offsets)


ds = ds.write_transformed_track(
    "lcpb", "read-depth", log_cpm, overwrite=True, max_mem=1 * 2**30
)
ds

  0%|          | 0/5 [00:00<?, ?it/s]

GVL store at /tmp/tmpmcduh61m.gvl
Is subset: False
# of regions: 20
# of samples: 451
Output length: 131072
Jitter: 0 (max: 128)
Deterministic: True
Sequence type: reference [haplotypes] annotated
Active tracks: read-depth
Tracks available: lcpb, read-depth

**If the above cell crashes the kernel, it may have ran out of RAM which reducing `max_mem` can fix.**

After writing the transformed track to disk, we can see the dataset now has the `"lcpb"` track available (note the list of available tracks is always sorted).

## Evaluating Basenji2 on personalized expression (optional)

Note: this section requires PyTorch and basenji2-pytorch to be installed.

Here, we'll show a (very) quick and dirty demo of some of the results found by [Huang et al. Nat Gen 2023](https://www.nature.com/articles/s41588-023-01574-w) with Basenji2. We also recommend running this with a GPU since inference with Basenji2 may take quite a while otherwise.

In [12]:
human_targets = pl.read_csv(
    "https://github.com/calico/basenji/blob/master/manuscripts/cross2020/targets_human.txt?raw=true",
    separator="\t",
)
target = human_targets.filter(
    pl.col("description").str.contains(r"(?i)cage.*gm12878")
).item(0, "index")
human_targets.filter(pl.col("description").str.contains(r"(?i)cage.*gm12878"))

index,genome,identifier,file,clip,scale,sum_stat,description
i64,i64,str,str,i64,i64,str,str
5110,0,"""CNhs12333""","""/home/drk/tillage/datasets/hum…",384,1,"""sum""","""CAGE:B lymphoblastoid cell lin…"


**If the above cell is taking more than a few seconds, try restarting its execution -- sometimes GitHub fails to respond so the file doesn't download. Likewise for below, recount3 can get stuck.**

In [13]:
count_df = pl.read_csv(
    "https://duffel.rail.bio/recount3/human/data_sources/sra/gene_sums/42/ERP001942/sra.gene_sums.ERP001942.G029.gz",
    separator="\t",
    comment_prefix="#",
)
accessions = bigwig_table.with_columns(
    accession=pl.col("path").str.extract(r"(ERR\d+)")
)["accession"]
counts = (
    bed.join(
        count_df.select("gene_id", *accessions),
        left_on="name",
        right_on=pl.col("gene_id").str.split(".").list.get(0),
        maintain_order="left",
    )
    .select(*accessions)
    .to_numpy()
)
counts.shape

(20, 451)

In [14]:
import torch
from basenji2_pytorch import Basenji2, basenji2_params, basenji2_weights

device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_float32_matmul_precision("medium")

basenji2 = Basenji2(basenji2_params["model"]).to(device)
basenji2.load_state_dict(torch.load(basenji2_weights(), weights_only=True))
basenji2.eval();

In [15]:
def transform(haps, *args):
    haps = rearrange(
        sp.DNA.ohe(haps), "... length alphabet -> ... alphabet length"
    ).astype(np.float32)
    return haps, *args


ds = (
    gvl.Dataset.open(ds_path, reference)
    .with_len(2**17)
    .with_indices(True)
    .with_tracks(None)
    .with_transform(transform)
)

[32m2025-03-19 20:27:59.608[0m | [1mINFO    [0m | [36mgenvarloader._dataset._impl[0m:[36mopen[0m:[36m227[0m - [1mLoading reference genome into memory. This typically has a modest memory footprint (a few GB) and greatly improves performance.[0m
[32m2025-03-19 20:27:59.639[0m | [1mINFO    [0m | [36mgenvarloader._dataset._impl[0m:[36mopen[0m:[36m269[0m - [1mOpened dataset:
GVL store at /tmp/tmpmcduh61m.gvl
Is subset: False
# of regions: 20
# of samples: 451
Output length: ragged
Jitter: 0 (max: 128)
Deterministic: True
Sequence type: reference [haplotypes] annotated
Active tracks: lcpb, read-depth
Tracks available: lcpb, read-depth
[0m


**If you're using a GPU, you may need to use a smaller batch size depending on how much GPU RAM you have.**

In [16]:
batch_size = 48
# number of output bins for Basenji2, each corresponds to 128 bp of sequence
n_bins = 896

Compute predictions for all genes using reference sequences:

In [17]:
ref_preds = np.empty((ds.n_regions, n_bins), dtype=np.float32)
with torch.no_grad():
    for ref, r_idx, _ in tqdm(
        ds.subset_to(samples=0)
        .with_seqs("reference")
        .to_dataloader(batch_size=batch_size, num_workers=0)
    ):
        ref_preds[r_idx] = basenji2(ref.to("cuda"))[..., target].numpy(force=True)

  0%|          | 0/1 [00:00<?, ?it/s]

Next we'll compute the Pearson correlation between predicted transformed CAGE-seq read-depth and mean expression. We'll use a 5 bin (640 bp) window that is 9 bins (1152 bp) upstream of the TSS since this yielded the highest correlation with a little testing. This is somewhat expected since CAGE-seq reads should fall in the 5' UTR region, and we haven't thoroughly confirmed that the TSS coordinates we're using are exactly the same as what Basenji2 trained on.

In [18]:
ref_x_gene = np.corrcoef(
    ref_preds[..., 896 // 2 - 9 : 896 // 2 - 4].mean(-1), counts.mean(-1), rowvar=False
)[0, 1]
ref_x_gene

np.float64(0.47156382152055587)

We'd expect this to be the highest possible correlation Basenji2 can achieve on these genes. Let's see how it does across individuals and across genes with haplotypes.

In [19]:
preds = np.empty(ds.full_shape + (n_bins,), dtype=np.float64)
with torch.no_grad():
    for haps, r_idx, s_idx in tqdm(
        ds.to_dataloader(batch_size=batch_size, num_workers=0)
    ):
        preds[r_idx, s_idx] = basenji2(haps[:, 0].to(device))[..., target].numpy(
            force=True
        )

  0%|          | 0/188 [00:00<?, ?it/s]

In [20]:
ave_pearson_x_idv = np.diag(
    np.corrcoef(preds[..., 896 // 2 - 9 : 896 // 2 - 4].mean(-1), counts), 20
).mean()
ave_pearson_x_gene = np.diag(
    np.corrcoef(preds[..., 896 // 2 - 9 : 896 // 2 - 4].mean(-1), counts, rowvar=False),
    451,
).mean()
ave_pearson_x_idv, ave_pearson_x_gene

(np.float64(-0.005341827841689839), np.float64(0.4334261598507151))

The average correlation across genes with haplotypes is only slightly less than with reference sequences, but just as Huang et al. and others have found, the correlation across individuals is 0 on average.