In [1]:
import numpy as np
import os
from scipy.stats import pearsonr

In [2]:
import time
import tabix
import torch
import selene_sdk
import pyBigWig
from torch import nn
from scipy.special import softmax
from matplotlib import pyplot as plt
from selene_sdk.targets import Target
from selene_sdk.samplers import RandomPositionsSampler
from selene_sdk.samplers.dataloader import SamplerDataLoader
from scipy.stats import spearmanr
torch.set_default_tensor_type('torch.FloatTensor')


  _C._set_default_tensor_type(t)


# Data

In [3]:
import pandas as pd
import selene_sdk

root = "/work/magroup/4DN/Puffin/"

In [4]:
genome = selene_sdk.sequences.Genome(
                    input_path=root+"Homo_sapiens.GRCh38.dna.primary_assembly.fa",
                    blacklist_regions= 'hg38'
                )

noblacklist_genome = selene_sdk.sequences.Genome(
                    input_path=root+"Homo_sapiens.GRCh38.dna.primary_assembly.fa" )



In [5]:
import pyBigWig
import tabix
from selene_sdk.targets import Target
import numpy as np

class GenomicSignalFeatures(Target):
    """
    #Accept a list of cooler files as input.
    """
    def __init__(self, input_paths, features, shape, blacklists=None, blacklists_indices=None, 
        replacement_indices=None, replacement_scaling_factors=None):
        """
        Constructs a new `GenomicFeatures` object.
        """
        self.input_paths = input_paths
        self.initialized = False
        self.blacklists = blacklists
        self.blacklists_indices = blacklists_indices
        self.replacement_indices = replacement_indices
        self.replacement_scaling_factors = replacement_scaling_factors

            
        self.n_features = len(features)
        self.feature_index_dict = dict(
            [(feat, index) for index, feat in enumerate(features)])
        self.shape = (len(input_paths), *shape)

    def get_feature_data(self, chrom, start, end, nan_as_zero=True, feature_indices=None):
        if not self.initialized:
            self.data = [pyBigWig.open(path) for path in self.input_paths]
            if self.blacklists is not None:
                self.blacklists = [tabix.open(blacklist)  for blacklist in self.blacklists]
            self.initialized=True

        if feature_indices is None:
            feature_indices = np.arange(len(self.data))

        wigmat = np.zeros((len(feature_indices), end - start), dtype=np.float32)
        for i in feature_indices:
            try:
                wigmat[i, :] = self.data[i].values(chrom, start, end, numpy=True)
            except:
                print(chrom, start, end, self.input_paths[i], flush=True)
                raise
        
        if self.blacklists is not None:
            if self.replacement_indices is None:
                if self.blacklists_indices is not None:
                    for blacklist, blacklist_indices in zip(self.blacklists, self.blacklists_indices):
                        for _, s, e in blacklist.query(chrom, start, end):
                            wigmat[blacklist_indices, np.fmax(int(s)-start,0): int(e)-start] = 0
                else:
                    for blacklist in self.blacklists:
                        for _, s, e in blacklist.query(chrom, start, end):
                            wigmat[:, np.fmax(int(s)-start,0): int(e)-start] = 0
            else:
                for blacklist, blacklist_indices, replacement_indices, replacement_scaling_factor in zip(self.blacklists, self.blacklists_indices, self.replacement_indices, self.replacement_scaling_factors):
                    for _, s, e in blacklist.query(chrom, start, end):
                        wigmat[blacklist_indices, np.fmax(int(s)-start,0): int(e)-start] = wigmat[replacement_indices, np.fmax(int(s)-start,0): int(e)-start] * replacement_scaling_factor

        if nan_as_zero:
            wigmat[np.isnan(wigmat)]=0
        return wigmat




