In [None]:
import subprocess
import os


def make_fasta(output_dir = "results"):
    os.makedirs(output_dir, exist_ok=True)
    os.chdir(output_dir)

    download_variants = "wget https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/NA12878_HG001/latest/GRCh38/HG001_GRCh38_1_22_v4.2.1_benchmark.vcf.gz"
    download_reference_genome = "wget https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/references/GRCh38/GCA_000001405.15_GRCh38_no_alt_analysis_set.fasta.gz && gunzip GCA_000001405.15_GRCh38_no_alt_analysis_set.fasta"
    get_biallelic_variants = "bcftools view -v snps -m2 -M2 HG001_GRCh38_1_22_v4.2.1_benchmark.vcf.gz -Oz -o GM12878_SNPs_biallelic.vcf.gz"
    get_index = "bcftools index GM12878_SNPs_biallelic.vcf.gz"
    get_consensus = "bcftools consensus -f GCA_000001405.15_GRCh38_no_alt_analysis_set.fasta -H 1 GM12878_SNPs_biallelic.vcf.gz > GM12878.fasta"

    subprocess.run(download_variants, shell=True)
    subprocess.run(download_reference_genome, shell=True)
    subprocess.run(get_biallelic_variants, shell=True)
    subprocess.run(get_index, shell=True)
    subprocess.run(get_consensus, shell=True)

    os.remove(path="HG001_GRCh38_1_22_v4.2.1_benchmark.vcf.gz")
    os.remove(path="GCA_000001405.15_GRCh38_no_alt_analysis_set.fasta")
    os.remove(path="GM12878_SNPs_biallelic.vcf.gz")
    os.remove(path="GM12878_SNPs_biallelic.vcf.gz.csi")
    
    print(f"Done! Fasta file saved to {output_dir}/GM12878.fasta")

In [None]:
make_fasta()

In [None]:
from gtfparse import read_gtf
import polars as pl
from IPython.display import display

In [None]:
df = read_gtf("/data/common/genome/gencode.v47.basic.annotation.gtf")
filtered_df = df.filter(
    (pl.col('feature') == 'exon') &
    (pl.col('gene_type') == 'protein_coding') &
    (pl.col('seqname').is_in(['chr9']))
)
as_string = filtered_df.with_columns(pl.col('start').cast(pl.Utf8), pl.col('end').cast(pl.Utf8))
as_num = as_string.with_columns(pl.col('exon_number').cast(pl.Int64))
indexed_df = as_num.with_row_index()

display(indexed_df)

In [None]:
# Get lists of first and last indices
first_indices = []
last_indices = []

for _, group in indexed_df.group_by('transcript_id'):
    sorted_group = group.sort('exon_number')
    first_indices.append(sorted_group.row(0, named=True)['index'])
    last_indices.append(sorted_group.row(-1, named=True)['index'])

# Create update expressions
placeholder_df = indexed_df.with_columns([
    pl.when(pl.col("index").is_in(first_indices))
    .then(pl.lit("START"))
    .otherwise(pl.col("start"))
    .alias("start"),
    
    pl.when(pl.col("index").is_in(last_indices))
    .then(pl.lit("END"))
    .otherwise(pl.col("end"))
    .alias("end")
])

sorted_df = placeholder_df.sort('seqname', 'transcript_id', 'exon_number')

display(sorted_df)

In [None]:
import polars as pl
from IPython.display import display

In [None]:
quant_tsv_1 = pl.read_csv("../ENCFF189XTO.tsv", separator='\t')
quant_tsv_2 = pl.read_csv("../ENCFF971DVB.tsv", separator='\t')
display(quant_tsv_1)
display(quant_tsv_2)


In [None]:
joined_tsv = quant_tsv_1.join(quant_tsv_2, on='transcript_ID', how='inner')
display(joined_tsv)

averaged_counts = joined_tsv.with_columns(
    ((pl.col('rep1ENCSR368UNC') + pl.col('rep2ENCSR368UNC')) / 2).alias('transcript_count')
)
clean_tsv = averaged_counts.select("annot_transcript_id", "annot_transcript_name", "transcript_count")

