In [None]:
import subprocess
import os


def make_fasta(output_dir = "results"):
    os.makedirs(output_dir, exist_ok=True)
    os.chdir(output_dir)

    download_variants = "wget https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/NA12878_HG001/latest/GRCh38/HG001_GRCh38_1_22_v4.2.1_benchmark.vcf.gz"
    download_reference_genome = "wget https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/references/GRCh38/GCA_000001405.15_GRCh38_no_alt_analysis_set.fasta.gz && gunzip GCA_000001405.15_GRCh38_no_alt_analysis_set.fasta"
    get_biallelic_variants = "bcftools view -v snps -m2 -M2 HG001_GRCh38_1_22_v4.2.1_benchmark.vcf.gz -Oz -o GM12878_SNPs_biallelic.vcf.gz"
    get_index = "bcftools index GM12878_SNPs_biallelic.vcf.gz"
    get_consensus = "bcftools consensus -f GCA_000001405.15_GRCh38_no_alt_analysis_set.fasta -H 1 GM12878_SNPs_biallelic.vcf.gz > GM12878.fasta"

    subprocess.run(download_variants, shell=True)
    subprocess.run(download_reference_genome, shell=True)
    subprocess.run(get_biallelic_variants, shell=True)
    subprocess.run(get_index, shell=True)
    subprocess.run(get_consensus, shell=True)

    os.remove(path="HG001_GRCh38_1_22_v4.2.1_benchmark.vcf.gz")
    os.remove(path="GCA_000001405.15_GRCh38_no_alt_analysis_set.fasta")
    os.remove(path="GM12878_SNPs_biallelic.vcf.gz")
    os.remove(path="GM12878_SNPs_biallelic.vcf.gz.csi")
    
    print(f"Done! Fasta file saved to {output_dir}/GM12878.fasta")

In [None]:
make_fasta()

In [None]:
from gtfparse import read_gtf
import polars as pl
from IPython.display import display

In [None]:
df = read_gtf("/data/common/genome/gencode.v47.basic.annotation.gtf")
filtered_df = df.filter(
    (pl.col('feature') == 'exon') &
    (pl.col('gene_type') == 'protein_coding') &
    (pl.col('seqname').is_in(['chr9']))
)
as_string = filtered_df.with_columns(pl.col('start').cast(pl.Utf8), pl.col('end').cast(pl.Utf8))
as_num = as_string.with_columns(pl.col('exon_number').cast(pl.Int64))
indexed_df = as_num.with_row_index()

display(indexed_df)

In [None]:
# Get lists of first and last indices
first_indices = []
last_indices = []

for _, group in indexed_df.group_by('transcript_id'):
    sorted_group = group.sort('exon_number')
    first_indices.append(sorted_group.row(0, named=True)['index'])
    last_indices.append(sorted_group.row(-1, named=True)['index'])

# Create update expressions
placeholder_df = indexed_df.with_columns([
    pl.when(pl.col("index").is_in(first_indices))
    .then(pl.lit("START"))
    .otherwise(pl.col("start"))
    .alias("start"),
    
    pl.when(pl.col("index").is_in(last_indices))
    .then(pl.lit("END"))
    .otherwise(pl.col("end"))
    .alias("end")
])

sorted_df = placeholder_df.sort('seqname', 'transcript_id', 'exon_number')

display(sorted_df)

In [None]:
import polars as pl
from IPython.display import display

In [None]:
quant_tsv_1 = pl.read_csv("../ENCFF189XTO.tsv", separator='\t')
quant_tsv_2 = pl.read_csv("../ENCFF971DVB.tsv", separator='\t')
display(quant_tsv_1)
display(quant_tsv_2)


In [None]:
joined_tsv = quant_tsv_1.join(quant_tsv_2, on='transcript_ID', how='inner')
display(joined_tsv)

averaged_counts = joined_tsv.with_columns(
    ((pl.col('rep1ENCSR368UNC') + pl.col('rep2ENCSR368UNC')) / 2).alias('transcript_count')
)
clean_tsv = averaged_counts.select("annot_transcript_id", "annot_transcript_name", "transcript_count")

In [None]:
import polars as pl 
from gtfparse import read_gtf

In [None]:
gtf_file = read_gtf("/data/common/genome/gencode.v44.basic.annotation.gtf")
gtf_file.write_parquet("../reference_files/gencode.v44.basic.annotation.gtf.parquet")

In [None]:
import os
os.chdir("/zata/zippy/ramirezc/splice-model-benchmark")

import sys
splice_transformer_path = os.path.join(f"{os.getcwd}", 'reference_files', 'SpliceTransformer')
sys.path.append(splice_transformer_path)

