In [1]:
import torch
import torch.nn as nn
import torch.nn.modules.activation as activation
import math
import sys
sys.path.append('../model')  
from model import ExpActivation, Unsqueeze

### ConvNetDeep

In [2]:
c1 = nn.Conv1d(4, 100, 19)
bn1 = nn.BatchNorm1d(100)
rl1 = activation.ReLU()
mp1 = nn.MaxPool1d(3, 3)

c2 = nn.Conv1d(100, 200, 11)
bn2 = nn.BatchNorm1d(200)
rl2 = activation.ReLU()
mp2 = nn.MaxPool1d(3, 3)

c3 = nn.Conv1d(200, 200, 7)
bn3 = nn.BatchNorm1d(200)
rl3 = activation.ReLU()
mp3 = nn.MaxPool1d(4, 4) 

c7 = nn.Conv1d(200, 400, 4)
bn7 = nn.BatchNorm1d(400)
rl7 = activation.ReLU()
mp7 = nn.MaxPool1d(4, 4) 

# Block 4 : Fully Connected 1 :
d4 = nn.Linear(800, 800)  # 1000 for 200 input size
bn4 = nn.BatchNorm1d(800, 1e-05, 0.1, True)
rl4 = activation.ReLU()
dr4 = nn.Dropout(0.3)

# Block 5 : Fully Connected 2 :
d5 = nn.Linear(800, 800)
bn5 = nn.BatchNorm1d(800, 1e-05, 0.1, True)
rl5 = activation.ReLU()
dr5 = nn.Dropout(0.3)

# Block 6 :4Fully connected 3
num_classes = 1
d6 = nn.Linear(800, num_classes)
# self.sig = activation.Sigmoid()

In [23]:
x = torch.randn(10,4,608)
x = rl1(bn1(c1(x)))
print(f"Layer 1: {x.shape}")
x = mp1(x)
print(f"Layer 1: {x.shape}")


x = rl2(bn2(c2(x)))
print(f"Layer 2: {x.shape}")
x = mp2(x)
print(f"Layer 2: {x.shape}")


x = rl3(bn3(c3(x)))
print(f"Layer 3: {x.shape}")
x = mp3(x)
print(f"Layer 3: {x.shape}")

x = rl7(bn7(c7(x)))
print(f"Layer 3.5: {x.shape}")
x = mp7(x)
print(f"Layer 3.5: {x.shape}")

o = torch.flatten(x, start_dim=1)
print(f"o shape is {o.shape}")
o = rl4(bn4(d4(o)))
print(f"Layer 4: {o.shape}")
o = dr4(o)
print(f"Layer 4: {o.shape}")


o = rl5(bn5(d5(o)))
print(f"Layer 5: {o.shape}")
o = dr5(o)
print(f"Layer 5: {o.shape}")


o = d6(o)
print(f"Finalz: {o.shape}")


Layer 1: torch.Size([10, 100, 590])
Layer 1: torch.Size([10, 100, 196])
Layer 2: torch.Size([10, 200, 186])
Layer 2: torch.Size([10, 200, 62])
Layer 3: torch.Size([10, 200, 56])
Layer 3: torch.Size([10, 200, 14])
Layer 3.5: torch.Size([10, 400, 11])
Layer 3.5: torch.Size([10, 400, 2])
o shape is torch.Size([10, 800])
Layer 4: torch.Size([10, 800])
Layer 4: torch.Size([10, 800])
Layer 5: torch.Size([10, 800])
Layer 5: torch.Size([10, 800])
Finalz: torch.Size([10, 1])


In [13]:
11%2

1

### DanQ