In [None]:
import polars as pl 
from gtfparse import read_gtf

In [None]:
gtf_file = read_gtf("/data/common/genome/gencode.v44.basic.annotation.gtf")
gtf_file.write_parquet("../reference_files/gencode.v44.basic.annotation.gtf.parquet")

In [None]:
import os
os.chdir("/zata/zippy/ramirezc/splice-model-benchmark")

import sys
splice_transformer_path = os.path.join(f"{os.getcwd}", 'reference_files', 'SpliceTransformer')
sys.path.append(splice_transformer_path)

import pandas as pd
import numpy as np
from pyfaidx import Fasta
import argparse
import vcf as pyvcf
from pyensembl import Genome
import tqdm
import os
from sptransformer import Annotator
import torch

In [None]:
annotator = Annotator()
gtf = annotator.gtf

tis_names = ['Adipose Tissue', 'Blood', 'Blood Vessel', 'Brain', 'Colon', 'Heart', 'Kidney',
                'Liver', 'Lung', 'Muscle', 'Nerve', 'Small Intestine', 'Skin', 'Spleen', 'Stomach']

input_seq = 'N'*4000 + 'ACGTAGGGCG' + 'N'*4000  # just an example
input_seq = annotator.model.one_hot_encode(input_seq)
input_seq = torch.tensor(input_seq).to(annotator.model.device)
print(input_seq.shape)
# the function step() accepts encoded sequence, (Batch, 4, Length),
# thus, the input_seq should have shape (1, 4, Length)
input_seq = input_seq.unsqueeze(0).float().transpose(1, 2)
output = annotator.model.step(input_seq)
print(output.shape)

In [None]:
import torch

save_path = 'model/weights/SpTransformer_pytorch.ckpt'
save_dict = torch.load(save_path, map_location='cpu')

new_state_dict = {}
for key, value in save_dict["state_dict"].items():
    if "attn.pos_emb.weights_" in key:
        new_key = key.replace("attn.pos_emb.weights_", "attn.pos_emb.weights.")
        new_state_dict[new_key] = value
    else:
        new_state_dict[key] = value

save_dict["state_dict"] = new_state_dict

new_save_path = 'model/weights/SpTransformer_pytorch_fixed.ckpt'
torch.save(save_dict, new_save_path)

print(f"Modified checkpoint saved to {new_save_path}")

In [None]:
import torch
from pangolin.model import *
import os
os.chdir("/zata/zippy/ramirezc/splice-model-benchmark")

model_path = "reference_files/pangolin/models/final.{model_index}.{model_num}.3"

model_nums = [0, 2, 4, 6]
models = []
for i in model_nums:
    for j in range(1, 6):
        model = Pangolin(L, W, AR)
        model.cuda()
        weights = torch.load(model_path.format(model_index=j, model_num=i))
        model.load_state_dict(weights)
        model.eval()
        models.append(model)
                
print(models)


In [None]:
import zarr
import os

In [None]:
splice_site_predicitons = zarr.open_group(store="/zata/zippy/ramirezc/splice-model-benchmark/results/pangolin_predictions.zarr/splice_site_predictions", mode="r")
splice_site_truth = zarr.open_group(store="/zata/zippy/ramirezc/splice-model-benchmark/results/pangolin_predictions.zarr/splice_site_truth", mode="r")
splice_sites = zarr.open_group(store="/zata/zippy/ramirezc/splice-model-benchmark/results/pangolin_predictions.zarr/splice_sites", mode="r")

In [None]:
print(splice_site_predicitons["chr1"][11963635])
print(splice_site_truth["chr1"][11963635])

In [None]:
print(splice_sites['metadata'][:200])

In [None]:
from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model
import haiku as hk

In [None]:
parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
    model_name="segment_nt",
    max_positions=5000 + 1,
)
forward_fn = hk.transform(forward_fn)
donor_idx = config.features.index('splice_donor')
acceptor_idx = config.features.index('splice_acceptor')
print(config.features)

In [None]:
import os
os.chdir("/zata/zippy/ramirezc/splice-model-benchmark")

import polars as pl
import numpy as np
from models.spliceai import SpliceAIEvaluator
from IPython.display import display