import pandas as pd
import numpy as np
from pyfaidx import Fasta
import argparse
import vcf as pyvcf
from pyensembl import Genome
import tqdm
import os
from sptransformer import Annotator
import torch

In [None]:
annotator = Annotator()
gtf = annotator.gtf

tis_names = ['Adipose Tissue', 'Blood', 'Blood Vessel', 'Brain', 'Colon', 'Heart', 'Kidney',
                'Liver', 'Lung', 'Muscle', 'Nerve', 'Small Intestine', 'Skin', 'Spleen', 'Stomach']

input_seq = 'N'*4000 + 'ACGTAGGGCG' + 'N'*4000  # just an example
input_seq = annotator.model.one_hot_encode(input_seq)
input_seq = torch.tensor(input_seq).to(annotator.model.device)
print(input_seq.shape)
# the function step() accepts encoded sequence, (Batch, 4, Length),
# thus, the input_seq should have shape (1, 4, Length)
input_seq = input_seq.unsqueeze(0).float().transpose(1, 2)
output = annotator.model.step(input_seq)
print(output.shape)

In [None]:
import torch

save_path = 'model/weights/SpTransformer_pytorch.ckpt'
save_dict = torch.load(save_path, map_location='cpu')

new_state_dict = {}
for key, value in save_dict["state_dict"].items():
    if "attn.pos_emb.weights_" in key:
        new_key = key.replace("attn.pos_emb.weights_", "attn.pos_emb.weights.")
        new_state_dict[new_key] = value
    else:
        new_state_dict[key] = value

save_dict["state_dict"] = new_state_dict

new_save_path = 'model/weights/SpTransformer_pytorch_fixed.ckpt'
torch.save(save_dict, new_save_path)

print(f"Modified checkpoint saved to {new_save_path}")

In [None]:
import torch
from pangolin.model import *
import os
os.chdir("/zata/zippy/ramirezc/splice-model-benchmark")

model_path = "reference_files/pangolin/models/final.{model_index}.{model_num}.3"

model_nums = [0, 2, 4, 6]
models = []
for i in model_nums:
    for j in range(1, 6):
        model = Pangolin(L, W, AR)
        model.cuda()
        weights = torch.load(model_path.format(model_index=j, model_num=i))
        model.load_state_dict(weights)
        model.eval()
        models.append(model)
                
print(models)


In [1]:
import zarr
import os

In [24]:
splice_site_predicitons = zarr.open_group(store="/zata/zippy/ramirezc/splice-model-benchmark/results/pangolin_predictions.zarr/splice_site_predictions", mode="r")
splice_site_truth = zarr.open_group(store="/zata/zippy/ramirezc/splice-model-benchmark/results/pangolin_predictions.zarr/splice_site_truth", mode="r")
splice_sites = zarr.open_group(store="/zata/zippy/ramirezc/splice-model-benchmark/results/pangolin_predictions.zarr/splice_sites", mode="r")

In [31]:
print(splice_site_predicitons["chr1"][11963635])
print(splice_site_truth["chr1"][11963635])

0.9036884307861328
1


In [14]:
print(splice_sites['metadata'][:200])

[('chr1',  53945928, '-') ('chr1',  53939984, '-')
 ('chr1',  53940096, '-') ('chr1',  53930038, '-')
 ('chr1',  53930144, '-') ('chr1',  53928371, '-')
 ('chr1',  53928439, '-') ('chr1',  53923903, '-')
 ('chr1',  53923946, '-') ('chr1',  53921560, '-')
 ('chr1',  11934854, '+') ('chr1',  11947975, '+')
 ('chr1',  11948066, '+') ('chr1',  11949772, '+')
 ('chr1',  11949905, '+') ('chr1',  11950356, '+')
 ('chr1',  11950519, '+') ('chr1',  11952622, '+')
 ('chr1',  11952734, '+') ('chr1',  11954829, '+')
 ('chr1',  11954892, '+') ('chr1',  11956916, '+')
 ('chr1',  11957013, '+') ('chr1',  11957841, '+')
 ('chr1',  11957942, '+') ('chr1',  11958515, '+')
 ('chr1',  11958646, '+') ('chr1',  11960645, '+')
 ('chr1',  11960766, '+') ('chr1',  11963531, '+')
 ('chr1',  11963635, '+') ('chr1',  11964174, '+')
 ('chr1',  11964299, '+') ('chr1',  11964643, '+')
 ('chr1',  11964784, '+') ('chr1',  11965479, '+')
 ('chr1',  11965592, '+') ('chr1',  11966250, '+')
 ('chr1',  11966315, '+') ('chr