In [260]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
import torch
import torch.nn as nn
import torchvision
from torch.linalg import norm as tnorm
from collections import OrderedDict as od

In [318]:
class ResBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride =  1, groups = 1,
                 use_bn = True, use_drop = True, is_first_block = False):

        super(ResBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups
        self.use_bn = use_bn
        self.use_drop = use_drop
        self.is_first_block = is_first_block
        #first block

        self.conv1 = torch.nn.Conv1d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = self.kernel_size,
            padding = "same",
            stride = self.stride,
            groups = self.groups,
            bias = False)

        self.bn1 = nn.BatchNorm1d(self.out_channels)
        self.relu1 = nn.ReLU()
        self.drop1= nn.Dropout(0.1)
        self.max_pool = nn.MaxPool1d(self.kernel_size)

        #second block
        self.bn2 = nn.BatchNorm1d(self.out_channels)
        self.relu2 = nn.ReLU()
        self.drop2 = nn.Dropout(0.1)
        self.conv2 = torch.nn.Conv1d(
            in_channels = self.out_channels,
            out_channels = self.out_channels,
            kernel_size = self.kernel_size,
            padding = "same",
            stride = 1,
            groups = self.groups,
            bias = False)

    def forward(self, x):
        #First model applying everything
        residual = x
        out = x

        if not self.is_first_block:
            if self.use_bn:
                out = self.bn1(out)
            out = self.relu1(out)
            if self.use_drop:
                out = self.drop1(out)

        out = self.conv1(out)
        # the second conv
        if self.use_bn:
            out = self.bn2(out)
        out = self.relu2(out)
        if self.use_drop:
            out = self.drop2(out)
        out = self.conv2(out)

        # shortcut
        out += residual

        return out

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv1d(inplanes, planes, kernel_size = 1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size = 3, stride=stride,
                               padding = 1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size = 1, bias=False)
        self.bn3 = nn.BatchNorm1d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        # SE
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.conv_down = nn.Conv1d(
            planes * 4, planes // 4, kernel_size = 1, bias=False)
        self.conv_up = nn.Conv1d(
            planes // 4, planes * 4, kernel_size = 1, bias=False)
        self.sig = nn.Sigmoid()
        # Downsample
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out1 = self.global_pool(out)
        out1 = self.conv_down(out1)
        out1 = self.relu(out1)
        out1 = self.conv_up(out1)
        out1 = self.sig(out1)

        if self.downsample is not None:
            residual = self.downsample(x)

        res = out1 * out + residual
        res = self.relu(res)

        return res


class SEResNet(nn.Module):

    def __init__(self, block, layers, num_classes = 100):
        self.inplanes = 32
        super(SEResNet, self).__init__()
        self.conv1 = nn.Conv1d(1, 64, kernel_size = 7, stride=2,
                               padding=3, bias=False)
        self.conv2 = nn.Conv1d(64, 32, kernel_size = 7, stride = 2,
                               padding = 3, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(32)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool1d(7)
        self.fc = nn.Linear(2048, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv1d(self.inplanes, planes * block.expansion,
                          kernel_size = 1, stride = stride, bias = False),
                nn.BatchNorm1d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x.view(x.shape[0], 1, x.shape[-1])

class SEResCnn(nn.Module):

    def __init__(self, block, layers, input_din, output_din):

        super(SEResCnn, self).__init__()
        self.block = block
        self.layers = layers
        self.input_din = input_din
        self.output_din = output_din
        self.intermedial_layers = [1024, 512, 256, 128]
        self.conv1 = nn.Conv1d(self.intermedial_layers[0], self.intermedial_layers[0], kernel_size = 3, stride = 2, padding=1,
                               bias = False)
        self.conv2 = nn.Conv1d(self.intermedial_layers[0], self.intermedial_layers[1], kernel_size = 3, stride = 2, padding=1,
                               bias = False)
        self.conv3 = nn.Conv1d(self.intermedial_layers[1], self.intermedial_layers[2], kernel_size = 3, stride = 2, padding=1,
                               bias = False)
        self.conv4 = nn.Conv1d(self.intermedial_layers[2], self.intermedial_layers[3], kernel_size = 3, stride=2, padding=1,
                               bias = False)
        self.bn1 = nn.BatchNorm1d(self.intermedial_layers[0])
        self.bn2 = nn.BatchNorm1d(self.intermedial_layers[1])
        self.bn3 = nn.BatchNorm1d(self.intermedial_layers[2])
        self.bn4 = nn.BatchNorm1d(self.intermedial_layers[3])
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool1d(kernel_size = 3, stride = 1, padding = 1)

        self.se_layer = SEResNet(self.block, self.layers, self.output_din)
        self.resblock_layer1 = ResBlock(1, self.intermedial_layers[0], 3, is_first_block = True)
        self.resblock_layer2 = ResBlock(self.intermedial_layers[0], self.intermedial_layers[0], 3)
        self.resblock_layer3 = ResBlock(self.intermedial_layers[1], self.intermedial_layers[1], 3)
        self.resblock_layer4 = ResBlock(self.intermedial_layers[2], self.intermedial_layers[2], 3)

    def forward(self, x):
        x = self.se_layer(x)
        x = self.resblock_layer1(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x1 = self.maxpool(x)

        x2 = self.resblock_layer2(x1)
        x2 = self.conv2(x2)
        x2 = self.bn2(x2)
        x2 = self.relu(x2)
        x2 = self.maxpool(x2)

        x3 = self.resblock_layer3(x2)
        x3 = self.conv3(x3)
        x3 = self.bn3(x3)
        x3 = self.relu(x3)
        x3 = self.maxpool(x3)

        x4 = self.resblock_layer4(x3)
        x4 = self.conv4(x4)
        x4 = self.bn4(x4)
        x4 = self.relu(x4)
        x4 = self.maxpool(x4)

        return x1, x2, x3, x4



class BPNET_blk1(nn.Module):

    def __init__(self, block, layers, input_din, output_din):

        super(BPNET_blk1, self).__init__()
        self.block = block
        self.layers = layers
        self.input_din = input_din
        self.output_din = output_din

        self.ecg_branch = SEResCnn(block, layers, self.input_din, self.output_din)
        self.ppg_branch = SEResCnn(block, layers, self.input_din, self.output_din)

    def forward(self, x):

        x1 = x[:,:, 1:1250]
        x2 = x[:,:, 1250:]
        x1_ecg, x2_ecg, x3_ecg, x4_ecg = self.ecg_branch(x1)
        x1_ppg, x2_ppg, x3_ppg, x4_ppg = self.ppg_branch(x2)

        #Normalizing the Feature Vectors:
        #ECG-Norm
        aux_1_ecg = tnorm(x1_ecg, ord = 2, keepdim = True, dim = 2)
        x1_ecg_norm = x1_ecg / aux_1_ecg
        aux_2_ecg = tnorm(x2_ecg, ord = 2, keepdim = True, dim = 2)
        x2_ecg_norm = x2_ecg / aux_2_ecg
        aux_3_ecg = tnorm(x3_ecg, ord = 2, keepdim = True, dim = 2)
        x3_ecg_norm = x3_ecg / aux_3_ecg
        aux_4_ecg = tnorm(x4_ecg, ord = 2, keepdim = True, dim = 2)
        x4_ecg_norm = x4_ecg / aux_4_ecg

        #PPG-Norm
        aux_1_ppg = tnorm(x1_ppg, ord = 2, keepdim = True, dim = 2)
        x1_ppg_norm = x1_ppg / aux_1_ppg
        aux_2_ppg = tnorm(x2_ppg, ord = 2, keepdim = True, dim = 2)
        x2_ppg_norm = x2_ppg / aux_2_ppg
        aux_3_ppg = tnorm(x3_ppg, ord = 2, keepdim = True, dim = 2)
        x3_ppg_norm = x3_ppg / aux_3_ppg
        aux_4_ppg = tnorm(x4_ppg, ord = 2, keepdim = True, dim = 2)
        x4_ppg_norm = x4_ppg / aux_4_ppg

        x1_cat = torch.cat((x1_ecg_norm, x1_ppg_norm), axis = -1)
        x2_cat = torch.cat((x2_ecg_norm, x2_ppg_norm), axis = -1)
        x3_cat = torch.cat((x3_ecg_norm, x3_ppg_norm), axis = -1)
        x4_cat = torch.cat((x4_ecg_norm, x4_ppg_norm), axis = -1)

        return x1_cat, x2_cat, x3_cat, x4_cat

class FPN_SERESCNN(nn.Module):

    def __init__(self, layers, filters_out):

        super(FPN_SERESCNN, self).__init__()

        self.layers = layers
        self.filters_out = filters_out
        self.fpn = torchvision.ops.FeaturePyramidNetwork(self.layers, self.filters_out)

    def forward(self, x1, x2, x3, x4):

        xx = od()
        xx['feat4'] = torch.unsqueeze(x1, dim = 2)
        xx['feat3'] = torch.unsqueeze(x2, dim = 2)
        xx['feat2'] = torch.unsqueeze(x3, dim = 2)
        xx['feat1'] = torch.unsqueeze(x4, dim = 2)
        out = self.fpn(xx)
        aux1 = nn.functional.interpolate(out['feat1'].squeeze(), scale_factor = 2, mode = "linear", align_corners = True)
        aux3 = out['feat2'].squeeze() + aux1
        aux3_2 = nn.functional.interpolate(aux3, scale_factor = 2, mode = 'linear', align_corners = True)
        aux4 = out['feat3'].squeeze() + aux3_2
        aux4_2 = nn.functional.interpolate(aux4, scale_factor = 2, mode = 'linear', align_corners = True)
        final = out['feat4'].squeeze() + aux4_2
        return final


class BPNET_blk2(nn.Module):

    def __init__(self, block, layers, input_din, output_din, layers_fpn = [1024, 512, 256, 128], filters_out = 5):

        super(BPNET_blk2, self).__init__()

        self.block = block
        self.layers = layers
        self.layers_fpn = layers_fpn
        self.input_din = input_din
        self.output_din = output_din
        self.filters_out = filters_out
        self.blk1 = BPNET_blk1(self.block, self.layers, self.input_din, 1024)
        self.fpn = FPN_SERESCNN(self.layers_fpn, self.filters_out)
        self.conv1 = nn.Conv1d(self.filters_out, 32, kernel_size = 3, stride = 1, padding = "valid",
                               bias = False)
        self.conv2 = nn.Conv1d(32, 16, kernel_size = 3, stride = 2, padding = "valid",
                               bias = False)

        self.conv3 = nn.Conv1d(16, 8, kernel_size = 3, stride = 2,  padding = "valid",
                               bias = False)

    def forward(self, x):

        xx1, xx2, xx3, xx4 = self.blk1(x)
        out = self.fpn(xx1, xx2, xx3, xx4)
        out = self.conv1(out)
        out = self.conv2(out)
        out = self.conv3(out)
        return out

class BPNET_blk3(nn.Module):

    def __init__(self, block, layers, input_din, output_din, layers_fpn = [1024, 512, 256, 128], filters_out = 5):

        super(BPNET_blk3, self).__init__()
        self.block = block
        self.layers = layers
        self.layers_fpn = layers_fpn
        self.input_din = input_din
        self.output_din = output_din
        self.filters_out = filters_out

        self.bpnet = BPNET_blk2(self.block, self.layers, self.input_din, self.output_din, self.layers_fpn, self.filters_out)

        self.mlp_ecg = nn.Sequential(nn.Linear(in_features = 254*8, out_features = 1024, bias = True),
                                     nn.ReLU(),
                                     nn.Linear(in_features = 1024, out_features = 512, bias = True),
                                     nn.AdaptiveAvgPool1d(1))

        self.mlp_ppg = nn.Sequential(nn.Linear(in_features = 254*8, out_features = 1024, bias = True),
                                     nn.ReLU(),
                                     nn.Linear(in_features = 1024, out_features = 512, bias = True),
                                     nn.AdaptiveAvgPool1d(1))

    def forward(self, x):

        out = self.bpnet(x)
        out = out.view(out.shape[0], out.shape[1]*out.shape[2])
        y1 = self.mlp_ecg(out)
        y2 = self.mlp_ppg(out)

        return torch.cat((y1, y2), dim = -1)

#Loss function from the paper
def loss_bpnet(y, y_pred):
    #dim 0 = SBP
    #dim 1 = DBP

    erro = (y_pred - y)**2
    corrcoef = erro
    corrcoef = corrcoef[:, 0] / corrcoef[:, 1]

    aux = erro[:, 0] + corrcoef*erro[:, 1]

    return aux.mean()



In [319]:
N = 1250
test = torch.rand(64, 1, 1250)

#model1 = SEResNet(Bottleneck, [3, 4, 6, 3], 1024)
# model2 = ResBlock(1, 512, 3, is_first_block = True)
model3 = SEResCnn(Bottleneck, [3, 4, 6, 3], 1, 1024)
#out1 = model1(test)
# out2 = model2(test)
x1,x2,x3,x4 = model3(test)
print(f"inpout shape: {test.shape}")
print(f"output1 shape: {x1.shape}")
print(f"output2 shape: {x2.shape}")
print(f"output3 shape: {x3.shape}")
print(f"output4 shape: {x4.shape}")

inpout shape: torch.Size([64, 1, 1250])
output1 shape: torch.Size([64, 1024, 512])
output2 shape: torch.Size([64, 512, 256])
output3 shape: torch.Size([64, 256, 128])
output4 shape: torch.Size([64, 128, 64])


In [267]:
test = torch.rand(12, 1, 1250*2)
model1 = BPNET_blk1(Bottleneck, [3, 4, 6, 3], 1, 1024)
x1,x2,x3,x4 = model1(test)
print(f"inpout shape: {test.shape}")
print(f"output1 shape: {x1.shape}")
print(f"output2 shape: {x2.shape}")
print(f"output3 shape: {x3.shape}")
print(f"output3 shape: {x4.shape}")

inpout shape: torch.Size([12, 1, 2500])
output1 shape: torch.Size([12, 1024, 1024])
output2 shape: torch.Size([12, 512, 512])
output3 shape: torch.Size([12, 256, 256])
output3 shape: torch.Size([12, 128, 128])


In [234]:
t2 = torch.rand((12, 1, 10))
t2, t2.shape

(tensor([[[0.7428, 0.7525, 0.0418, 0.5083, 0.2307, 0.7825, 0.8495, 0.1337,
           0.1783, 0.2107]],
 
         [[0.3677, 0.2938, 0.5692, 0.4827, 0.6315, 0.7112, 0.8840, 0.4853,
           0.1218, 0.6406]],
 
         [[0.3661, 0.1583, 0.0711, 0.8283, 0.0228, 0.3586, 0.7416, 0.3084,
           0.5966, 0.8224]],
 
         [[0.0448, 0.3671, 0.9465, 0.5501, 0.4943, 0.2382, 0.4843, 0.6858,
           0.9965, 0.3052]],
 
         [[0.7042, 0.4075, 0.6133, 0.2222, 0.0796, 0.6925, 0.9508, 0.7436,
           0.7198, 0.3171]],
 
         [[0.5681, 0.8154, 0.5081, 0.5280, 0.5121, 0.2160, 0.2327, 0.8745,
           0.8192, 0.9995]],
 
         [[0.6482, 0.9921, 0.0076, 0.3258, 0.1107, 0.6575, 0.8438, 0.5387,
           0.3134, 0.5029]],
 
         [[0.5921, 0.6933, 0.3374, 0.2255, 0.9485, 0.5597, 0.0519, 0.1808,
           0.9089, 0.3794]],
 
         [[0.5400, 0.4547, 0.9560, 0.5182, 0.9414, 0.3568, 0.5995, 0.1502,
           0.7529, 0.0997]],
 
         [[0.8458, 0.0398, 0.2669, 0.8781, 0.0

In [241]:
norm_t2 = tnorm(t2, ord = 2, keepdim = True, dim = 2)
t2_norm = t2 / norm_t2
norm_t2, norm_t2.shape

(tensor([[[1.6909]],
 
         [[1.7678]],
 
         [[1.6299]],
 
         [[1.8513]],
 
         [[1.9103]],
 
         [[2.0790]],
 
         [[1.8189]],
 
         [[1.7928]],
 
         [[1.9113]],
 
         [[1.8822]],
 
         [[1.7790]],
 
         [[1.6315]]]),
 torch.Size([12, 1, 1]))

In [247]:
t2_norm = t2 / norm_t2
t2_norm, t2_norm.shape

(tensor([[[0.4393, 0.4450, 0.0247, 0.3006, 0.1365, 0.4627, 0.5024, 0.0791,
           0.1054, 0.1246]],
 
         [[0.2080, 0.1662, 0.3220, 0.2730, 0.3572, 0.4023, 0.5000, 0.2745,
           0.0689, 0.3623]],
 
         [[0.2246, 0.0971, 0.0436, 0.5082, 0.0140, 0.2200, 0.4550, 0.1892,
           0.3660, 0.5045]],
 
         [[0.0242, 0.1983, 0.5113, 0.2971, 0.2670, 0.1287, 0.2616, 0.3704,
           0.5383, 0.1648]],
 
         [[0.3686, 0.2133, 0.3211, 0.1163, 0.0417, 0.3625, 0.4977, 0.3893,
           0.3768, 0.1660]],
 
         [[0.2732, 0.3922, 0.2444, 0.2540, 0.2463, 0.1039, 0.1119, 0.4206,
           0.3941, 0.4807]],
 
         [[0.3564, 0.5454, 0.0042, 0.1791, 0.0608, 0.3615, 0.4639, 0.2962,
           0.1723, 0.2765]],
 
         [[0.3303, 0.3867, 0.1882, 0.1258, 0.5291, 0.3122, 0.0290, 0.1008,
           0.5070, 0.2116]],
 
         [[0.2825, 0.2379, 0.5002, 0.2711, 0.4925, 0.1867, 0.3137, 0.0786,
           0.3939, 0.0521]],
 
         [[0.4494, 0.0212, 0.1418, 0.4665, 0.0

In [222]:
t2[0, 0, :] / t2_norm[0,0, 0]

tensor([2.1260, 2.1828, 2.6122, 5.1026, 4.2017, 2.5588, 5.3950, 5.1737, 5.9358,
        1.4470])

In [243]:
t2[0, 0, :], t2_norm[0, 0, 0]

(tensor([0.7428, 0.7525, 0.0418, 0.5083, 0.2307, 0.7825, 0.8495, 0.1337, 0.1783,
         0.2107]),
 tensor(0.4393))

In [248]:
tnorm(t2_norm, ord = 2, dim = 2, keepdim = True)

tensor([[[1.0000]],

        [[1.0000]],

        [[1.0000]],

        [[1.0000]],

        [[1.0000]],

        [[1.0000]],

        [[1.0000]],

        [[1.0000]],

        [[1.0000]],

        [[1.0000]],

        [[1.0000]],

        [[1.0000]]])

In [228]:
t2_norm

tensor([[[0.1683, 0.1728, 0.2068, 0.4040, 0.3326, 0.2026, 0.4271, 0.4096,
          0.4699, 0.1146]],

        [[0.0066, 0.0781, 0.5628, 0.5322, 0.1585, 0.2395, 0.2216, 0.1635,
          0.4296, 0.2259]],

        [[0.3841, 0.6002, 0.1577, 0.2400, 0.1205, 0.0340, 0.3670, 0.3219,
          0.2310, 0.3201]],

        [[0.1417, 0.1320, 0.2181, 0.1164, 0.3828, 0.1357, 0.4450, 0.3609,
          0.3844, 0.5102]],

        [[0.0139, 0.2605, 0.1502, 0.4891, 0.4885, 0.0368, 0.1905, 0.3963,
          0.2806, 0.3978]],

        [[0.3699, 0.3498, 0.2860, 0.0387, 0.5117, 0.4381, 0.0621, 0.4136,
          0.1575, 0.0633]],

        [[0.1899, 0.3632, 0.2091, 0.0601, 0.4521, 0.3690, 0.0093, 0.4184,
          0.4913, 0.1657]],

        [[0.2638, 0.2833, 0.2794, 0.0776, 0.2655, 0.0299, 0.4727, 0.2101,
          0.4694, 0.4548]],

        [[0.0536, 0.3793, 0.2426, 0.5145, 0.0302, 0.2181, 0.1960, 0.5866,
          0.3010, 0.0899]],

        [[0.5213, 0.3599, 0.2031, 0.4046, 0.0469, 0.3081, 0.3415, 0.0304,

In [271]:
xx = od()
xx['feat4'] = torch.unsqueeze(x1, dim = 2)
xx['feat3'] = torch.unsqueeze(x2, dim = 2)
xx['feat2'] = torch.unsqueeze(x3, dim = 2)
xx['feat1'] = torch.unsqueeze(x4, dim = 2)
fpn_m = torchvision.ops.FeaturePyramidNetwork([1024, 512, 256, 128], 5)
out = fpn_m(xx)
out['feat4'].shape, out['feat3'].shape, out['feat2'].shape, out['feat1'].shape

(torch.Size([12, 5, 1, 1024]),
 torch.Size([12, 5, 1, 512]),
 torch.Size([12, 5, 1, 256]),
 torch.Size([12, 5, 1, 128]))

In [279]:
aux1 = nn.functional.interpolate(out['feat1'].squeeze(), scale_factor = 2, mode = "linear", align_corners = True)
aux3 = out['feat2'].squeeze() + aux1
aux3_2 = nn.functional.interpolate(aux3, scale_factor = 2, mode = 'linear', align_corners = True)
aux4 = out['feat3'].squeeze() + aux3_2
aux4_2 = nn.functional.interpolate(aux4, scale_factor = 2, mode = 'linear', align_corners = True)
final = out['feat4'].squeeze() + aux4_2
final.shape


torch.Size([12, 5, 1024])

In [320]:
test = torch.rand((64, 1, 1250*2))
test.shape

torch.Size([64, 1, 2500])

In [321]:
model4 = BPNET_blk2(Bottleneck, [3, 4, 6, 3], 1, 1024, [1024, 512, 256, 128], 32)
out4 = model4(test)
out4.shape

torch.Size([64, 8, 254])

In [322]:
model4 = BPNET_blk3(Bottleneck, [3, 4, 6, 3], 1, 1024, [1024, 512, 256, 128], 32)
out4 = model4(test)
out4.shape

torch.Size([64, 2])

In [323]:
y_pred = torch.rand((64, 2))

In [324]:
loss_bpnet(out4, y_pred)

tensor(0.5967, grad_fn=<MeanBackward0>)