Import the seillra package, and other helpful packages:

In [None]:
import os, sys
import pandas as pd
import torch
import numpy as np
import collections

import seillra as sl

To load different parts of the models, select a rank, whether the output is chromatin profiles or sequence classes, whether you are predicting for any sequences or specifically variants, and if the model should be on the CPU or GPU. These are selectable when using the `Sei_LLRA` model class, but other functionallity can be customized using the model blocks explicitly.

In [2]:
# If using a MAC
if "qnnpack" in torch.backends.quantized.supported_engines:
    torch.backends.quantized.engine = "qnnpack"

In [None]:
rank = 256 # 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048
sequence_classes = False # True
sequence_type = "sequence" # "variant"
quant = "CPU" # "GPU", "GPU_fp16", "GPU_int8"
model = sl.Sei_LLRA(k=rank, projection=sequence_classes, mode = sequence_type, quant=quant)

2026-01-21 14:58:42,821 - INFO - Checksum verified for url_a6038b62128b5b01_wts: 28a1a49ca62e4d67a62c170df3751f7255db6eea3923455c119c762dde446308
2026-01-21 14:58:42,822 - INFO - Loading state dict from /home/ejg66/.cache/seillra/1.5/url_a6038b62128b5b01_wts
2026-01-21 14:58:42,879 - INFO - Model weights loaded and set to eval mode.
2026-01-21 14:58:44,282 - INFO - Starting download: https://drive.google.com/uc?export=download&id=1DrlkcecVSgj3CH2924Zz8-Oj95Dyry3C -> /home/ejg66/.cache/seillra/1.5/url_9c83e76615711914_wts
2026-01-21 14:58:44,282 - INFO - Detected Google Drive URL, using download_from_gdrive for file_id=1DrlkcecVSgj3CH2924Zz8-Oj95Dyry3C
2026-01-21 14:58:47,613 - INFO - Successfully downloaded to /home/ejg66/.cache/seillra/1.5/url_9c83e76615711914_wts.part
2026-01-21 14:58:47,616 - INFO - Download complete: /home/ejg66/.cache/seillra/1.5/url_9c83e76615711914_wts
2026-01-21 14:58:47,665 - INFO - Checksum verified for url_9c83e76615711914_wts: ce0baa7e8533604ab579a37ada1848

This can be done using `torch.nn.Sequential` as well. For an un-quantized (GPU) model see the commented code. Note that one should be careful about handeling forward and reverse-complement sequences. 

In [None]:
mod_trunk = sl.get_sei_trunk(quant = quant) # sm.get_sei_trunk().load_weights()
mod_head  = sl.get_sei_head_llra(k=rank, quant = quant) # sl.get_sei_head_llra(k=rank)
mod_projection = sl.get_sei_projection(quant = quant)
mod_projection.set_mode(sequence_type)

#- Make a full model
mod = torch.nn.Sequential(collections.OrderedDict([
    ('trunk', mod_trunk),
    ('head', mod_head),
    ('projection', mod_projection)
]))

AttributeError: module 'seillra' has no attribute 'get_sei_trunk_q'

Here is an example of running the model on a random one-hot encoded sequence.

In [None]:
sequences = torch.randint(0, 4, (16, 4096))
x = torch.nn.functional.one_hot(sequences, num_classes=4).permute(0, 2, 1).float()
print(x.shape)
# - run the model
out = model(x.to("cpu"))
print(out.shape)

Now we can try a different rank, use the sequenc classes and use the GPU. Note that this will remain on the CPU if you do not have access to a cuda enabled GPU (Apple M-series GPUs are not cuda enabled).

In [None]:
rank = 1 # 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048
sequence_classes = True # True
sequence_type = "sequence" # "variant"
quant = "GPU" # "GPU", "GPU_fp16", "GPU_int8"
model2 = sl.Sei_LLRA(k=rank, projection=sequence_classes, mode = sequence_type, quant=quant)