In [2]:
class DanQ(nn.Module):
    """
    PyTorch implementation of DanQ (PMID: 27084946)
    """
    def __init__(self, input_length, num_classes, weight_path=None):
        """
        :param input_length: int, input sequence length
        :param num_classes: int, number of output classes
        :param weight_path: string, path to the file with model weights
        """
        super(DanQ, self).__init__()

        self._options = {
            "input_length": input_length,
            "num_classes": num_classes,
            "weight_path": weight_path
        }

        self.conv1 = nn.Conv1d(4, 320, kernel_size=26)
        self.act1 = nn.ReLU()
        self.maxp1 = nn.MaxPool1d(kernel_size=13, stride=13)

        self.bi_lstm_layer = nn.LSTM(320, 320, num_layers=1,
                                     batch_first=True, bidirectional=True)

        self._in_features_L1 = math.floor((input_length - 25) / 13.) * 640

        self.linear = nn.Sequential(
            nn.Linear(self._in_features_L1, 925),
            nn.ReLU(),
            nn.Linear(925, num_classes),
        )

        if weight_path:
            self.load_state_dict(torch.load(weight_path))

    def forward(self, input):      
        x = self.act1(self.conv1(input))
        x = nn.Dropout(0.2)(self.maxp1(x))
        x = x.transpose(1, 2)
        x, _ = self.bi_lstm_layer(x)
        x = x.contiguous().view(-1, self._in_features_L1)
        x = nn.Dropout(0.5)(x)
        x = self.linear(x)
        return x

In [5]:
# Usage of the DanQ model
input_length = 608
num_classes = 1

# Create an instance of the model
model = DanQ(input_length, num_classes)

# Prepare input tensor (10 samples, 4 channels, input_length of 608)
x = torch.randn(10, 4, input_length)

# Forward pass to get the output
output = model(x)

# Print the output shape
print(output.shape)

torch.Size([10, 1])


### DeepStar

In [5]:
c1 = nn.Conv1d(in_channels=4, out_channels=256, kernel_size=7,
                      padding=int((7 - 1) / 2))
b1 = nn.BatchNorm1d(256)
act1 = nn.ReLU()
max1 = nn.MaxPool1d(2, 2)
c2 = nn.Conv1d(in_channels=256, out_channels=60, kernel_size=3,
            padding=int((3 - 1) / 2))
b2 = nn.BatchNorm1d(60)
act2 = nn.ReLU()
max2 = nn.MaxPool1d(2, 2)
c3 = nn.Conv1d(in_channels=60, out_channels=60, kernel_size=5,
            padding=int((5 - 1) / 2))
b3 = nn.BatchNorm1d(60)
act3 = nn.ReLU()
max3 = nn.MaxPool1d(2, 2)
c4 = nn.Conv1d(in_channels=60, out_channels=120, kernel_size=3,
            padding=int((3 - 1) / 2))
b4 = nn.BatchNorm1d(120)
act4 = nn.ReLU()
max4 = nn.MaxPool1d(2, 2)
flat4 = nn.Flatten()

In [7]:
x = torch.randn(10,4,200)
print(f"{x.shape}")
print("Conv1")
x = c1(x)
print(f"{x.shape}")
x = b1(x)
print(f"{x.shape}")
x = act1(x)
print(f"{x.shape}")
x = max1(x)
print(f"{x.shape}")
print("Conv2")
x = c2(x)
print(f"{x.shape}")
x = b2(x)
print(f"{x.shape}")
x = act2(x)
print(f"{x.shape}")
x = max2(x)
print(f"{x.shape}")
print("Conv3")
x = c3(x)
print(f"{x.shape}")
x = b3(x)
print(f"{x.shape}")
x = act3(x)
print(f"{x.shape}")
x = max3(x)
print(f"{x.shape}")
print("Conv4")
x = c4(x)
print(f"{x.shape}")
x = b4(x)
print(f"{x.shape}")
x = act4(x)
print(f"{x.shape}")
x = max4(x)
print(f"{x.shape}")
x = flat4(x)
print(f"{x.shape}")


torch.Size([10, 4, 200])
Conv1
torch.Size([10, 256, 200])
torch.Size([10, 256, 200])
torch.Size([10, 256, 200])
torch.Size([10, 256, 100])
Conv2
torch.Size([10, 60, 100])
torch.Size([10, 60, 100])
torch.Size([10, 60, 100])
torch.Size([10, 60, 50])
Conv3
torch.Size([10, 60, 50])
torch.Size([10, 60, 50])
torch.Size([10, 60, 50])
torch.Size([10, 60, 25])
Conv4
torch.Size([10, 120, 25])
torch.Size([10, 120, 25])
torch.Size([10, 120, 25])
torch.Size([10, 120, 12])
torch.Size([10, 1440])