tfeature = GenomicSignalFeatures([root+"agg.plus.bw.bedgraph.bw",
root+"agg.encodecage.plus.v2.bedgraph.bw",
root+"agg.encoderampage.plus.v2.bedgraph.bw",
root+"agg.plus.grocap.bedgraph.sorted.merged.bw",
root+"agg.plus.allprocap.bedgraph.sorted.merged.bw",
root+"agg.minus.allprocap.bedgraph.sorted.merged.bw",
root+"agg.minus.grocap.bedgraph.sorted.merged.bw",
root+"agg.encoderampage.minus.v2.bedgraph.bw",
root+"agg.encodecage.minus.v2.bedgraph.bw",
root+"agg.minus.bw.bedgraph.bw"],
                               ['cage_plus','encodecage_plus','encoderampage_plus', 'grocap_plus','procap_plus','procap_minus','grocap_minus'
,'encoderampage_minus', 'encodecage_minus',
'cage_minus'],
                               (100000,),
                               [root+"fantom.blacklist8.plus.bed.gz",root+"fantom.blacklist8.minus.bed.gz"],
                               [0,9], [1,8], [0.61357, 0.61357])



In [6]:
sampler = RandomPositionsSampler(
                reference_sequence= genome,
                target= tfeature,
                features = [''],
                test_holdout=['chr8', 'chr9'],
                validation_holdout= ['chr10'],
                sequence_length= 100000,
                center_bin_to_predict= 100000,
                position_resolution=1,
                random_shift=0,
                random_strand=False
)


In [7]:
seed=3

# Train loader
sampler.mode="train"
# train_loader = SamplerDataLoader(sampler, num_workers=32, batch_size=32, seed=seed)
train_loader = SamplerDataLoader(sampler, num_workers=1, batch_size=16, seed=seed)

In [8]:
for sequence, target in train_loader:
    print(sequence.shape,target.shape)
    print(sequence[0,:3,:])
    break


torch.Size([16, 100000, 4]) torch.Size([16, 10, 100000])
tensor([[0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.]], dtype=torch.float64)


# Model

In [9]:
# import torch
# import torch.nn as nn
# import torch.optim as optim

# Define the CNN
# class CNN(nn.Module):
#     def __init__(self):
#         super(CNN, self).__init__()
#         self.conv1 = nn.Conv1d(in_channels=4, out_channels=16, kernel_size=3, padding=1)
#         self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
#         self.conv3 = nn.Conv1d(in_channels=32, out_channels=10, kernel_size=3, padding=1)
#         self.relu = nn.ReLU()
        
#     def forward(self, x):
#         x = self.relu(self.conv1(x))
#         x = self.relu(self.conv2(x))
#         x = self.conv3(x)
#         return x

In [10]:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, dropout_rate):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv1d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              dilation=dilation,
                              padding=(kernel_size - 1) * dilation,
                              stride=stride),
            nn.Dropout(p=dropout_rate),
            nn.BatchNorm1d(num_features=out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)


class SimpleCNN(nn.Module):
    def __init__(self, in_channels=4, channels = [16,32,64], output_shape=2, input_length=450000, dropout_rate=0, kernel_sizes = None, dilation_sizes = None):
        super(SimpleCNN, self).__init__()

        kernel_sizes = [3 for i in range(len(channels))] if kernel_sizes is None else kernel_sizes
        dilation_sizes = [1 for i in range(len(channels))] if dilation_sizes is None else dilation_sizes 

        layers = [ConvBlock(in_channels=in_channels, out_channels=channels[0], kernel_size=kernel_sizes[0], stride=1, dilation=dilation_sizes[0], dropout_rate=dropout_rate)]
        self.output_shape = output_shape
        self.input_length = input_length
        
        for i in range(len(channels) - 1):
            in_channels = channels[i]
            out_channels = channels[i + 1]
            layers.append(ConvBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_sizes[i+1], stride=1, dilation=dilation_sizes[i+1], dropout_rate=dropout_rate))
  
        self.layers = nn.Sequential(*layers)
        
        if isinstance(output_shape, int):  # ETGP, eQTLP
            self.fout = 'Linear'
            self.fc = nn.Linear(channels[-1], output_shape)
            self.relu = nn.ReLU()
        elif len(output_shape)==2:  # RSAP, TISP
            self.fout = 'Conv1d'
            self.adaptive_pool = nn.AdaptiveMaxPool1d(output_shape[0])  # Adaptive pooling to ensure the exact sequence length
            self.final_conv = nn.Conv1d(channels[-1], output_shape[1], kernel_size=1)  # Adjust channels without changing length
            

    def forward(self, x):
        x = x.transpose(1, 2)  
        x = self.layers(x)
        if self.fout == 'Linear':
            x = F.max_pool1d(x, x.size(2)).squeeze()
            x = self.fc(x)
            x = self.relu(x)
        elif self.fout == 'Conv1d': 
            x = self.adaptive_pool(x)
            x = self.final_conv(x)
            x = x.transpose(1, 2) 
    
        return x