In [None]:
sequences = torch.randint(0, 4, (16, 4096))
x = torch.nn.functional.one_hot(sequences, num_classes=4).permute(0, 2, 1).float()
print(x.shape)
# - run the model
out = model2(x.to("cuda"))
print(out.shape)

This can also be done for getting predictions for variants. We will set up example reference and alternate allele sequences. 

In [None]:
rank = 1 # 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048
sequence_classes = True # False
sequence_type = "variant" # "sequence"
device = "cpu" # "cuda", "cuda:0"
model3 = sl.Sei_LLRA(k=rank, projection=sequence_classes, mode = sequence_type, device=device)


In [None]:
ref_sequences = torch.randint(0, 4, (16, 4096))
alt_sequences = ref_sequences.clone()
center_idx = 4096 // 2
alt_sequences[:, center_idx] = (ref_sequences[:, center_idx] + 1) % 4


x_ref = torch.nn.functional.one_hot(ref_sequences, num_classes=4).permute(0, 2, 1).float()
x_alt = torch.nn.functional.one_hot(alt_sequences, num_classes=4).permute(0, 2, 1).float()

print(x_ref.shape, x_alt.shape)
input = (x_ref, x_alt)
out_ref, out_alt = model3(input)
print(out_ref.shape, out_alt.shape)