In [15]:
convol = nn.Sequential(
            nn.Conv1d(in_channels=4, out_channels=256, kernel_size=7,
                      padding=int((7 - 1) / 2)),  # same padding
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(2, 2),
            nn.Conv1d(in_channels=256, out_channels=60, kernel_size=3,
                      padding=int((3 - 1) / 2)),  # same padding
            nn.BatchNorm1d(60),
            nn.ReLU(),
            nn.MaxPool1d(2, 2),
            nn.Conv1d(in_channels=60, out_channels=60, kernel_size=5,
                      padding=int((5 - 1) / 2)),  # same padding
            nn.BatchNorm1d(60),
            nn.ReLU(),
            nn.MaxPool1d(2, 2),
            nn.Conv1d(in_channels=60, out_channels=120, kernel_size=3,
                      padding=int((3 - 1) / 2)),  # same padding
            nn.BatchNorm1d(120),
            nn.ReLU(),
            nn.MaxPool1d(2, 2),
            nn.Flatten()
        )

linear = nn.Sequential(
            nn.Linear(4560, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 2)
        )

In [16]:
x = torch.randn(10,4,608)
o = convol(x)
print(o.shape)
o = linear(o)
print(o.shape)

torch.Size([10, 4560])
torch.Size([10, 2])


### ExplaiNN

In [3]:
class ExplaiNN(nn.Module):
    """
    The ExplaiNN model (PMID: 37370113)
    """
    def __init__(self, num_cnns, input_length, num_classes, filter_size=19, num_fc=2, pool_size=7, pool_stride=7,
                 weight_path=None):
        """
        :param num_cnns: int, number of independent cnn units
        :param input_length: int, input sequence length
        :param num_classes: int, number of outputs
        :param filter_size: int, size of the unit's filter, default=19
        :param num_fc: int, number of FC layers in the unit, default=2
        :param pool_size: int, size of the unit's maxpooling layer, default=7
        :param pool_stride: int, stride of the unit's maxpooling layer, default=7
        :param weight_path: string, path to the file with model weights
        """
        super(ExplaiNN, self).__init__()

        self._options = {
            "num_cnns": num_cnns,
            "input_length": input_length,
            "num_classes": num_classes,
            "filter_size": filter_size,
            "num_fc": num_fc,
            "pool_size": pool_size,
            "pool_stride": pool_stride,
            "weight_path": weight_path
        }

        if num_fc == 0:
            self.linears = nn.Sequential(
                nn.Conv1d(in_channels=4 * num_cnns, out_channels=1 * num_cnns, kernel_size=filter_size,
                          groups=num_cnns),
                nn.BatchNorm1d(num_cnns),
                ExpActivation(),
                nn.MaxPool1d(input_length - (filter_size-1)),
                nn.Flatten())
        elif num_fc == 1:
            self.linears = nn.Sequential(
                nn.Conv1d(in_channels=4 * num_cnns, out_channels=1 * num_cnns, kernel_size=filter_size,
                          groups=num_cnns),
                nn.BatchNorm1d(num_cnns),
                ExpActivation(),
                nn.MaxPool1d(pool_size, pool_stride),
                nn.Flatten(),
                Unsqueeze(),
                nn.Conv1d(in_channels=int(((input_length - (filter_size-1)) - (pool_size-1)-1)/pool_stride + 1) * num_cnns,
                          out_channels=1 * num_cnns, kernel_size=1,
                          groups=num_cnns),
                nn.BatchNorm1d(1 * num_cnns, 1e-05, 0.1, True),
                nn.ReLU(),
                nn.Flatten())
        elif num_fc == 2:
            self.linears = nn.Sequential(
                nn.Conv1d(in_channels=4 * num_cnns, out_channels=1 * num_cnns, kernel_size=filter_size,
                          groups=num_cnns),
                nn.BatchNorm1d(num_cnns),
                ExpActivation(),
                nn.MaxPool1d(pool_size, pool_stride),
                nn.Flatten(),
                Unsqueeze(),
                nn.Conv1d(in_channels=int(((input_length - (filter_size-1)) - (pool_size-1)-1)/pool_stride + 1) * num_cnns,
                          out_channels=100 * num_cnns, kernel_size=1,
                          groups=num_cnns),
                nn.BatchNorm1d(100 * num_cnns, 1e-05, 0.1, True),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Conv1d(in_channels=100 * num_cnns,
                          out_channels=1 * num_cnns, kernel_size=1,
                          groups=num_cnns),
                nn.BatchNorm1d(1 * num_cnns, 1e-05, 0.1, True),
                nn.ReLU(),
                nn.Flatten())
        else:
            self.linears = nn.Sequential(
                nn.Conv1d(in_channels=4 * num_cnns, out_channels=1 * num_cnns, kernel_size=filter_size,
                          groups=num_cnns),
                nn.BatchNorm1d(num_cnns),
                ExpActivation(),
                nn.MaxPool1d(pool_size, pool_stride),
                nn.Flatten(),
                Unsqueeze(),
                nn.Conv1d(in_channels=int(((input_length - (filter_size-1)) - (pool_size-1)-1)/pool_stride + 1) * num_cnns,
                          out_channels=100 * num_cnns, kernel_size=1,
                          groups=num_cnns),
                nn.BatchNorm1d(100 * num_cnns, 1e-05, 0.1, True),
                nn.ReLU())

            self.linears_bg = nn.ModuleList([nn.Sequential(nn.Dropout(0.3),
                                                           nn.Conv1d(in_channels=100 * num_cnns,
                                                                     out_channels=100 * num_cnns, kernel_size=1,
                                                                     groups=num_cnns),
                                                           nn.BatchNorm1d(100 * num_cnns, 1e-05, 0.1, True),
                                                           nn.ReLU()) for i in range(num_fc - 2)])

            self.last_linear = nn.Sequential(nn.Dropout(0.3),
                                             nn.Conv1d(in_channels=100 * num_cnns, out_channels=1 * num_cnns,
                                                       kernel_size=1,
                                                       groups=num_cnns),
                                             nn.BatchNorm1d(1 * num_cnns, 1e-05, 0.1, True),
                                             nn.ReLU(),
                                             nn.Flatten())

        self.final = nn.Linear(num_cnns, num_classes)

        if weight_path:
            self.load_state_dict(torch.load(weight_path))

    def forward(self, x):
        x = x.repeat(1, self._options["num_cnns"], 1)
        print(f"x.repeat {x.shape}")
        if self._options["num_fc"] <= 2:
            outs = self.linears(x)
        else:
            outs = self.linears(x)
            for i in range(len(self.linears_bg)):
                outs = self.linears_bg[i](outs)
            outs = self.last_linear(outs)
        print(f"outs {outs.shape}")
        out = self.final(outs)
        print(f"out {out.shape}")
        return out


