In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [12]:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, dropout_rate):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv1d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              dilation=dilation,
                              padding=(kernel_size - 1) * dilation,
                              stride=stride),
            nn.Dropout(p=dropout_rate),
            nn.BatchNorm1d(num_features=out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)


class SimpleCNN(nn.Module):
    def __init__(self, in_channels=4, channels = [16,32,64], output_shape=2, input_length=450000, dropout_rate=0, kernel_sizes = None, dilation_sizes = None):
        super(SimpleCNN, self).__init__()

        kernel_sizes = [3 for i in range(len(channels))] if kernel_sizes is None else kernel_sizes
        dilation_sizes = [1 for i in range(len(channels))] if dilation_sizes is None else dilation_sizes 

        layers = [ConvBlock(in_channels=in_channels, out_channels=channels[0], kernel_size=kernel_sizes[0], stride=1, dilation=dilation_sizes[0], dropout_rate=dropout_rate)]
        self.output_shape = output_shape
        self.input_length = input_length
        
        for i in range(len(channels) - 1):
            in_channels = channels[i]
            out_channels = channels[i + 1]
            layers.append(ConvBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_sizes[i+1], stride=1, dilation=dilation_sizes[i+1], dropout_rate=dropout_rate))
  
        self.layers = nn.Sequential(*layers)
        
        if isinstance(output_shape, int):  # ETGP, eQTLP
            self.fout = 'Linear'
            self.fc = nn.Linear(channels[-1], output_shape)
            self.relu = nn.ReLU()
        elif len(output_shape)==2:  # RSAP, TISP
            self.fout = 'Conv1d'
            self.adaptive_pool = nn.AdaptiveMaxPool1d(output_shape[0])  # Adaptive pooling to ensure the exact sequence length
            self.final_conv = nn.Conv1d(channels[-1], output_shape[1], kernel_size=1)  # Adjust channels without changing length
            

    def forward(self, x):
        x = x.transpose(1, 2)  
        x = self.layers(x)
        if self.fout == 'Linear':
            x = F.max_pool1d(x, x.size(2)).squeeze()
            x = self.fc(x)
            x = self.relu(x)
        elif self.fout == 'Conv1d': 
            x = self.adaptive_pool(x)
            x = self.final_conv(x)
            x = x.transpose(1, 2) 
    
        return x

In [9]:
# eQTL & enhancer-target gene
model = SimpleCNN(channels = [128,64,32],input_length=450000)
print(model)
x = torch.zeros((2,450000,4))
print(x.shape)
out = model(x)
print(out.shape)

SimpleCNN(
  (layers): Sequential(
    (0): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(4, 128, kernel_size=(3,), stride=(1,), padding=(2,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (1): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(128, 64, kernel_size=(3,), stride=(1,), padding=(2,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (2): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(64, 32, kernel_size=(3,), stride=(1,), padding=(2,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
  )
  (fc): Linear(in_features=32, out_features=2, bias=True)
  (relu): ReLU()
)
torch.Size([2, 

In [13]:
# Sequence Activity Prediction
model = SimpleCNN(channels = [16,64,256,1024], output_shape=(896,5313), input_length=196608, kernel_sizes = [25,15,15,15])
print(model)
x = torch.zeros((2,196608,4))
print(x.shape)
out = model(x)
print(out.shape)

SimpleCNN(
  (layers): Sequential(
    (0): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(4, 16, kernel_size=(25,), stride=(1,), padding=(24,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (1): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(16, 64, kernel_size=(15,), stride=(1,), padding=(14,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (2): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(64, 256, kernel_size=(15,), stride=(1,), padding=(14,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (3): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(256, 1024, kernel_size=

In [11]:
# TISP
model = SimpleCNN(channels = [16,32,64], output_shape=(100000,10), input_length=100000)
print(model)
x = torch.zeros((2,100000,4))
print(x.shape)
out = model(x)
print(out.shape)

SimpleCNN(
  (layers): Sequential(
    (0): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(4, 16, kernel_size=(3,), stride=(1,), padding=(2,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (1): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=(2,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
    (2): ConvBlock(
      (layers): Sequential(
        (0): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(2,))
        (1): Dropout(p=0, inplace=False)
        (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU()
      )
    )
  )
  (adaptive_pool): AdaptiveMaxPool1d(output_size=100000)
  (final_conv): Conv1d(64, 10, kernel_