Here is an example workflow for doing variant effect predictions. For instructions on downloading data, see manuscript repository: [https://github.com/egilfeather/lowrank-s2f-code](https://github.com/egilfeather/lowrank-s2f-code).

In [None]:

file_path = "./data/MPRA_eQTL.vcf"
df = pd.read_csv("./data/MPRA_eQTL.tsv", sep='\t', header=0)
print(df.head())

In [None]:
import seillra as sl
rank = 256
model = sl.Sei_LLRA(k=rank, projection=True, mode = "variant", quant="CPU")

In [None]:
import numpy as np
import pandas as pd
import torch
# from pyfaidx import Fasta
# from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader, Subset
import random
import h5py
import os

from Bio import SeqIO
from Bio.Seq import Seq
from pybedtools import BedTool
from itertools import islice
import math

LOOKUP = {'A': 0, 'C': 1, 'G': 2, 'T': 3}

class VariantDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, fasta_path = ""):
        """
        Args:
            bed_path (str): Path to the BED file with positions.
            scores_path (str): Path to the file with 61-dimensional scores.
            fasta_path (str): Path to the FASTA file for sequence retrieval.
        """
        self.genome = self._load_fasta(fasta_path)
        self.vcf_positions = pd.read_csv(file_path, comment="#", sep="\t", header = None)
        self.vcf_positions.columns = ["CHROM", "POS", "STRAND", "REF", "ALT" ]
        self.vcf_positions["CHROM"] = self.vcf_positions["CHROM"].apply(
            lambda x: f"chr{x}" if not str(x).startswith("chr") else x
        )
        self.window_size = 4096

    def __len__(self):
        return len(self.vcf_positions)

    def __getitem__(self, index):
    
        vcf_row = self.vcf_positions.iloc[index]
        chrom, pos, strand, ref, alt = vcf_row

        center = pos + (len(ref) // 2)
        start = center - (self.window_size // 2) - 1
        end = start + self.window_size

        sequence = self._get_sequence(chrom, start, end)
        # Find where REF should be
        ref_start = (self.window_size // 2) - (len(ref) // 2)
        ref_end = ref_start + len(ref)
        ref_seq_segment = sequence[ref_start:ref_end]

        # Skip if reference doesn't match
        if ref_seq_segment != ref:
            if ref_seq_segment != alt:
                return None, None, None
            else:
                temp = sequence
                sequence = temp[:ref_start] + ref + temp[ref_end:]

        alt_sequence = sequence[:ref_start] + alt + sequence[ref_end:]
        if len(alt_sequence) < self.window_size:
            alt_sequence += "N" * (self.window_size - len(alt_sequence))
        elif len(alt_sequence) > self.window_size:
            alt_sequence = alt_sequence[:self.window_size]

        row = vcf_row.to_numpy().tolist()
        return torch.tensor(self.returnonehot(sequence, index=index), dtype=torch.float32), torch.tensor(self.returnonehot(alt_sequence, index=index), dtype=torch.float32), row


    def _load_fasta(self, fasta_path):
        """
        Load a FASTA file into a dictionary for fast sequence retrieval.
        """
        genome = {}
        for record in SeqIO.parse(fasta_path, "fasta"):
            genome[record.id] = str(record.seq)
        return genome

    def _get_sequence(self, chrom, start, end, strand = "+"):
        """
        Retrieve a sequence from the FASTA dictionary based on chromosome and positions.
        """
        seq = []
        # Left pad if start < 0
        if start < 0:
            seq.append("N" * (-start))
            start = 0
        # Middle
        if start < len(self.genome[chrom]):
            seq.append(self.genome[chrom][start:min(end, len(self.genome[chrom]))])
        # Right pad if end past chrom length
        if end > len(self.genome[chrom]):
            seq.append("N" * (end - len(self.genome[chrom])))
        return "".join(seq).upper()
    
    def _insert_allele(self, chrom, start, end, ref, alt, strand = "+"):
        """
        Replace the REF allele at the center of the sequence with ALT.
        
        Args:
            sequence (str): the reference sequence window
            ref (str): the reference allele (from VCF)
            alt (str): the alternate allele (from VCF)

        Returns:
            str: the modified sequence with the allele inserted
        """
        for char in ref:
            if char not in LOOKUP:
                return None
        for char in alt:
            if char not in LOOKUP:
                return None
        sequence = self._get_sequence(chrom, start, end, strand)
        sequence = sequence.upper()
        center_idx = self.window_size//2
        ref_len = len(ref)
        alt_len = len(alt)
        diff = ref_len - alt_len

        start_idx = center_idx -1 - ref_len//2
        end_idx = start_idx + (ref_len +1)//2

        ref_seq_segment = sequence[start_idx:end_idx]

        if ref_seq_segment != ref:
            return None
        new_sequence = sequence[:start_idx] + alt + sequence[end_idx:]
        if diff >=0:
            new_sequence = self._get_sequence(chrom, start-(diff//2), start, strand) + new_sequence + self._get_sequence(chrom, end, end+((diff+1)//2), strand)
        else:
            new_sequence = new_sequence[(-diff//2):-(-diff+1)//2]

        return new_sequence.upper()  

    def returnonehot(self, string):
        """
        One-hot encode a DNA sequence.
        
        Args:
            string (str): DNA sequence.
        
        Returns:
            np.ndarray: One-hot encoded matrix of shape (4, len(sequence)).
        """
        string = string.upper()
        lookup = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
        tmp = np.array(list(string))
        icol = np.where(tmp != 'N')[0]
        out = np.zeros((4, len(tmp)), dtype=np.float32)
        irow = np.array([lookup[i] for i in tmp[icol]])

        if len(icol) > 0:
            out[irow, icol] = 1

        return out


class SeqDataLoader(DataLoader):
    def __init__(self, dataset, *, batch_size=1, n_samples=None, **kwargs):
        """
        Custom DataLoader that limits the number of batches per epoch.

        Args:
            n_batches (int): Maximum number of batches to yield per epoch.
        """
        self.user_n_samples = n_samples  # Save the user's request
        self.batch_size_here = batch_size
        print(self.batch_size_here)
        self.__pl_cls_kwargs__ = {
            "dataset": dataset,
            "batch_size": batch_size,
            "n_samples": n_samples,
            **kwargs
        }
        if getattr(dataset, "mode", None) == "variant_prediction":
            kwargs["collate_fn"] = safe_collate

        super().__init__(dataset=dataset, batch_size=batch_size, **kwargs)

        dataset_size = len(self.dataset)
        print("Batch size after DataLoader:", self.batch_size_here)
        print("Batch size after DataLoader:", self.batch_size)
        print(dataset_size)
        # print(batch_size)
        max_batches = math.ceil(int(dataset_size) / int(self.batch_size_here))
        print(max_batches)
        if self.user_n_samples is not None:
            user_n_batches = math.ceil(self.user_n_samples / self.batch_size_here)
            self._effective_batches = min(user_n_batches, max_batches)
        else:
            self._effective_batches = max_batches
        print(self._effective_batches )

    def __iter__(self):
        base_iter = super().__iter__()
        return islice(base_iter, self._effective_batches)
        
    def __len__(self):
            return self._effective_batches 

def safe_collate(batch):
    # batch is a list of (x, y, v)
    batch = [item for item in batch if item[0] is not None]
    if not batch:
        print("Not Batch")
        return None  # whole batch is invalid
  
    xs, ys, vs = zip(*batch)
    return torch.stack(xs),torch.stack(ys), np.stack(vs)

class VariantDataLoader(SeqDataLoader):
    def __init__(self, dataset, *, batch_size=1, n_samples = None, **kwargs):
        """
        Custom DataLoader that limits the number of batches per epoch.

        Args:
            n_batches (int): Maximum number of batches to yield per epoch.
        """

        kwargs.pop("collate_fn", None)
        super().__init__(dataset = dataset, batch_size=batch_size, n_samples = n_samples, collate_fn = safe_collate, **kwargs)



In [None]:
from tqdm import tqdm
from sei_lora.dataloaders import VariantDataset, VariantDataLoader
dataset = VariantDataset(file_path=file_path, fasta_path = "./resources/hg38_UCSC.fa", window_size = 4096)
dataloader = VariantDataLoader(dataset=dataset, batch_size=32, shuffle=False, num_workers=0)
model.eval()

all_cp_ref = []
all_cp_alt = []
all_vcf = []

progress_bar = tqdm(dataloader, desc=f"Running {rank} benchmark")

for batch in progress_bar:
    ref, alt, vcf = batch


    cp_outputs = model((ref, alt))  # both are tuples: (refproj, altproj)


    all_cp_ref.append(cp_outputs[0])
    all_cp_alt.append(cp_outputs[1])
    all_vcf.append(vcf)

    # Accumulate by appending to list

all_cp_ref = torch.cat([t.detach() for t in all_cp_ref], dim=0).numpy()
all_cp_alt = torch.cat([t.detach() for t in all_cp_alt], dim=0).numpy()

all_vcf = np.concatenate(all_vcf, axis=0)



In [None]:
from sklearn.metrics import roc_auc_score
sc_diff = all_cp_alt - all_cp_ref

df_pred = pd.DataFrame(all_vcf, columns=["CHROM", "POS", "NAME", "REF", "ALT"])
df_pred["POS"] = df_pred["POS"].astype(int)

seqclass_path = os.path.join( "./resources/seqclass.names")
with open(seqclass_path, "r") as f:
    sc_names = []
    for line in f:
        parts = line.strip().split()
        if len(parts) > 1:
            sc_names.append("-".join(parts[1:]))
        else:
            sc_names.append(parts[0])

df_sc = pd.DataFrame(sc_diff[:, :40], columns=sc_names[:40])

df_pred =  pd.concat([df_pred, df_sc], axis=1)


df_ou = df[df['consequence'].isin(['over', 'under'])].copy()
df_combine_ou = df_ou.merge(df_pred, left_on = ["chrom", "pos", "ref", "alt"], right_on=["CHROM", "POS", "REF", "ALT"], how = "inner")
binary_labels_ou = (df_combine_ou['consequence'] == 'over')
roc_promoter_ou = roc_auc_score(binary_labels_ou, df_combine_ou["Promoter"])
print(roc_promoter_ou)
