In [1]:
import json
from pyfasta import Fasta
from torch.utils.data import Dataset
import os
import numpy as np
from config import Fapath, IN_MAP, repdict, SeqTable, EL
import matplotlib.pyplot as plt
class DataGenerator(Dataset):
    def __init__(self, EL, data):
        self.data =data
        self.fafiles = {}
        self.EL = EL

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        d = self.data[index]
        species, chrom, start, end, strand, name = d["species"], d["chrom"], int(d["start"]), int(d["end"]), d["strand"], d["name"]
        if species not in self.fafiles:
            self.fafiles[species] = Fasta(os.path.join(Fapath, species+".fa"))

        seq = self.fafiles[species][chrom][start:end].upper()
        label = np.zeros([end-start-self.EL, 3])
        for v in d["label"]:
            idx, value = v
            idx = int(idx)
            if idx-self.EL//2 >= 0 and idx < end-start-self.EL//2:
                value = np.array([float(_) for _ in value])
                label[idx-self.EL//2] = value
        if strand == "-":
            seq = [repdict[_] for _ in seq][::-1]
            label = np.copy(label[::-1])
        seq = IN_MAP[[SeqTable[_] for _ in seq]][:, :4]
        label[:, 0] = 1.-label[:, 1:].sum(-1)

        return {"X": seq, "single_pred_psi": label, "species": species, "chrom": chrom, "name": name, "strand": strand, "txstart": start, "txend": end}


case_gene_name="SOX13"
save_path="experiments/11_feature_analysis/test_results"
with open("data/Npy/test/human.json", "r") as f:
    train_data=json.load(f)
candidate_data_points=[_ for _ in train_data if _["name"]==case_gene_name]
candidate_data_points.append(
    {'species': 'hg19', 
     'chrom': 'chr1', 
     'start': 204064741, 
     'end': 204099741,
     'strand': '+',
     'label': [[204082241-204064741-1, [1-1e-3, 0, 1e-3]]], 
     'name': 'SOX13'}
    )
candidate_dataset=DataGenerator(data=candidate_data_points, EL=EL)
print(len(candidate_data_points), candidate_data_points, EL)


6 [{'species': 'hg19', 'chrom': 'chr1', 'start': 204027838, 'end': 204062838, 'strand': '+', 'label': [[15000, [0, 0, 0.947368421052632]], [15678, [0, 0, 0.0178571428571429]], [19474, [0, 0, 0.001]]], 'name': 'SOX13'}, {'species': 'hg19', 'chrom': 'chr1', 'start': 204067261, 'end': 204102261, 'strand': '+', 'label': [[14781, [0, 0.964285714285714, 0]], [15000, [0, 0, 0.928571428571429]], [15122, [0, 0, 0.0144927536231884]], [16187, [0, 0.942028985507246, 0]], [16298, [0, 0, 0.971428571428571]], [16385, [0, 0.94392523364486, 0]], [16471, [0, 0, 0.975206611570248]], [18215, [0, 0, 0.001]], [18370, [0, 0.00819672131147541, 0]], [18373, [0, 0.983606557377049, 0]], [18545, [0, 0, 1.0]], [18988, [0, 0.991735537190083, 0]], [19056, [0, 0, 0.946153846153846]], [19459, [0, 0.576923076923077, 0]], [19462, [0, 0.343283582089552, 0]], [19573, [0, 0, 0.933962264150943]], [23737, [0, 0.001, 0]], [23766, [0, 0.926605504587156, 0]], [23851, [0, 0, 0.867924528301887]], [24103, [0, 0.766666666666667, 0]

In [2]:
import importlib, copy
import torch
config_path="experiments/1_evaluate_on_test_and_val/RefSplice_human_test_config".replace("/", ".")
config = importlib.import_module(config_path)
config = config.config
Models = [copy.deepcopy(config.model) for _ in config.model_path]
[m.load_state_dict(torch.load(b))
         for m, b in zip(Models, config.model_path)]
[m.eval() for m in Models]

[32m2024-01-14 15:30:52.180[0m | [1mINFO    [0m | [36mmodels.delta_pretrain[0m:[36m__init__[0m:[36m172[0m - [1mthe number of parameters is 8075145[0m


[MainModel(
   (encode): DataParallel(
     (module): Encode(
       (encodenet): Sequential(
         (0): ResidualUnit(
           (net): Sequential(
             (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
             (1): ReLU()
             (2): Conv1d(64, 64, kernel_size=(11,), stride=(1,))
             (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
             (4): ReLU()
             (5): Dropout(p=0.3, inplace=False)
             (6): Conv1d(64, 64, kernel_size=(11,), stride=(1,))
           )
         )
         (1): ResidualUnit(
           (net): Sequential(
             (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
             (1): ReLU()
             (2): Conv1d(64, 64, kernel_size=(11,), stride=(1,))
             (3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
             (4): ReLU()
             (5): Dropout(p=0.3

In [3]:
@ torch.no_grad()
def get_gradient_for_all_positions(input):
    seq, true_expression= torch.tensor(input["X"]).float(
        ), torch.tensor(input["single_pred_psi"]).float()
    start, end=input["txstart"], input["txend"]
    seq_global=seq.cuda()
    
    acceptor_splice_site=torch.nonzero(true_expression[..., 1])
    donor_splice_site=torch.nonzero(true_expression[..., 2])
    
    ret_lines={}
    table=["A", "C", "G", "T"]
    pred_global=0
    for model in Models:
        pred_global=pred_global+model.encode(seq_global[None])
    pred_global=pred_global/len(Models)
    
    for s in torch.cat([acceptor_splice_site, donor_splice_site], 0):
        site=pred_global[0, s][..., 1:]
        te=true_expression[s][..., 1:]
        with open(os.path.join(save_path, f"{(s+start+15000+1).detach().cpu().numpy()[0]}.csv"), "w") as f:
            f.writelines("Splice site,gt_ssu,pred_ssu,impact_positions,alt_base,impact\n")
            for i in range(s-110, s+110):
                res=[[],[],[]]
                for j in range(4):
                    seq=seq_global.clone()
                    seq[i+15000]=0
                    seq[i+15000, j]=1
                    pred=0
                    for model in Models:
                        pred=pred+model.encode(seq[None]) 
                    pred=pred/len(Models)-pred_global
    
                    res[0].append(str(table[j]))
                    res[1].append(str((pred[0,s,2]).detach().cpu().numpy()[0]))
                f.writelines(f'{(s+start+15000+1).detach().cpu().numpy()[0]},{te[te>0].detach().cpu().numpy()[0]},{site[te>0].detach().cpu().numpy()[0]},{i+start+15000+1},{";".join(res[0])},{";".join(res[1])}\n')
    
    return ret_lines


for d in candidate_dataset:
    ana_results=get_gradient_for_all_positions(d)
  

In [4]:
# predict site1 and site2
hg19_ref=Fasta("hgfile/hg19.fa")
hg19seq1=hg19_ref["chr1"][204082241-1-35001//2:204082241-1-35001//2+35001]
hg19seq2=hg19_ref["chr1"][204082262-1-35001//2:204082262-1-35001//2+35001]
hg19seq1=IN_MAP[[SeqTable[_] for _ in hg19seq1.upper()]][:, :4]
hg19seq2=IN_MAP[[SeqTable[_] for _ in hg19seq2.upper()]][:, :4]
@torch.no_grad()
def pred_seq(species, chrom, position, hgref1, hgref2, hg19seq, seqlength=35001):
    position=position-1
    reference=Fasta("hgfile/{}.fa".format(species))
    seq=reference[chrom][position-seqlength//2:position-seqlength//2+seqlength]
    seq=IN_MAP[[SeqTable[_] for _ in seq.upper()]][:, :4]
    reflabel=np.zeros([seqlength-30000, 3])
    reflabel[(seqlength-30000)//2, 1]=hgref1
    reflabel[(seqlength-30000)//2, 2]=hgref2
    reflabel[..., 0]=1-reflabel[..., 1]-reflabel[..., 2]
    
    pred=sum([model.encode(torch.tensor(seq).cuda()[None].float(), torch.tensor(hg19seq).cuda()[None].float(), torch.tensor(reflabel).cuda()[None].float()) for model in Models])/len(Models)
    assert len(pred)==1
    return ",".join(map(str, pred[0][(seqlength-30000)//2].tolist()[1:]))

hgsite1=pred_seq("hg19", "chr1", 204082241, 0, 0, hg19seq1)
hgsite2=pred_seq("hg19", "chr1", 204082262, 0, 0.928571428571429, hg19seq2)
pantro5site1=pred_seq("panTro5", "chr1", 183005948 , 0, 0, hg19seq1)
pantro5site2=pred_seq("panTro5", "chr1", 183005969, 0, 0.928571428571429, hg19seq2)
rhemac10site1=pred_seq("rheMac10", "chr1", 61620711  , 0, 0, hg19seq1)
rhemac10site2=pred_seq("rheMac10", "chr1", 61620729 , 0, 0.928571428571429, hg19seq2)
calJac4site1=pred_seq("calJac4", "chr19", 28524575, 0, 0, hg19seq1)
calJac4site2=pred_seq("calJac4", "chr19", 28524605, 0, 0.928571428571429, hg19seq2)
mm10site1=pred_seq("mm10", "chr1", 133393073, 0, 0, hg19seq1)
mm10site2=pred_seq("mm10", "chr1", 133393044, 0, 0.928571428571429, hg19seq2)
    
rn6site1=pred_seq("rn6", "chr13", 50468362, 0, 0, hg19seq1)
rn6site2=pred_seq("rn6", "chr13", 50468391, 0, 0.928571428571429, hg19seq2)

susscr11site1=pred_seq("susScr11", "chr9", 64768091 , 0, 0, hg19seq1)
susscr11site2=pred_seq("susScr11", "chr9", 64768120, 0, 0.928571428571429, hg19seq2)

bostau9site1=pred_seq("bosTau9", "chr16", 1899711  , 0, 0, hg19seq1)
bostau9site2=pred_seq("bosTau9", "chr16", 1899739, 0, 0.928571428571429, hg19seq2)

with open(os.path.join(save_path, "site1_site2_summary.csv"), "w") as f:
    f.writelines("reference genome,chrom,site,predicted_acceptor_prob,predicted_donor_prob\n")
    f.writelines(f"hg19,chr1,204082241,{hgsite1}\n")
    f.writelines(f"hg19,chr1,204082262,{hgsite2}\n")
    
    f.writelines(f"panTro5,chr1,183005948,{pantro5site1}\n")
    f.writelines(f"panTro5,chr1,183005969,{pantro5site2}\n")
    
    f.writelines(f"rheMac10,chr1,61620711,{rhemac10site1}\n")
    f.writelines(f"rheMac10,chr1,61620729,{rhemac10site2}\n")
    
    f.writelines(f"calJac4,chr19,28524575,{calJac4site1}\n")
    f.writelines(f"calJac4,chr19,28524605,{calJac4site2}\n")
    
    f.writelines(f"mm10,chr1,133393073,{mm10site1}\n")
    f.writelines(f"mm10,chr1,133393044,{mm10site2}\n")
    
    f.writelines(f"rn6,chr13,50468362,{rn6site1}\n")
    f.writelines(f"rn6,chr13,50468391,{rn6site2}\n")
    
    f.writelines(f"susScr11,chr9,64768091,{susscr11site1}\n")
    f.writelines(f"susScr11,chr9,64768120,{susscr11site2}\n")
    
    f.writelines(f"bosTau9,chr16,1899711,{bostau9site1}\n")
    f.writelines(f"bosTau9,chr16,1899739,{bostau9site2}\n")

In [5]:
print((sum([model.encode(torch.tensor(hg19seq2).cuda()[None].float()) for model in Models])/len(Models))[0,2500])
hg19seq2[35001//2+1]=0
hg19seq2[35001//2+1,0]=0
print((sum([model.encode(torch.tensor(hg19seq2).cuda()[None].float()) for model in Models])/len(Models))[0,2500])

tensor([4.6745e-02, 1.3523e-05, 9.5324e-01], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([8.8867e-01, 1.2614e-05, 1.1132e-01], device='cuda:0',
       grad_fn=<SelectBackward0>)