In [11]:
model = SimpleCNN(channels = [16,64,256,1024], output_shape=(100000,10), input_length=100000)
model.cuda()
model.train()

SimpleCNN(
  (layers): Sequential(
    (0): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(4, 16, kernel_size=(3,), stride=(1,), padding=(2,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (1): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(16, 64, kernel_size=(3,), stride=(1,), padding=(2,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (2): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(64, 256, kernel_size=(3,), stride=(1,), padding=(2,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (3): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(256, 1024, kernel_size=(3,), 

# Train

In [12]:
# loss
def PseudoPoissonKL(lpred, ltarget):
    return (ltarget * torch.log((ltarget+1e-10)/(lpred+1e-10)) + lpred - ltarget)

In [13]:
# Instantiate the model, define loss and optimizer

# Define loss
criterion = nn.MSELoss()
# weights = torch.ones(10).cuda()
# criterion = PseudoPoissonKL

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(),lr=0.005)

# Define scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.9, patience=10, threshold=0)


In [14]:
# # Example train function
# num_epochs=20
# for epoch in range(num_epochs):
#     optimizer.zero_grad()

#     x = sequence.permute(0, 2, 1).float().cuda()
#     y = target.float().cuda()
    
#     # forward pass
#     pred = model(x)
#     print(x.shape, pred.shape)
    
#     # Compute loss
#     loss = criterion(pred, y)


    
#     # Backward pass and optimization
#     loss.backward()
    
#     optimizer.step()

    
#     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
#     break

In [None]:
# Puffin
i=0
past_losses=[]
firstvalid=True
bestcor=0
n_samples = 0
while True:
    for sequence, target in train_loader:
            if torch.rand(1)<0.5:
                sequence = sequence.flip([1,2])
                target = target.flip([1,2])

            n_samples+=sequence.shape[0]
            # x = sequence.permute(0, 2, 1).float().cuda()
            x = sequence.float().cuda()
            y = target.float().cuda()
  
            optimizer.zero_grad()
            pred = model(x).transpose(1, 2)
        
            # loss = (PseudoPoissonKL(pred, target.cuda()) * weights[None,:,None]).mean() 
            loss = criterion(pred, y)
        
            loss.backward()
            past_losses.append(loss.detach().cpu().numpy())
            
            # print(i, loss.item())

            
            optimizer.step()

                       

     
        
            if i % 500 ==0:
            # if i % 1 == 0:
                print(i)
                print("train loss:"+str(np.mean(past_losses[-500:])),flush=True)
                scheduler.step(loss) 
                print(optimizer.param_groups[0]['lr'])

            # if i % 500 == 0:
                torch.save(model, '/work/magroup/wenduoc/benchmark/Puffin/baseline_puffin/'+'CNN_3'+'.checkpoint')
                torch.save(optimizer, '/work/magroup/wenduoc/benchmark/Puffin/baseline_puffin/'+'CNN_3'+'.optimizer')

            rstate_saved = np.random.get_state()
            # if i % 8000 == 0:
            if i % 500 == 0:
            # if i % 1 == 0:
                if firstvalid:
                    validseq = noblacklist_genome.get_encoding_from_coords("chr10", 0, 114364328)
                    validcage = tfeature.get_feature_data("chr10", 0, 114364328)
                    firstvalid = False
                model.eval()
                print(validseq.shape, flush=True)
                with torch.no_grad():
                    validpred = np.zeros((10, 114364328))
                    # kllosses = []
                    for ii in np.arange(0, 114364328, 50000)[:-2]:
                        pred = (
                            model(
                                torch.FloatTensor(validseq[ii : ii + 100000, :][None, :, :])
                                # .transpose(1, 2)
                                .cuda()
                            ).transpose(1, 2)
                            .cpu()
                            .detach()
                            .numpy()
                        )
                        pred2 = (
                            model(
                                torch.FloatTensor(validseq[ii : ii + 100000, :][None, ::-1, ::-1].copy())
                                # .transpose(1, 2)
                                .cuda()
                            ).transpose(1, 2)
                            .cpu()
                            .detach()
                            .numpy()[:, ::-1, ::-1]
                        )

                        validpred[:, ii + 25000 : ii + 75000] = (
                            pred[0, :, 25000:75000] * 0.5 + pred2[0, :, 25000:75000] * 0.5
                        )



                validcor = (
                    np.corrcoef(validpred[0, :ii], validcage[0, :ii])[0, 1] * 0.5
                    + np.corrcoef(validpred[-1, :ii], validcage[-1, :ii])[0, 1] * 0.5
                )
                validcor2 = (
                    np.corrcoef(validpred[1, :ii], validcage[1, :ii])[0, 1] * 0.5
                    + np.corrcoef(validpred[-2, :ii], validcage[-2, :ii])[0, 1] * 0.5
                )
                validcor3 = (
                    np.corrcoef(validpred[2, :ii], validcage[2, :ii])[0, 1] * 0.5
                    + np.corrcoef(validpred[-3, :ii], validcage[-3, :ii])[0, 1] * 0.5
                )
                validcor4 = (
                    np.corrcoef(validpred[3, :ii], validcage[3, :ii])[0, 1] * 0.5
                    + np.corrcoef(validpred[-4, :ii], validcage[-4, :ii])[0, 1] * 0.5
                )
                validcor5 = (
                    np.corrcoef(validpred[4, :ii], validcage[4, :ii])[0, 1] * 0.5
                    + np.corrcoef(validpred[-5, :ii], validcage[-5, :ii])[0, 1] * 0.5
                )
                print("Cor {0} {1} {2} {3} {4}".format(validcor, validcor2, validcor3, validcor4, validcor5))

                model.train()
                validsum = validcor + validcor2 + validcor3 + validcor4 + validcor5
                if bestcor < validsum:
                    bestcor = validsum
                    torch.save(model, '/work/magroup/wenduoc/benchmark/Puffin/baseline_puffin/'+'CNN_3'+'.best.checkpoint')
                    torch.save(optimizer, '/work/magroup/wenduoc/benchmark/Puffin/baseline_puffin/'+'CNN_3'+'.best.optimizer')
            i+=1
            del x, y, pred, loss
            
        
            if n_samples > 100000:
                print('Done!')
                break
    break

0
train loss:0.27690426
0.005
(114364328, 4)
Cor 0.0030526697580957438 0.003542320188668639 0.0021053281855192276 0.005134597469163311 0.0019901055379286086
500
train loss:0.058204312
0.005
(114364328, 4)
Cor 0.0001596501467121537 -0.0001546204409545666 6.227556498842301e-05 0.0003124226539105827 0.001117784036814831
1000
train loss:0.00029515632
0.005
(114364328, 4)
Cor -0.003408185860242576 0.004527108244016793 0.0006406482951593019 0.0007513708979533687 0.011119373227336483
1500
train loss:0.0010260381
0.005
(114364328, 4)
Cor 0.0033114915808798795 0.002743869797427067 0.005745940827759937 -0.0043194761383138075 0.0027269302207097276
2000
train loss:0.0012247341
0.005
(114364328, 4)
Cor 0.00033004674037885714 -0.0033282302994240894 0.004765649715602938 -0.002632274495505231 0.013244089003423925
2500
train loss:0.0009184549
0.005
(114364328, 4)
Cor 0.004468767034515042 0.007286819862235727 0.005562013585270048 0.0025491465203971563 0.00892939027384013
3000
train loss:0.00062832783
0.