In [None]:
print([f"chr{i}" for i in range(1, 11, 2)])

In [None]:
evaluator = SpliceAIEvaluator()
sorted_df = evaluator._filter_gencode()
display(sorted_df)

In [None]:
quant_tsv_1 = pl.read_csv("reference_files/transcript_quantifications_rep1.tsv", separator='\t')
quant_tsv_2 = pl.read_csv("reference_files/transcript_quantifications_rep2.tsv", separator='\t')
joined_tsv = quant_tsv_1.join(quant_tsv_2, on='transcript_ID', how='inner')
averaged_counts = joined_tsv.with_columns(
    ((pl.col('rep1ENCSR368UNC') + pl.col('rep2ENCSR368UNC')) / 2).alias('transcript_count')
)
clean_tsv = averaged_counts.select("annot_transcript_id", "annot_transcript_name", "transcript_count")
expressed_transcripts = clean_tsv.filter(pl.col('transcript_count') >= 2.0)['annot_transcript_id'].to_list()

gtf = pl.read_parquet("reference_files/gencode.v29.primary_assembly.annotation_UCSC_names.gtf.parquet")
filtered_df = gtf.filter(
    (pl.col('feature') == 'transcript') &
    (pl.col('gene_type') == 'protein_coding') &
    (pl.col('seqname').is_in([f"chr{i}" for i in range(1, 11, 2)])) &
    (pl.col('transcript_id').is_in(expressed_transcripts))
)
display(filtered_df)

In [None]:
print(f"Average length of transcripts: {np.median((filtered_df['end'].to_numpy() - filtered_df['start'].to_numpy()))}")

In [None]:
import polars as pl
from IPython.display import display

In [None]:
vcf = pl.read_csv("HG001_GRCh38_1_22_v4.2.1_benchmark.vcf", separator="\t", comment_prefix="#", new_columns=["CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO", "FORMAT", "HG001"])
display(vcf)


In [None]:
filter_chroms = vcf.filter(pl.col("CHROM").is_in([f"chr{i}" for i in range(1, 11, 2)]))
display(filter_chroms)


In [None]:
filtered_gtf = pl.read_parquet("filtered_gtf.parquet")
display(filtered_gtf)

In [None]:
with_start = filtered_gtf.with_columns(
    pl.lit('start').alias('pos_type'),
    pl.col('start').alias('POS'))
with_end = filtered_gtf.with_columns(
    pl.lit('end').alias('pos_type'),
    pl.col('end').alias('POS')
)

concatenated = pl.concat([with_start, with_end])
dropped = concatenated.drop('start', 'end')
to_none = dropped.with_columns(pl.col('POS').replace("EXCLUDE", pl.lit(None)))
as_int = to_none.with_columns(pl.col('POS').cast(pl.Int64))
rename = as_int.with_columns(pl.col('seqname').cast(pl.String).alias('CHROM')).drop('seqname', 'index')
drop_null = rename.drop_nulls(subset=['POS']).with_row_index()

display(drop_null)

In [9]:
import polars as pl
from IPython.display import display
import os
os.chdir("/zata/zippy/ramirezc/splice-model-benchmark")


print("Filtering GENCODE GTF...")
quant_tsv_1 = pl.read_csv("reference_files/transcript_quantifications_rep1.tsv", separator='\t')
quant_tsv_2 = pl.read_csv("reference_files/transcript_quantifications_rep2.tsv", separator='\t')
joined_tsv = quant_tsv_1.join(quant_tsv_2, on='transcript_ID', how='inner')
averaged_counts = joined_tsv.with_columns(
    ((pl.col('rep1ENCSR368UNC') + pl.col('rep2ENCSR368UNC')) / 2).alias('transcript_count')
)
clean_tsv = averaged_counts.select("annot_transcript_id", "annot_transcript_name", "transcript_count")
expressed_transcripts = clean_tsv.filter(pl.col('transcript_count') >= 2.0)['annot_transcript_id'].to_list()
print(f"Number of expressed transcripts: {len(expressed_transcripts)}")