In [4]:
model = ExplaiNN(num_cnns = 5, input_length = 608, num_classes = 2, filter_size = 19, num_fc = 2, pool_size = 7,
                  pool_stride = 7, weight_path = None)

In [28]:
num_cnns = 50
filter_size = 19
input_length = 608
num_classes = 2

conv = nn.Conv1d(in_channels=4*num_cnns, out_channels=1*num_cnns, kernel_size=filter_size, groups = num_cnns)
bc = nn.BatchNorm1d(num_cnns)
act = ExpActivation()
max = nn.MaxPool1d(608-19+1)
flat = nn.Flatten()
sqz = Unsqueeze()

final = nn.Linear(num_cnns, num_classes)
con2 = nn.Conv1d(in_channels=int(((input_length - (filter_size-1)) - (pool_size-1)-1)/pool_stride + 1) * num_cnns,
                          out_channels=1 * num_cnns, kernel_size=1,
                          groups=num_cnns)
bc2 = nn.BatchNorm1d(1 * num_cnns, 1e-05, 0.1, True)


In [4]:
pool_size = 7
pool_stride = 7
num_cnns = 5
input_length = 608
num_classes = 2
filter_size = 19
num_fc = 2

# fc = 0
linears = nn.Sequential(
                nn.Conv1d(in_channels=4 * num_cnns, out_channels=1 * num_cnns, kernel_size=filter_size,
                          groups=num_cnns),
                nn.BatchNorm1d(num_cnns),
                ExpActivation(),
                nn.MaxPool1d(input_length - (filter_size-1)),
                nn.Flatten())