gtf = pl.read_parquet("reference_files/gencode.v29.primary_assembly.annotation_UCSC_names.gtf.parquet")
filtered_df = gtf.filter(
    (pl.col('feature') == 'exon') &
    (pl.col('gene_type') == 'protein_coding') &
    (pl.col('seqname').is_in(['chr1', 'chr3', 'chr5', 'chr7', 'chr9']))
)
filtered_df = filtered_df.filter(pl.col('transcript_id').is_in(expressed_transcripts))
    
transcript_counts = (
    filtered_df
    .select(['seqname', 'transcript_id'])
    .unique()
    .group_by('seqname')
    .len()
    .sort('seqname')
)
print(f"Number of transcripts per chromsome: {transcript_counts}")

with_start = filtered_df.with_columns(
    pl.lit('start').alias('pos_type'),
    pl.col('start').alias('pos'))
with_end = filtered_df.with_columns(
    pl.lit('end').alias('pos_type'),
    pl.col('end').alias('pos')
)

concatenated = pl.concat([with_start, with_end])
dropped = concatenated.drop('start', 'end')
as_int = dropped.with_columns(pl.col('pos').cast(pl.Int64), pl.col('exon_number').cast(pl.Int64))
sorted_df = as_int.sort('seqname', 'transcript_id', 'exon_number', 'pos')

grouped = sorted_df.group_by('seqname', 'transcript_id').agg(pl.col('pos'))
remove_single_exons = grouped.filter(
    pl.col('pos').list.len() > 2
)
removed_start_and_end = remove_single_exons.with_columns(
    pl.col("pos").list.slice(1, pl.col("pos").list.len() - 2)
    .alias("pos")
)

expoloded_df = removed_start_and_end.explode('pos')
joined_df = sorted_df.join(expoloded_df, on=['seqname', 'transcript_id', 'pos'], how='inner').drop('index').with_row_index()
display(joined_df)

# Drop the first and last exon for each seqname (chrom), transcript_id


# rename = as_int.with_columns(pl.col('seqname').cast(pl.String)).drop('index', 'score')
# drop_null = rename.drop_nulls(subset=['pos']).with_row_index()

# sorted_df = drop_null.sort('seqname', 'transcript_id', 'exon_number')
# display(sorted_df)

Filtering GENCODE GTF...
Number of expressed transcripts: 10660
Number of transcripts per chromsome: shape: (5, 2)
┌─────────┬─────┐
│ seqname ┆ len │
│ ---     ┆ --- │
│ cat     ┆ u32 │
╞═════════╪═════╡
│ chr1    ┆ 281 │
│ chr3    ┆ 136 │
│ chr5    ┆ 113 │
│ chr7    ┆ 129 │
│ chr9    ┆ 98  │
└─────────┴─────┘