# fc = 1
linears = nn.Sequential(
                nn.Conv1d(in_channels=4 * num_cnns, out_channels=1 * num_cnns, kernel_size=filter_size,
                          groups=num_cnns),
                nn.BatchNorm1d(num_cnns),
                ExpActivation(),
                nn.MaxPool1d(pool_size, pool_stride),
                nn.Flatten(),
                Unsqueeze(),
                nn.Conv1d(in_channels=int(((input_length - (filter_size-1)) - (pool_size-1)-1)/pool_stride + 1) * num_cnns,
                          out_channels=1 * num_cnns, kernel_size=1,
                          groups=num_cnns),
                nn.BatchNorm1d(1 * num_cnns, 1e-05, 0.1, True),
                nn.ReLU(),
                nn.Flatten())

linears = nn.Sequential(
                nn.Conv1d(in_channels=4 * num_cnns, out_channels=1 * num_cnns, kernel_size=filter_size,
                          groups=num_cnns),
                nn.BatchNorm1d(num_cnns),
                ExpActivation(),
                nn.MaxPool1d(pool_size, pool_stride),
                nn.Flatten(),
                Unsqueeze(),
                nn.Conv1d(in_channels=int(((input_length - (filter_size-1)) - (pool_size-1)-1)/pool_stride + 1) * num_cnns,
                          out_channels=100 * num_cnns, kernel_size=1,
                          groups=num_cnns),
                nn.BatchNorm1d(100 * num_cnns, 1e-05, 0.1, True),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Conv1d(in_channels=100 * num_cnns,
                          out_channels=1 * num_cnns, kernel_size=1,
                          groups=num_cnns),
                nn.BatchNorm1d(1 * num_cnns, 1e-05, 0.1, True),
                nn.ReLU(),
                nn.Flatten())

In [10]:
#### Test when fc == 2
conv = nn.Conv1d(in_channels=4 * num_cnns, out_channels=1 * num_cnns, kernel_size=filter_size,
                          groups=num_cnns)
bc = nn.BatchNorm1d(num_cnns)
act = ExpActivation()
max = nn.MaxPool1d(pool_size, pool_stride)
flat = nn.Flatten()
sqz = Unsqueeze()
con2 = nn.Conv1d(in_channels=int(((input_length - (filter_size-1)) - (pool_size-1)-1)/pool_stride + 1) * num_cnns,
                          out_channels=100 * num_cnns, kernel_size=1,
                          groups=num_cnns)
bc2 = nn.BatchNorm1d(100 * num_cnns, 1e-05, 0.1, True)
relu2 = nn.ReLU()
dp2 = nn.Dropout(0.3)
conv3 = nn.Conv1d(in_channels=100 * num_cnns,
                          out_channels=1 * num_cnns, kernel_size=1,
                          groups=num_cnns)
bn3 = nn.BatchNorm1d(1 * num_cnns, 1e-05, 0.1, True)
relu3 = nn.ReLU()
flat3 = nn.Flatten()
final = nn.Linear(num_cnns, num_classes)


In [11]:
#### Test when fc == 2
x = torch.randn(10,4,608)
print(f"input x shape {x.shape}")
x = x.repeat(1, num_cnns, 1)
print(f"Add num_cnns repeat {x.shape}")
x = conv(x)
print(x.shape)
x = bc(x)
print(f"Batch Norm {x.shape}")
x = act(x)
print(x.shape)
x = max(x)
print(f"max pooling {x.shape}")
x = flat(x)
print(x.shape)
x = sqz(x)
print(f"squeez {x.shape}")
x = con2(x)
print(x.shape)
x = bc2(x)
print(x.shape)
x = relu2(x)
print(f"relu2 {x.shape}")
x = dp2(x)
print(f"dp2 {x.shape}")
x = conv3(x)
print(f"conv3 {x.shape}")
x = bn3(x)
print(f"bn3 {x.shape}")
x = relu3(x)
print(f"relu3 {x.shape}")
x = flat3(x)
print(f"flat3 {x.shape}")
x = final(x)
print(f"final {x.shape}")