index,seqname,source,feature,score,strand,frame,gene_id,gene_type,gene_name,level,havana_gene,transcript_id,transcript_type,transcript_name,transcript_support_level,tag,havana_transcript,exon_number,exon_id,ont,protein_id,ccdsid,pos_type,pos
u32,cat,cat,cat,f32,cat,i64,str,str,str,str,str,str,str,str,str,str,str,i64,str,str,str,str,str,i64
0,"""chr1""","""HAVANA""","""exon""",,"""-""",0,"""ENSG00000081870.11""","""protein_coding""","""HSPB11""","""2""","""OTTHUMG00000008408.4""","""ENST00000194214.9""","""protein_coding""","""HSPB11-201""","""1""","""basic,appris_principal_1,CCDS""","""OTTHUMT00000023114.1""",1,"""ENSE00001841796.1""","""""","""ENSP00000194214.5""","""CCDS41341.1""","""end""",53945929
1,"""chr1""","""HAVANA""","""exon""",,"""-""",0,"""ENSG00000081870.11""","""protein_coding""","""HSPB11""","""2""","""OTTHUMG00000008408.4""","""ENST00000194214.9""","""protein_coding""","""HSPB11-201""","""1""","""basic,appris_principal_1,CCDS""","""OTTHUMT00000023114.1""",2,"""ENSE00001334213.1""","""""","""ENSP00000194214.5""","""CCDS41341.1""","""start""",53939985
2,"""chr1""","""HAVANA""","""exon""",,"""-""",0,"""ENSG00000081870.11""","""protein_coding""","""HSPB11""","""2""","""OTTHUMG00000008408.4""","""ENST00000194214.9""","""protein_coding""","""HSPB11-201""","""1""","""basic,appris_principal_1,CCDS""","""OTTHUMT00000023114.1""",2,"""ENSE00001334213.1""","""""","""ENSP00000194214.5""","""CCDS41341.1""","""end""",53940097
3,"""chr1""","""HAVANA""","""exon""",,"""-""",0,"""ENSG00000081870.11""","""protein_coding""","""HSPB11""","""2""","""OTTHUMG00000008408.4""","""ENST00000194214.9""","""protein_coding""","""HSPB11-201""","""1""","""basic,appris_principal_1,CCDS""","""OTTHUMT00000023114.1""",3,"""ENSE00000772733.1""","""""","""ENSP00000194214.5""","""CCDS41341.1""","""start""",53930039
4,"""chr1""","""HAVANA""","""exon""",,"""-""",0,"""ENSG00000081870.11""","""protein_coding""","""HSPB11""","""2""","""OTTHUMG00000008408.4""","""ENST00000194214.9""","""protein_coding""","""HSPB11-201""","""1""","""basic,appris_principal_1,CCDS""","""OTTHUMT00000023114.1""",3,"""ENSE00000772733.1""","""""","""ENSP00000194214.5""","""CCDS41341.1""","""end""",53930145
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
9013,"""chr9""","""HAVANA""","""exon""",,"""+""",0,"""ENSG00000165060.12""","""protein_coding""","""FXN""","""2""","""OTTHUMG00000019977.11""","""ENST00000643639.1""","""protein_coding""","""FXN-207""","""""","""basic,appris_principal_1,CCDS""","""OTTHUMT00000052568.4""",3,"""ENSE00001089856.1""","""""","""ENSP00000496143.1""","""CCDS6626.1""","""start""",69053140
9014,"""chr9""","""HAVANA""","""exon""",,"""+""",0,"""ENSG00000165060.12""","""protein_coding""","""FXN""","""2""","""OTTHUMG00000019977.11""","""ENST00000643639.1""","""protein_coding""","""FXN-207""","""""","""basic,appris_principal_1,CCDS""","""OTTHUMT00000052568.4""",3,"""ENSE00001089856.1""","""""","""ENSP00000496143.1""","""CCDS6626.1""","""end""",69053260
9015,"""chr9""","""HAVANA""","""exon""",,"""+""",0,"""ENSG00000165060.12""","""protein_coding""","""FXN""","""2""","""OTTHUMG00000019977.11""","""ENST00000643639.1""","""protein_coding""","""FXN-207""","""""","""basic,appris_principal_1,CCDS""","""OTTHUMT00000052568.4""",4,"""ENSE00001089858.1""","""""","""ENSP00000496143.1""","""CCDS6626.1""","""start""",69064938
9016,"""chr9""","""HAVANA""","""exon""",,"""+""",0,"""ENSG00000165060.12""","""protein_coding""","""FXN""","""2""","""OTTHUMG00000019977.11""","""ENST00000643639.1""","""protein_coding""","""FXN-207""","""""","""basic,appris_principal_1,CCDS""","""OTTHUMT00000052568.4""",4,"""ENSE00001089858.1""","""""","""ENSP00000496143.1""","""CCDS6626.1""","""end""",69065035


In [None]:
print(removed_start_and_end['pos'][2].to_list())

In [4]:
display(gtf.filter(pl.col('transcript_id') == 'ENST00000610533.4', pl.col('feature') == 'exon').with_columns(pl.col('exon_number').cast(pl.Int64)).sort('seqname', 'exon_number', 'start'))