input x shape torch.Size([10, 4, 608])
Add num_cnns repeat torch.Size([10, 20, 608])
torch.Size([10, 5, 590])
Batch Norm torch.Size([10, 5, 590])
torch.Size([10, 5, 590])
max pooling torch.Size([10, 5, 84])
torch.Size([10, 420])
squeez torch.Size([10, 420, 1])
torch.Size([10, 500, 1])
torch.Size([10, 500, 1])
relu2 torch.Size([10, 500, 1])
dp2 torch.Size([10, 500, 1])
conv3 torch.Size([10, 5, 1])
bn3 torch.Size([10, 5, 1])
relu3 torch.Size([10, 5, 1])
flat3 torch.Size([10, 5])
final torch.Size([10, 2])


In [32]:
x = torch.randn(10,4,608)
print(x.shape)
x = x.repeat(1, num_cnns, 1)
print(x.shape)
x = linears(x)
print(x.shape)

x = final(x)
print(x.shape)

torch.Size([10, 4, 608])


NameError: name 'num_cnns' is not defined

In [30]:
num_motifs = 256
x1 = torch.randn(10, 4, 608)
#for i in range (num_motifs):
conv1 = nn.Conv1d(in_channels=4, out_channels=1, kernel_size=19, padding = 'same')
bc1 = nn.BatchNorm1d(1)
act1 = ExpActivation()
max1 = nn.MaxPool1d(7, 7)
flat1 = nn.Flatten()

ln2 = nn.Linear(86, 20)
bc2 = nn.BatchNorm1d(20)
act2 = ExpActivation()
drop2 = nn.Dropout(p=0.3)

ln3 = nn.Linear(20, 1)
bc3 = nn.BatchNorm1d(1)
act3 = ExpActivation()
flat3 = nn.Flatten()

encoder = []
for i in range(256):
    x = conv1(x1)
    print(x.shape)
    x = bc1(x)
    print(f"Batch Norm {x.shape}")
    x = act1(x)
    print(x.shape)
    x = max1(x)
    print(f"max pooling {x.shape}")
    x = flat1(x)
    print(x.shape)


    print("Linear Layer 2")
    x = ln2(x)
    print(x.shape)
    x = bc2(x)
    x = act1(x)
    print(f"Batch Norm 2 {x.shape}")
    x = drop2(x)
    print(x.shape)

    print("Linear Layer 3")
    x = ln3(x)
    print(x.shape)
    x = bc3(x)
    print(f"Batch Norm 2 {x.shape}")
    x = act1(x)
    x = flat3(x)
    print(x.shape)
    encoder.append(x)

results = torch.cat(encoder, dim=-1)





torch.Size([10, 1, 608])
Batch Norm torch.Size([10, 1, 608])
torch.Size([10, 1, 608])
max pooling torch.Size([10, 1, 86])
torch.Size([10, 86])
Linear Layer 2
torch.Size([10, 20])
Batch Norm 2 torch.Size([10, 20])
torch.Size([10, 20])
Linear Layer 3
torch.Size([10, 1])
Batch Norm 2 torch.Size([10, 1])
torch.Size([10, 1])
torch.Size([10, 1, 608])
Batch Norm torch.Size([10, 1, 608])
torch.Size([10, 1, 608])
max pooling torch.Size([10, 1, 86])
torch.Size([10, 86])
Linear Layer 2
torch.Size([10, 20])
Batch Norm 2 torch.Size([10, 20])
torch.Size([10, 20])
Linear Layer 3
torch.Size([10, 1])
Batch Norm 2 torch.Size([10, 1])
torch.Size([10, 1])
torch.Size([10, 1, 608])
Batch Norm torch.Size([10, 1, 608])
torch.Size([10, 1, 608])
max pooling torch.Size([10, 1, 86])
torch.Size([10, 86])
Linear Layer 2
torch.Size([10, 20])
Batch Norm 2 torch.Size([10, 20])
torch.Size([10, 20])
Linear Layer 3
torch.Size([10, 1])
Batch Norm 2 torch.Size([10, 1])
torch.Size([10, 1])
torch.Size([10, 1, 608])
Batch Nor

In [19]:
class ExplaiNN2(nn.Module):
    """
    Create num_cnns CNN. Each CNN 
    [B, 4, 608] -> Conv1d(kernal=19) -> [B, 1, 608] -> BN -> ExpAct -> MaxPool(7,7,padding = 'same')
    -> [B, 1, 86] -> [B, 86] -> Linear(86, fc_filter = 20) -> [B, 20] -> BN -> ExpAct -> Dropout ->
    Linear(20, fc_filter = 1) -> [B,1] -> BN -> ExpAct -> flat -> [B, 1]

    """
    def __init__(self, num_cnns, input_length, num_classes, 
                 filter_size = 19, num_fc=2, pool_size=7, pool_stride=7, 
                 fc_filter1 = 20, fc_filter2 = 1, drop_out = 0.3, weight_path = None):
        super(ExplaiNN2, self).__init__()
        self._options = {
            "num_cnns": num_cnns,
            "input_length": input_length,
            "num_classes": num_classes,
            "filter_size": filter_size,
            "num_fc": num_fc,
            "pool_size": pool_size,
            "pool_stride": pool_stride,
            "weight_path": weight_path
        }
        self.linears = nn.Sequential(
            # Convolution Layer
            nn.Conv1d(in_channels=4, out_channels=1, kernel_size=filter_size, padding = 'same'),
            nn.BatchNorm1d(1),
            ExpActivation(),
            nn.MaxPool1d(pool_size, pool_stride), # pool_size=7, pool_stride=7
            nn.Flatten(),
            # Linear Layer 1
            nn.Linear(((input_length-pool_size)//pool_stride) + 1, fc_filter1),
            nn.BatchNorm1d(fc_filter1),
            ExpActivation(),
            nn.Dropout(p=drop_out),
            # Linear Layer 2
            nn.Linear(fc_filter1, fc_filter2),
            nn.BatchNorm1d(fc_filter2),
            ExpActivation(),
            nn.Flatten()
        )
        self.final = nn.Linear(num_cnns*fc_filter2, num_classes)

        if weight_path:
            self.load_state_dict(torch.load(weight_path))
        
    def forward(self, x):
        encoder = []
        for i in range(self._options['num_cnns']):
            xnn = self.linears(x)
            encoder.append(xnn)
        encoder = torch.cat(encoder, dim=-1)
        results = self.final(encoder)
        return x, results



In [20]:
x = torch.randn(10, 4, 608)
model = ExplaiNN2(num_cnns = 256, input_length = 608, num_classes = 1, 
                 filter_size = 19, num_fc=2, pool_size=7, pool_stride=7, 
                 fc_filter1 = 20, fc_filter2 = 1, drop_out = 0.3, weight_path = None)
input, output = model(x)
print(output.shape)

torch.Size([10, 1])


In [4]:
class ExplaiNN3(nn.Module):
    """
    [B,4,608] 
    -> replicate => [B, 4 * n_cnn, 608] 
    -> Conv1d(4 * n_cnn, n_cnn, kernel = 19) => [B, n_cnn, 590] 
    -> ExpAct => [B, n_cnn, 590] 
    -> MaxPool (7,7) => [B, n_cnn, 84] 
    -> flat => [B, n_cnn * 84] 
    -> Unsqueeze => [B, n_cnn * 84, 1] 
    -> Conv1d (n_cnn * 84, n_cnn * 100, kernel = 1) => [B, n_cnn * 100, 1] 
    -> BatchNorm -> ReLu -> Drop_out 0.3 => [B, n_cnn * 100, 1] 
    -> Conv1d (n_cnn * 100, n_cnn * 1, kernel = 1) => [B, n_cnn, 1]
    -> BatchNorm -> ReLu => [B, n_cnn, 1]
    -> flat => [B, n_cnn] 
    -> Linear (n_cnn, num_classes) => [B, num_classes]
    """
    def __init__(self, num_cnns, input_length, num_classes, 
                 filter_size = 19, num_fc=2, pool_size=7, pool_stride=7, 
                 drop_out = 0.3, weight_path = None):
        super(ExplaiNN3, self).__init__()
        self._options = {
            "num_cnns": num_cnns,
            "input_length": input_length,
            "num_classes": num_classes,
            "filter_size": filter_size,
            "num_fc": num_fc,
            "pool_size": pool_size,
            "pool_stride": pool_stride,
            "weight_path": weight_path
        }
        self.linears = nn.Sequential(
            nn.Conv1d(in_channels=4 * num_cnns, out_channels=1 * num_cnns, kernel_size=filter_size,
                        groups=num_cnns),
            nn.BatchNorm1d(num_cnns),
            ExpActivation(),
            nn.MaxPool1d(pool_size, pool_stride),
            nn.Flatten(),
            Unsqueeze(),
            nn.Conv1d(in_channels=int(((input_length - (filter_size-1)) - (pool_size-1)-1)/pool_stride + 1) * num_cnns,
                        out_channels=100 * num_cnns, kernel_size=1,
                        groups=num_cnns),
            nn.BatchNorm1d(100 * num_cnns, 1e-05, 0.1, True),
            nn.ReLU(),
            nn.Dropout(drop_out),
            nn.Conv1d(in_channels=100 * num_cnns,
                        out_channels=1 * num_cnns, kernel_size=1,
                        groups=num_cnns),
            nn.BatchNorm1d(1 * num_cnns, 1e-05, 0.1, True),
            nn.ReLU(),
            nn.Flatten()
            )
        self.final = nn.Linear(num_cnns, num_classes)

        if weight_path:
            self.load_state_dict(torch.load(weight_path))
        
    def forward(self, x):
        x = x.repeat(1, self._options["num_cnns"], 1)
        print(x.shape)
        outs = self.linears(x)
        print(outs.shape)
        results = self.final(outs)
        print(results.shape)
        return results

In [5]:
x = torch.randn(10, 4, 608)
model = ExplaiNN3(num_cnns = 256, input_length = 608, num_classes = 2, 
                 filter_size = 19, num_fc=2, pool_size=7, pool_stride=7, 
                 drop_out = 0.3, weight_path = None)

In [6]:
output = model(x)

torch.Size([10, 1024, 608])
torch.Size([10, 256])
torch.Size([10, 2])


In [7]:
x = torch.rand(322, 4, 608)
x = x.repeat(1, 100, 1)
model = ExplaiNN3(num_cnns = 100, input_length = 608, num_classes = 1, 
                 filter_size = 19, num_fc=2, pool_size=7, pool_stride=7, 
                 drop_out = 0.3, weight_path = None)

In [9]:
act = model.linears[:3]
b = act(x)
b.shape

torch.Size([322, 100, 590])

In [17]:
import numpy as np
activations = torch.randn(10, 100, 590)
activations = np.array(activations)
print(activations.shape)
activation_threshold = 0.5 * np.amax(activations, axis=(0, 2))
print(activation_threshold.shape)
n_filters = activations.shape[1]
pwm = np.full((n_filters, 4, 19), .25)
n_samples = activations.shape[0]
activation_indices = []
for i in range(n_filters):
    act_seqs_list = []
    for j in range(n_samples):
        indices = np.where(activations[j,i,:] > activation_threshold[i])
        for start in indices[0]:
            activation_indices.append(start)
            end = start + filter_size
            act_seqs_list.append(sequences[j, :, start:end])

(10, 100, 590)
(100,)
(array([ 27, 100, 122, 123, 139, 200, 204, 212, 226, 238, 277, 281, 334,
       335, 357, 368, 447, 451, 463, 488, 501, 568, 587]),)
(array([  0,  17,  38,  41,  52,  70, 122, 138, 140, 186, 277, 323, 332,
       339, 348, 358, 396, 444, 447, 476, 501, 515, 523, 549]),)
(array([ 65, 104, 194, 232, 315, 399, 420, 423, 435, 482, 539, 572]),)
(array([ 28,  55,  67,  91,  95, 109, 126, 177, 184, 191, 200, 257, 264,
       283, 293, 359, 395, 407, 409, 433, 439, 444, 456, 513, 521, 523,
       552, 584]),)
(array([  9,  40, 128, 196, 247, 265, 268, 325, 371, 424, 466, 483, 496,
       509, 519, 535, 549, 552]),)
(array([126, 174, 194, 238, 265, 325, 357, 358, 419, 428, 442, 447, 450,
       454, 478, 506, 563, 581]),)
(array([ 30,  37,  88, 200, 209, 221, 225, 291, 299, 311, 319, 408, 531,
       577]),)
(array([ 53, 116, 125, 187, 198, 204, 209, 298, 301, 349, 376, 406, 430,
       433, 454, 456, 475, 494, 502, 529, 546, 549]),)
(array([ 24,  26,  50,  90, 104, 106, 1

In [None]:
a = np.full((100, 4, 19), .25).shape
type(a)

tuple