seqname,source,feature,start,end,score,strand,frame,gene_id,gene_type,gene_name,level,havana_gene,transcript_id,transcript_type,transcript_name,transcript_support_level,tag,havana_transcript,exon_number,exon_id,ont,protein_id,ccdsid
cat,cat,cat,i64,i64,f32,cat,i64,str,str,str,str,str,str,str,str,str,str,str,i64,str,str,str,str
"""chr7""","""ENSEMBL""","""exon""",44123511,44123559,,"""-""",0,"""ENSG00000106628.10""","""protein_coding""","""POLD2""","""3""","""OTTHUMG00000022909.15""","""ENST00000610533.4""","""protein_coding""","""POLD2-218""","""1""","""basic,CCDS""","""""",1,"""ENSE00003727843.1""","""""","""ENSP00000480186.1""","""CCDS75586.1"""
"""chr7""","""ENSEMBL""","""exon""",44121834,44122109,,"""-""",0,"""ENSG00000106628.10""","""protein_coding""","""POLD2""","""3""","""OTTHUMG00000022909.15""","""ENST00000610533.4""","""protein_coding""","""POLD2-218""","""1""","""basic,CCDS""","""""",2,"""ENSE00003725493.1""","""""","""ENSP00000480186.1""","""CCDS75586.1"""
"""chr7""","""ENSEMBL""","""exon""",44117943,44118064,,"""-""",0,"""ENSG00000106628.10""","""protein_coding""","""POLD2""","""3""","""OTTHUMG00000022909.15""","""ENST00000610533.4""","""protein_coding""","""POLD2-218""","""1""","""basic,CCDS""","""""",3,"""ENSE00003570606.1""","""""","""ENSP00000480186.1""","""CCDS75586.1"""
"""chr7""","""ENSEMBL""","""exon""",44117619,44117742,,"""-""",0,"""ENSG00000106628.10""","""protein_coding""","""POLD2""","""3""","""OTTHUMG00000022909.15""","""ENST00000610533.4""","""protein_coding""","""POLD2-218""","""1""","""basic,CCDS""","""""",4,"""ENSE00003789985.1""","""""","""ENSP00000480186.1""","""CCDS75586.1"""
"""chr7""","""ENSEMBL""","""exon""",44117133,44117247,,"""-""",0,"""ENSG00000106628.10""","""protein_coding""","""POLD2""","""3""","""OTTHUMG00000022909.15""","""ENST00000610533.4""","""protein_coding""","""POLD2-218""","""1""","""basic,CCDS""","""""",5,"""ENSE00000680774.1""","""""","""ENSP00000480186.1""","""CCDS75586.1"""
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""chr7""","""ENSEMBL""","""exon""",44116430,44116510,,"""-""",0,"""ENSG00000106628.10""","""protein_coding""","""POLD2""","""3""","""OTTHUMG00000022909.15""","""ENST00000610533.4""","""protein_coding""","""POLD2-218""","""1""","""basic,CCDS""","""""",7,"""ENSE00000680769.1""","""""","""ENSP00000480186.1""","""CCDS75586.1"""
"""chr7""","""ENSEMBL""","""exon""",44116115,44116272,,"""-""",0,"""ENSG00000106628.10""","""protein_coding""","""POLD2""","""3""","""OTTHUMG00000022909.15""","""ENST00000610533.4""","""protein_coding""","""POLD2-218""","""1""","""basic,CCDS""","""""",8,"""ENSE00000680766.1""","""""","""ENSP00000480186.1""","""CCDS75586.1"""
"""chr7""","""ENSEMBL""","""exon""",44115766,44115893,,"""-""",0,"""ENSG00000106628.10""","""protein_coding""","""POLD2""","""3""","""OTTHUMG00000022909.15""","""ENST00000610533.4""","""protein_coding""","""POLD2-218""","""1""","""basic,CCDS""","""""",9,"""ENSE00000680764.1""","""""","""ENSP00000480186.1""","""CCDS75586.1"""
"""chr7""","""ENSEMBL""","""exon""",44115295,44115396,,"""-""",0,"""ENSG00000106628.10""","""protein_coding""","""POLD2""","""3""","""OTTHUMG00000022909.15""","""ENST00000610533.4""","""protein_coding""","""POLD2-218""","""1""","""basic,CCDS""","""""",10,"""ENSE00000680760.1""","""""","""ENSP00000480186.1""","""CCDS75586.1"""
