In [1]:
!pip install executorch

Collecting executorch
  Downloading executorch-0.6.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting expecttest (from executorch)
  Downloading expecttest-0.3.0-py3-none-any.whl.metadata (3.8 kB)
Collecting hypothesis (from executorch)
  Downloading hypothesis-6.135.31-py3-none-any.whl.metadata (5.6 kB)
Collecting parameterized (from executorch)
  Downloading parameterized-0.9.0-py2.py3-none-any.whl.metadata (18 kB)
Collecting pytest-xdist (from executorch)
  Downloading pytest_xdist-3.8.0-py3-none-any.whl.metadata (3.0 kB)
Collecting pytest-rerunfailures (from executorch)
  Downloading pytest_rerunfailures-15.1-py3-none-any.whl.metadata (20 kB)
Collecting ruamel.yaml (from executorch)
  Downloading ruamel.yaml-0.18.14-py3-none-any.whl.metadata (24 kB)
Collecting torch==2.7.0 (from executorch)
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchaudio==2.7.0 (from executorch)
  Downloading torchaudio-2.7.0-cp311-cp311-manyl

In [35]:
#@title SAMNet
import torch
import torch.nn as nn
import torch.nn.functional as F


class FastSal(nn.Module):
    def __init__(self, pretrained=None):
        super(FastSal, self).__init__()
        self.context_path = VAMM_backbone(pretrained)
        self.pyramid_pooling = PyramidPooling(128, 128)
        self.prepare = nn.ModuleList([
                convbnrelu(128, 128, k=1, s=1, p=0, relu=False),
                convbnrelu(96, 96, k=1, s=1, p=0, relu=False),
                convbnrelu(64, 64, k=1, s=1, p=0, relu=False),
                convbnrelu(32, 32, k=1, s=1, p=0, relu=False),
                convbnrelu(16, 16, k=1, s=1, p=0, relu=False)
                ])
        self.fuse = nn.ModuleList([
                DSConv3x3(128, 96, dilation=1),
                DSConv3x3(96, 64, dilation=2),
                DSConv5x5(64, 32, dilation=2),
                DSConv5x5(32, 16, dilation=2),
                DSConv5x5(16, 16, dilation=2)
                ])
        self.heads = nn.ModuleList([
                SalHead(in_channel=96),
                SalHead(in_channel=64),
                SalHead(in_channel=32),
                SalHead(in_channel=16),
                SalHead(in_channel=16)
                ])

    def forward(self, x): # (3, 1)
        ct_stage1, ct_stage2, ct_stage3, ct_stage4, ct_stage5 = self.context_path(x)
        # (16, 1/2) (32, 1/4) (64, 1/8)  (96, 1/16) (128, 1/32)
        ct_stage6 = self.pyramid_pooling(ct_stage5)                          # (128, 1/32)

        fused_stage1 = self.fuse[0](self.prepare[0](ct_stage5) + ct_stage6)  # (96, 1/32)
        refined1 = interpolate(fused_stage1, ct_stage4.size()[2:])           # (96, 1/16)

        fused_stage2 = self.fuse[1](self.prepare[1](ct_stage4) + refined1)   # (64, 1/16)
        refined2 = interpolate(fused_stage2, ct_stage3.size()[2:])           # (64, 1/8)

        fused_stage3 = self.fuse[2](self.prepare[2](ct_stage3) + refined2)   # (32, 1/8)
        refined3 = interpolate(fused_stage3, ct_stage2.size()[2:]) 		     # (32, 1/4)

        fused_stage4 = self.fuse[3](self.prepare[3](ct_stage2) + refined3)   # (16, 1/4)
        refined4 = interpolate(fused_stage4, ct_stage1.size()[2:])		     # (16, 1/2)

        fused_stage5 = self.fuse[4](self.prepare[4](ct_stage1) + refined4)   # (16, 1/2)

        output_side1 = interpolate(self.heads[0](fused_stage1), x.size()[2:])
        output_side2 = interpolate(self.heads[1](fused_stage2), x.size()[2:])
        output_side3 = interpolate(self.heads[2](fused_stage3), x.size()[2:])
        output_side4 = interpolate(self.heads[3](fused_stage4), x.size()[2:])
        output_main  = interpolate(self.heads[4](fused_stage5), x.size()[2:])

        return output_main, output_side1, output_side2, output_side3,output_side4  #torch.cat([output_main, output_side1, output_side2, output_side3, output_side4], dim=1)


interpolate = lambda x, size: F.interpolate(x, size=size, mode='bilinear', align_corners=True)


class convbnrelu(nn.Module):
    def __init__(self, in_channel, out_channel, k=3, s=1, p=1, g=1, d=1, bias=False, bn=True, relu=True):
        super(convbnrelu, self).__init__()
        conv = [nn.Conv2d(in_channel, out_channel, k, s, p, dilation=d, groups=g, bias=bias)]
        if bn:
            conv.append(nn.BatchNorm2d(out_channel))
        if relu:
            conv.append(nn.ReLU(inplace=True))
        self.conv = nn.Sequential(*conv)

    def forward(self, x):
        return self.conv(x)


class DSConv3x3(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1, dilation=1, relu=True):
        super(DSConv3x3, self).__init__()
        self.conv = nn.Sequential(
                convbnrelu(in_channel, in_channel, k=3, s=stride, p=dilation, d=dilation, g=in_channel),
                convbnrelu(in_channel, out_channel, k=1, s=1, p=0, relu=relu)
                )

    def forward(self, x):
        return self.conv(x)


class DSConv5x5(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1, dilation=1, relu=True):
        super(DSConv5x5, self).__init__()
        self.conv = nn.Sequential(
                convbnrelu(in_channel, in_channel, k=5, s=stride, p=2*dilation, d=dilation, g=in_channel),
                convbnrelu(in_channel, out_channel, k=1, s=1, p=0, relu=relu)
                )

    def forward(self, x):
        return self.conv(x)


class SalHead(nn.Module):
    def __init__(self, in_channel):
        super(SalHead, self).__init__()
        self.conv = nn.Sequential(
                nn.Dropout2d(p=0.1),
                nn.Conv2d(in_channel, 1, 1, stride=1, padding=0),
                nn.Sigmoid()
                )

    def forward(self, x):
        return self.conv(x)


class VAMM_backbone(nn.Module):
    def __init__(self, pretrained=None):
        super(VAMM_backbone, self).__init__()
        self.layer1 = nn.Sequential(
                convbnrelu(3, 16, k=3, s=2, p=1),
                VAMM(16, dilation_level=[1,2,3])
                )
        self.layer2 = nn.Sequential(
                DSConv3x3(16, 32, stride=2),
                VAMM(32, dilation_level=[1,2,3])
                )
        self.layer3 = nn.Sequential(
                DSConv3x3(32, 64, stride=2),
                VAMM(64, dilation_level=[1,2,3]),
                VAMM(64, dilation_level=[1,2,3]),
                VAMM(64, dilation_level=[1,2,3])
                )
        self.layer4 = nn.Sequential(
                DSConv3x3(64, 96, stride=2),
                VAMM(96, dilation_level=[1,2,3]),
                VAMM(96, dilation_level=[1,2,3]),
                VAMM(96, dilation_level=[1,2,3]),
                VAMM(96, dilation_level=[1,2,3]),
                VAMM(96, dilation_level=[1,2,3]),
                VAMM(96, dilation_level=[1,2,3])
                )
        self.layer5 = nn.Sequential(
                DSConv3x3(96, 128, stride=2),
                VAMM(128, dilation_level=[1,2]),
                VAMM(128, dilation_level=[1,2]),
                VAMM(128, dilation_level=[1,2])
                )

        if pretrained is not None:
            self.load_state_dict(torch.load(pretrained))
            print('Pretrained model loaded!')

    def forward(self, x):
        out1 = self.layer1(x)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out4 = self.layer4(out3)
        out5 = self.layer5(out4)

        return out1, out2, out3, out4, out5


class VAMM(nn.Module):
    def __init__(self, channel, dilation_level=[1,2,4,8], reduce_factor=4):
        super(VAMM, self).__init__()
        self.planes = channel
        self.dilation_level = dilation_level
        self.conv = DSConv3x3(channel, channel, stride=1)
        self.branches = nn.ModuleList([
                DSConv3x3(channel, channel, stride=1, dilation=d) for d in dilation_level
                ])
        ### ChannelGate
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc1 = convbnrelu(channel, channel, 1, 1, 0, bn=True, relu=True)
        self.fc2 = nn.Conv2d(channel, (len(self.dilation_level) + 1) * channel, 1, 1, 0, bias=False)
        self.fuse = convbnrelu(channel, channel, k=1, s=1, p=0, relu=False)
        ### SpatialGate
        self.convs = nn.Sequential(
                convbnrelu(channel, channel // reduce_factor, 1, 1, 0, bn=True, relu=True),
                DSConv3x3(channel // reduce_factor, channel // reduce_factor, stride=1, dilation=2),
                DSConv3x3(channel // reduce_factor, channel // reduce_factor, stride=1, dilation=4),
                nn.Conv2d(channel // reduce_factor, 1, 1, 1, 0, bias=False)
                )

    def forward(self, x):
        conv = self.conv(x)
        brs = [branch(conv) for branch in self.branches]
        brs.append(conv)
        gather = sum(brs)

        ### ChannelGate
        d = self.gap(gather)
        d = self.fc2(self.fc1(d))
        d = torch.unsqueeze(d, dim=1).view(-1, len(self.dilation_level) + 1, self.planes, 1, 1)

        ### SpatialGate
        s = self.convs(gather).unsqueeze(1)

        ### Fuse two gates
        f = d * s
        f = F.softmax(f, dim=1)

        return self.fuse(sum([brs[i] * f[:, i, ...] for i in range(len(self.dilation_level) + 1)]))	+ x


class PyramidPooling(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(PyramidPooling, self).__init__()
        hidden_channel = int(in_channel / 4)
        self.conv1 = convbnrelu(in_channel, hidden_channel, k=1, s=1, p=0)
        self.conv2 = convbnrelu(in_channel, hidden_channel, k=1, s=1, p=0)
        self.conv3 = convbnrelu(in_channel, hidden_channel, k=1, s=1, p=0)
        self.conv4 = convbnrelu(in_channel, hidden_channel, k=1, s=1, p=0)
        self.out = convbnrelu(in_channel*2, out_channel, k=1, s=1, p=0)

    def forward(self, x):
        size = x.size()[2:]
        feat1 = interpolate(self.conv1(F.adaptive_avg_pool2d(x, 1)), size)
        feat2 = interpolate(self.conv2(F.adaptive_avg_pool2d(x, 2)), size)
        feat3 = interpolate(self.conv3(F.adaptive_avg_pool2d(x, 3)), size)
        feat4 = interpolate(self.conv4(F.adaptive_avg_pool2d(x, 6)), size)
        x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
        x = self.out(x)

        return x

In [22]:
#@title HVPNet

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.PReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)

        return x


class gycblock(nn.Module):
    def __init__(self,channels_in,channels_out):
        super(gycblock, self).__init__()
        self.recp7 = nn.Sequential(
            BasicConv(channels_in*1, channels_in, kernel_size=7, dilation=1, padding=3, groups=channels_in, bias=False),
            BasicConv(channels_in, channels_in, kernel_size=1, dilation=1, bias=False),

            BasicConv(channels_in, channels_in, kernel_size=3, dilation=7, padding=7, groups=channels_in, bias=False,
                      relu=False),
            BasicConv(channels_in, channels_in, kernel_size=1, dilation=1, bias=False)
        )

        self.recp5 = nn.Sequential(
            BasicConv(channels_in*2, channels_in, kernel_size=5, dilation=1, padding=2, groups=channels_in, bias=False),
            BasicConv(channels_in, channels_in, kernel_size=1, dilation=1, bias=False),

            BasicConv(channels_in, channels_in, kernel_size=3, dilation=5, padding=5, groups=channels_in, bias=False,
                      relu=False),
            BasicConv(channels_in, channels_in, kernel_size=1, dilation=1, bias=False)
        )

        self.recp3 = nn.Sequential(
            BasicConv(channels_in*3, channels_in, kernel_size=3, dilation=1, padding=1, groups=channels_in, bias=False),
            BasicConv(channels_in, channels_in, kernel_size=1, dilation=1, bias=False),

            BasicConv(channels_in, channels_in, kernel_size=3, dilation=3, padding=3, groups=channels_in, bias=False,
                      relu=False),
            BasicConv(channels_in, channels_in, kernel_size=1, dilation=1, bias=False)
        )

        self.recp1 = nn.Sequential(
            BasicConv(channels_in*4, channels_out, kernel_size=1, dilation=1, bias=False,relu=True)
        )

    def forward(self, x):
        x0 = self.recp7(x)
        x1 = self.recp5(torch.cat([x,x0],dim=1))
        x2 = self.recp3(torch.cat([x,x0,x1],dim=1))
        out = self.recp1(torch.cat([x,x0,x1,x2],dim=1))

        return out


class Channelatt(nn.Module):
    def __init__(self, channel, reduction=4):
        super(Channelatt, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)

        return x * y.expand_as(x)


class Spatialatt(nn.Module):
    def __init__(self,channels_in):
        super(Spatialatt, self).__init__()
        kernel_size = 3
        self.spatial = BasicConv(channels_in, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_out = self.spatial(x)
        scale = torch.sigmoid(x_out) # broadcasting

        return x * scale


class residual_att(nn.Module):
    def __init__(self,channels_in, reduction=4):
        super(residual_att, self).__init__()
        self.channel_att=Channelatt(channels_in, reduction=reduction)
        self.spatialatt=Spatialatt(channels_in)

    def forward(self, x):
        return x + self.spatialatt(self.channel_att(x))



class ConvBlock(nn.Module):
    def __init__(self, nIn, nOut, add=True):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(nIn, nOut, 3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(nOut)
        self.act = nn.PReLU(nOut)
        self.add = add

    def forward(self, input):
        output = self.conv(input)
        if self.add:
            output = input + output

        output = self.bn(output)
        output = self.act(output)

        return output


class DilatedParallelConvBlockD2(nn.Module):
    def __init__(self, nIn, nOut, add=False):
        super(DilatedParallelConvBlockD2, self).__init__()
        n = int(np.ceil(nOut / 2.))
        n2 = nOut - n

        self.conv0 = nn.Conv2d(nIn, nOut, 1, stride=1, padding=0, dilation=1, bias=False)
        self.conv1 = nn.Conv2d(n, n, 3, stride=1, padding=1, dilation=1, bias=False)
        self.conv2 = nn.Conv2d(n2, n2, 3, stride=1, padding=2, dilation=2, bias=False)

        self.bn = nn.BatchNorm2d(nOut)
        #self.act = nn.PReLU(nOut)
        self.add = add

    def forward(self, input):
        in0 = self.conv0(input)
        in1, in2 = torch.chunk(in0, 2, dim=1)
        b1 = self.conv1(in1)
        b2 = self.conv2(in2)
        output = torch.cat([b1, b2], dim=1)

        if self.add:
            output = input + output
        output = self.bn(output)
        #output = self.act(output)

        return output


class DownsamplerBlockDepthwiseConv(nn.Module):
    def __init__(self, nIn, nOut):
        super(DownsamplerBlockDepthwiseConv, self).__init__()
        self.nIn = nIn
        self.nOut = nOut

        if self.nIn < self.nOut:
            self.conv0 = nn.Conv2d(nIn, nOut-nIn, 1, stride=1, padding=0, dilation=1, groups=1, bias=False)
            self.conv1 = nn.Conv2d(nOut-nIn, nOut-nIn, 5, stride=2, padding=2, dilation=1, groups=nOut-nIn, bias=False)
            #self.pool = nn.MaxPool2d(2, stride=2)
            self.pool = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        else:
            self.conv0 = nn.Conv2d(nIn, nOut, 1, stride=1, padding=0, dilation=1, groups=1, bias=False)
            self.conv1 = nn.Conv2d(nOut, nOut, 5, stride=2, padding=2, dilation=1, groups=nOut, bias=False)

        self.bn = nn.BatchNorm2d(nOut)
        self.act = nn.PReLU(nOut)

    def forward(self, input):
        if self.nIn < self.nOut:
            output = torch.cat([self.conv1(self.conv0(input)), self.pool(input)], 1)
        else:
            output = self.conv1(self.conv0(input))

        output = self.bn(output)
        output = self.act(output)

        return output


class DownsamplerBlockConv(nn.Module):
    def __init__(self, nIn, nOut):
        super(DownsamplerBlockConv, self).__init__()
        self.nIn = nIn
        self.nOut = nOut

        if self.nIn < self.nOut:
            self.conv = nn.Conv2d(nIn, nOut-nIn, 3, stride=2, padding=1, bias=False)
            #self.pool = nn.MaxPool2d(2, stride=2)
            self.pool = nn.MaxPool2d(2, stride=2, ceil_mode=True)
        else:
            self.conv = nn.Conv2d(nIn, nOut, 3, stride=2, padding=1, bias=False)

        self.bn = nn.BatchNorm2d(nOut)
        self.act = nn.PReLU(nOut)

    def forward(self, input):
        if self.nIn < self.nOut:
            output = torch.cat([self.conv(input), self.pool(input)], 1)
        else:
            output = self.conv(input)

        output = self.bn(output)
        output = self.act(output)

        return output


class Backbone(nn.Module):
    def __init__(self, P1=1, P2=1, P3=3, P4=5, reduction=4, pretrained=None):
        super(Backbone, self).__init__()
        self.level1_0 = DownsamplerBlockConv(3, 16)
        self.level1 = nn.ModuleList()
        for i in range(0, P1):
            self.level1.append(ConvBlock(16, 16))
        self.level1.append(residual_att(16, reduction=reduction))
        self.branch1 = nn.Conv2d(16, 16, 1, stride=1, padding=0,bias=False)
        self.br1 = nn.Sequential(nn.BatchNorm2d(16), nn.PReLU(16))

        self.level2_0 = DownsamplerBlockDepthwiseConv(16, 32)
        self.level2 = nn.ModuleList()
        for i in range(0, P2):
            self.level2.append(nn.Dropout2d(0.1, True))
            self.level2.append(gycblock(32, 32))
            self.level2.append(residual_att(32, reduction=reduction))
        self.branch2 = nn.Conv2d(32, 32, 1, stride=1, padding=0,bias=False)
        self.br2 = nn.Sequential(nn.BatchNorm2d(32), nn.PReLU(32))

        self.level3_0 = DownsamplerBlockDepthwiseConv(32, 64)
        self.level3 = nn.ModuleList()
        for i in range(0, P3):
            self.level3.append(nn.Dropout2d(0.1, True))
            self.level3.append(gycblock(64, 64))
            self.level3.append(residual_att(64, reduction=reduction))
        self.branch3 = nn.Conv2d(64, 64, 1, stride=1, padding=0,bias=False)
        self.br3 = nn.Sequential(nn.BatchNorm2d(64), nn.PReLU(64))

        self.level4_0 = DownsamplerBlockDepthwiseConv(64, 128)
        self.level4 = nn.ModuleList()
        for i in range(0, P4):
            self.level4.append(nn.Dropout2d(0.1, True))
            self.level4.append(gycblock(128, 128))
            self.level4.append(residual_att(128, reduction=reduction))
        self.branch4 = nn.Conv2d(128, 128, 1, stride=1, padding=0,bias=False)
        self.br4 = nn.Sequential(nn.BatchNorm2d(128), nn.PReLU(128))

        if pretrained is not None:
            self.load_state_dict(torch.load(pretrained))
            print('Pretrained Model Loaded!')

    def forward(self, input):
        output1_0 = self.level1_0(input)

        output1 = output1_0
        for layer in self.level1:
            output1 = layer(output1)
        output1 = self.br1(self.branch1(output1_0) + output1)

        output2_0 = self.level2_0(output1)
        output2 = output2_0
        for layer in self.level2:
            output2 = layer(output2)
        output2 = self.br2(self.branch2(output2_0) + output2)

        output3_0 = self.level3_0(output2)
        output3 = output3_0
        for layer in self.level3:
            output3 = layer(output3)
        output3 = self.br3(self.branch3(output3_0) + output3)

        output4_0 = self.level4_0(output3)
        output4 = output4_0
        for layer in self.level4:
            output4 = layer(output4)
        output4 = self.br4(self.branch4(output4_0) + output4)

        return output1, output2, output3, output4


class FastSal(nn.Module):
    '''
    This class defines the MiniNetV2 network
    '''
    def __init__(self, P1=1, P2=1, P3=3, P4=5, reduction=4, pretrained=None):
        super(FastSal, self).__init__()
        self.backbone = Backbone(P1, P2, P3, P4, reduction, pretrained)

        self.up3_conv4 = DilatedParallelConvBlockD2(128, 64)
        self.up3_conv3 = nn.Conv2d(64, 64, 1, stride=1, padding=0,bias=False)
        self.up3_bn3 = nn.BatchNorm2d(64)
        self.up3_act = nn.PReLU(64)

        self.up2_conv3 = DilatedParallelConvBlockD2(64, 32)
        self.up2_conv2 = nn.Conv2d(32, 32, 1, stride=1, padding=0,bias=False)
        self.up2_bn2 = nn.BatchNorm2d(32)
        self.up2_act = nn.PReLU(32)

        self.up1_conv2 = DilatedParallelConvBlockD2(32, 16)
        self.up1_conv1 = nn.Conv2d(16, 16, 1, stride=1, padding=0,bias=False)
        self.up1_bn1 = nn.BatchNorm2d(16)
        self.up1_act = nn.PReLU(16)

        self.classifier4 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(128, 1, 1, stride=1, padding=0, bias=False))
        self.classifier3 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(64, 1, 1, stride=1, padding=0, bias=False))
        self.classifier2 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(32, 1, 1, stride=1, padding=0, bias=False))
        self.classifier1 = nn.Sequential(nn.Dropout2d(0.1, False), nn.Conv2d(16, 1, 1, stride=1, padding=0, bias=False))

    def forward(self, input):
        output1, output2, output3, output4 = self.backbone(input)

        up4 = F.interpolate(output4, output3.size()[2:], mode='bilinear', align_corners=False)
        up3_conv4 = self.up3_conv4(up4)
        up3_conv3 = self.up3_bn3(self.up3_conv3(output3))
        up3 = self.up3_act(up3_conv4 + up3_conv3)

        up3 = F.interpolate(up3, output2.size()[2:], mode='bilinear', align_corners=False)
        up2_conv3 = self.up2_conv3(up3)
        up2_conv2 = self.up2_bn2(self.up2_conv2(output2))
        up2 = self.up2_act(up2_conv3 + up2_conv2)

        up2 = F.interpolate(up2, output1.size()[2:], mode='bilinear', align_corners=False)
        up1_conv2 = self.up1_conv2(up2)
        up1_conv1 = self.up1_bn1(self.up1_conv1(output1))
        up1 = self.up1_act(up1_conv2 + up1_conv1)

        classifier4 = torch.sigmoid(self.classifier4(up4))
        classifier3 = torch.sigmoid(self.classifier3(up3))
        classifier2 = torch.sigmoid(self.classifier2(up2))
        classifier1 = torch.sigmoid(self.classifier1(up1))
        classifier4 = F.interpolate(classifier4, input.size()[2:], mode='bilinear', align_corners=False)
        classifier3 = F.interpolate(classifier3, input.size()[2:], mode='bilinear', align_corners=False)
        classifier2 = F.interpolate(classifier2, input.size()[2:], mode='bilinear', align_corners=False)
        classifier1 = F.interpolate(classifier1, input.size()[2:], mode='bilinear', align_corners=False)

        return classifier1, classifier2, classifier3, classifier4 #torch.cat([classifier1, classifier2, classifier3, classifier4], dim=1)

In [9]:
class FixedInterpolate(nn.Module):
    def __init__(self, target_size):
        super().__init__()
        self.target_size = target_size

    def forward(self, x):
        return F.interpolate(x, size=self.target_size, mode='bilinear', align_corners=False)


# Example use
self.upsample_to_224 = FixedInterpolate((224, 224))

x = self.upsample_to_224(x)

NameError: name 'self' is not defined

In [29]:
#@title Seanet


from torch import nn
import torch
import torch.nn.functional as F
import numpy as np
import os
from torch.nn import Parameter


try:
    from torchvision.models.utils import load_state_dict_from_url  # torchvision 0.4+
except ModuleNotFoundError:
    try:
        from torch.hub import load_state_dict_from_url  # torch 1.x
    except ModuleNotFoundError:
        from torch.utils.model_zoo import load_url as load_state_dict_from_url  # torch 0.4.1

model_urls = {
    'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}


class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1):
        padding = (kernel_size - 1) // 2
        if dilation != 1:
            padding = dilation
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, dilation=dilation,
                      bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True)
        )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio, dilation=1):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend([
            # dw
            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, dilation=dilation),
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        ])
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, pretrained=None, num_classes=1000, width_mult=1.0):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        inverted_residual_setting = [
            # t, c, n, s, d
            [1, 16, 1, 1, 1], # conv1 112*112*16
            [6, 24, 2, 2, 1], # conv2 56*56*24
            [6, 32, 3, 2, 1], # conv3 28*28*32
            [6, 64, 4, 2, 1],
            [6, 96, 3, 1, 1], # conv4 14*14*96
            [6, 160, 3, 2, 1],
            [6, 320, 1, 1, 1], # conv5 7*7*320
        ]

        # building first layer
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * max(1.0, width_mult))
        features = [ConvBNReLU(3, input_channel, stride=2)]
        # building inverted residual blocks
        for t, c, n, s, d in inverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                stride = s if i == 0 else 1
                dilation = d if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d))
                input_channel = output_channel
        # building last several layers
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
        # make it nn.Sequential
        self.features = nn.Sequential(*features)

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        res = []
        for idx, m in enumerate(self.features):
            x = m(x)
            if idx in [1, 3, 6, 13, 17]:
                res.append(x)
        return res


def mobilenet_v2(pretrained=True, progress=True, **kwargs):
    model = MobileNetV2(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
                                              progress=progress)
        print("loading imagenet pretrained mobilenetv2")
        model.load_state_dict(state_dict, strict=False)
        print("loaded imagenet pretrained mobilenetv2")
    return model



class seanet_convbnrelu(nn.Module):
    def __init__(self, in_channel, out_channel, k=3, s=1, p=1, g=1, d=1, bias=False, bn=True, relu=True):
        super(seanet_convbnrelu, self).__init__()
        conv = [nn.Conv2d(in_channel, out_channel, k, s, p, dilation=d, groups=g, bias=bias)]
        if bn:
            conv.append(nn.BatchNorm2d(out_channel))
        if relu:
            conv.append(nn.PReLU(out_channel))
        self.conv = nn.Sequential(*conv)

    def forward(self, x):
        return self.conv(x)


class DSConv3x3(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1, dilation=1, relu=True):
        super(DSConv3x3, self).__init__()
        self.conv = nn.Sequential(
            seanet_convbnrelu(in_channel, in_channel, k=3, s=stride, p=dilation, d=dilation, g=in_channel),
            seanet_convbnrelu(in_channel, out_channel, k=1, s=1, p=0, relu=relu)
        )

    def forward(self, x):
        return self.conv(x)


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=4):
        super(ChannelAttention, self).__init__()

        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=3):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = max_out
        x = self.conv1(x)
        return self.sigmoid(x)


# Channel-wise Correlation
class CCorrM(nn.Module):
    def __init__(self, all_channel=24, all_dim=128):
        super(CCorrM, self).__init__()
        self.linear_e = nn.Linear(all_channel, all_channel, bias=False) #weight
        self.channel = all_channel
        self.dim = all_dim * all_dim
        self.conv1 = DSConv3x3(all_channel, all_channel, stride=1)
        self.conv2 = DSConv3x3(all_channel, all_channel, stride=1)

    def forward(self, exemplar, query):  # exemplar: f1, query: f2
        fea_size = query.size()[2:]
        exemplar = F.interpolate(exemplar, size=fea_size, mode="bilinear", align_corners=True)
        all_dim = fea_size[0] * fea_size[1]
        exemplar_flat = exemplar.view(-1, self.channel, all_dim)  # N,C1,H,W -> N,C1,H*W
        query_flat = query.view(-1, self.channel, all_dim)  # N,C2,H,W -> N,C2,H*W
        exemplar_t = torch.transpose(exemplar_flat, 1, 2).contiguous()  # batchsize x dim x num, N,H*W,C1
        exemplar_corr = self.linear_e(exemplar_t)  # batchsize x dim x num, N,H*W,C1
        A = torch.bmm(query_flat, exemplar_corr)  # ChannelCorrelation: N,C2,H*W x N,H*W,C1 = N,C2,C1

        A1 = F.softmax(A.clone(), dim=2)  # N,C2,C1. dim=2 is row-wise norm. Sr
        B = F.softmax(torch.transpose(A, 1, 2), dim=2)  # N,C1,C2 column-wise norm. Sc
        query_att = torch.bmm(A1, exemplar_flat).contiguous()  # N,C2,C1 X N,C1,H*W = N,C2,H*W
        exemplar_att = torch.bmm(B, query_flat).contiguous()  # N,C1,C2 X N,C2,H*W = N,C1,H*W

        exemplar_att = exemplar_att.view(-1, self.channel, fea_size[0], fea_size[1])  # N,C1,H*W -> N,C1,H,W
        exemplar_out = self.conv1(exemplar_att + exemplar)

        query_att = query_att.view(-1, self.channel, fea_size[0], fea_size[1])  # N,C2,H*W -> N,C2,H,W
        #query_out = self.conv1(query_att + query) #https://github.com/MathLee/SeaNet/issues/2

        query_out = self.conv2(query_att + query)

        return exemplar_out, query_out
"""
class CCorrM(nn.Module):
    def __init__(self, all_channel=24, all_dim=128):
        super(CCorrM, self).__init__()
        self.linear_e = nn.Linear(all_channel, all_channel, bias=False) #weight
        self.channel = all_channel
        self.dim = all_dim * all_dim
        self.conv1 = DSConv3x3(all_channel, all_channel, stride=1)
        self.conv2 = DSConv3x3(all_channel, all_channel, stride=1)

    def forward(self, exemplar, query):  # exemplar: f1, query: f2
        fea_size = query.size()[2:]
        exemplar = F.interpolate(exemplar, size=fea_size, mode="bilinear", align_corners=True)
        all_dim = fea_size[0] * fea_size[1]
        exemplar_flat = exemplar.view(-1, self.channel, all_dim)  # N,C1,H,W -> N,C1,H*W
        query_flat = query.view(-1, self.channel, all_dim)  # N,C2,H,W -> N,C2,H*W
        exemplar_t = torch.transpose(exemplar_flat, 1, 2).contiguous()  # batchsize x dim x num, N,H*W,C1
        exemplar_corr = self.linear_e(exemplar_t)  # batchsize x dim x num, N,H*W,C1
        A = torch.bmm(query_flat, exemplar_corr)  # ChannelCorrelation: N,C2,H*W x N,H*W,C1 = N,C2,C1

        A1 = F.softmax(A.clone(), dim=2)  # N,C2,C1. dim=2 is row-wise norm. Sr
        B = F.softmax(torch.transpose(A, 1, 2), dim=2)  # N,C1,C2 column-wise norm. Sc
        query_att = torch.bmm(A1, exemplar_flat).contiguous()  # N,C2,C1 X N,C1,H*W = N,C2,H*W
        exemplar_att = torch.bmm(B, query_flat).contiguous()  # N,C1,C2 X N,C2,H*W = N,C1,H*W

        exemplar_att = exemplar_att.view(-1, self.channel, fea_size[0], fea_size[1])  # N,C1,H*W -> N,C1,H,W
        exemplar_out = self.conv1(exemplar_att + exemplar)

        query_att = query_att.view(-1, self.channel, fea_size[0], fea_size[1])  # N,C2,H*W -> N,C2,H,W
        query_out = self.conv1(query_att + query)

        return exemplar_out, query_out
"""

# Edge-based Enhancement Unit (EEU)
class EEU(nn.Module):
    def __init__(self, in_channel):
        super(EEU, self).__init__()
        self.avg_pool = nn.AvgPool2d((3, 3), stride=1, padding=1)
        self.conv_1 = nn.Conv2d(in_channel, in_channel, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(in_channel)
        self.sigmoid = nn.Sigmoid()
        self.PReLU = nn.PReLU(in_channel)

    def forward(self, x):
        edge = x - self.avg_pool(x)  # Xi=X-Avgpool(X)
        weight = self.sigmoid(self.bn1(self.conv_1(edge)))
        out = weight * x + x
        return self.PReLU(edge), out


# Edge Self-Alignment Module (ESAM)
class ESAM(nn.Module):
    def __init__(self, channel1=16, channel2=24):
        super(ESAM, self).__init__()

        self.smooth1 = DSConv3x3(channel1, channel2, stride=1, dilation=1)  # 16channel-> 24channel

        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.smooth2 = DSConv3x3(channel2, channel2, stride=1, dilation=1)  # 24channel-> 24channel

        self.eeu1 = EEU(channel2)
        self.eeu2 = EEU(channel2)
        self.ChannelCorrelation = CCorrM(channel2, 128)

    def forward(self, x1, x2):  # x1 16*144*14; x2 24*72*72

        x1_1 = self.smooth1(x1)
        edge1, x1_2 = self.eeu1(x1_1)

        x2_1 = self.smooth2(self.upsample2(x2))
        edge2, x2_2 = self.eeu2(x2_1)

        # Channel-wise Correlation
        x1_out, x2_out = self.ChannelCorrelation(x1_2, x2_2)

        return edge1, edge2, torch.cat([x1_out, x2_out], 1)  # (24*2)*144*144


# Dynamic Semantic Matching Module (DSMM)
class DSMM(nn.Module):
    def __init__(self, channel4=96, channel3=32):
        super(DSMM, self).__init__()

        self.fuse4 = seanet_convbnrelu(channel4, channel4, k=1, s=1, p=0, relu=True)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.smooth4 = DSConv3x3(channel4, channel4, stride=1, dilation=1)  # 96channel-> 96channel

        self.fuse3 = seanet_convbnrelu(channel3, channel3, k=1, s=1, p=0, relu=True)
        self.smooth3 = DSConv3x3(channel3, channel4, stride=1, dilation=1)  # 32channel-> 96channel
        self.ChannelCorrelation = CCorrM(channel4, 32)

    def forward(self, x4, k4, x3, k3):  # x4:96*18*18 k4:96*5*5; x3:32*36*36 k3:32*5*5
        B4, C4, H4, W4 = k4.size()
        B3, C3, H3, W3 = k3.size()

        x_B4, x_C4, x_H4, x_W4 = x4.size()  # 8*96*18*18
        x_B3, x_C3, x_H3, x_W3 = x3.size()  # 8*32*36*36

        x4_new = x4.clone()
        x3_new = x3.clone()
        # k4 = k4.view(C4, 1, H4, W4)
        # k3 = k3.view(C3, 1, H3, W3)
        for i in range(1, B4):
            kernel4 = k4[i, :, :, :]
            kernel3 = k3[i, :, :, :]
            kernel4 = kernel4.view(C4, 1, H4, W4)
            kernel3 = kernel3.view(C3, 1, H3, W3)
            # DDconv
            x4_r1 = F.conv2d(x4[i, :, :, :].view(1, C4, x_H4, x_W4), kernel4, stride=1, padding=2, dilation=1,
                             groups=C4)
            x4_r2 = F.conv2d(x4[i, :, :, :].view(1, C4, x_H4, x_W4), kernel4, stride=1, padding=4, dilation=2,
                             groups=C4)
            x4_r3 = F.conv2d(x4[i, :, :, :].view(1, C4, x_H4, x_W4), kernel4, stride=1, padding=6, dilation=3,
                             groups=C4)
            x4_new[i, :, :, :] = x4_r1 + x4_r2 + x4_r3

            # DDconv
            x3_r1 = F.conv2d(x3[i, :, :, :].view(1, C3, x_H3, x_W3), kernel3, stride=1, padding=2, dilation=1,
                             groups=C3)
            x3_r2 = F.conv2d(x3[i, :, :, :].view(1, C3, x_H3, x_W3), kernel3, stride=1, padding=4, dilation=2,
                             groups=C3)
            x3_r3 = F.conv2d(x3[i, :, :, :].view(1, C3, x_H3, x_W3), kernel3, stride=1, padding=6, dilation=3,
                             groups=C3)
            x3_new[i, :, :, :] = x3_r1 + x3_r2 + x3_r3
        # Pconv
        x4_all = self.fuse4(x4_new)
        x4_smooth = self.smooth4(self.upsample2(x4_all))
        # Pconv
        x3_all = self.fuse3(x3_new)
        x3_smooth = self.smooth3(x3_all)

        # Channel-wise Correlation
        x3_out, x4_out = self.ChannelCorrelation(x3_smooth, x4_smooth)

        return torch.cat([x3_out, x4_out], 1)  # (96*2)*32*32


class SalHead(nn.Module):
    def __init__(self, in_channel):
        super(SalHead, self).__init__()
        self.conv = nn.Sequential(
            nn.Dropout2d(p=0.1),
            nn.Conv2d(in_channel, 1, 1, stride=1, padding=0),
        )

    def forward(self, x):
        return self.conv(x)


class prediction_decoder(nn.Module):
    def __init__(self, channel5=320, channel34=192, channel12=48):
        super(prediction_decoder, self).__init__()
        # 9*9
        self.decoder5 = nn.Sequential(
            DSConv3x3(channel5, channel5, stride=1),
            DSConv3x3(channel5, channel5, stride=1),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),  # 36*36
            DSConv3x3(channel5, channel34, stride=1)
        )
        self.s5 = SalHead(channel34)  # 36*36

        # 36*36
        self.decoder34 = nn.Sequential(
            DSConv3x3(channel34 * 2, channel34, stride=1),
            DSConv3x3(channel34, channel34, stride=1),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),  # 144*144
            DSConv3x3(channel34, channel12, stride=1)
        )
        self.s34 = SalHead(channel12)  # 144*144

        # 144*144
        self.decoder12 = nn.Sequential(
            DSConv3x3(channel12 * 2, channel12, stride=1),
            DSConv3x3(channel12, channel12, stride=1),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 288*288
            DSConv3x3(channel12, channel12, stride=1)
        )
        self.s12 = SalHead(channel12)

    def forward(self, x5, x34, x12):
        x5_decoder = self.decoder5(x5)
        s5 = self.s5(x5_decoder)

        x34_decoder = self.decoder34(torch.cat([x5_decoder, x34], 1))
        s34 = self.s34(x34_decoder)

        x12_decoder = self.decoder12(torch.cat([x34_decoder, x12], 1))
        s12 = self.s12(x12_decoder)

        return s12, s34, s5


class SeaNet(nn.Module):
    def __init__(self, pretrained=True, channel=128):
        super(SeaNet, self).__init__()
        # Backbone model
        self.backbone = mobilenet_v2(pretrained)
        # input 256*256*3
        # conv1 128*128*16
        # conv2 64*64*24
        # conv3 32*32*32
        # conv4 16*16*96
        # conv5 8*8*320

        # Semantic Knowledge Compression(SKC) unit, k3 and k4
        self.conv5_conv4 = DSConv3x3(320, 96, stride=1)
        self.conv5_conv3 = DSConv3x3(320, 32, stride=1)
        self.pool = nn.AdaptiveAvgPool2d(5)

        self.dsmm = DSMM(96, 32)
        self.esam = ESAM(16, 24)

        self.prediction_decoder = prediction_decoder(320, 192, 48)

        self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        # generate backbone features
        conv1, conv2, conv3, conv4, conv5 = self.backbone(input)

        # Semantic Knowledge Compression(SKC) unit, kernel_conv4 (k4) and kernel_conv3 (k3)
        kernel_conv4 = self.pool(self.conv5_conv4(conv5))  # 96*5*5
        kernel_conv3 = self.pool(self.conv5_conv3(conv5))  # 32*5*5

        # conv34 is f_dsmm
        conv34 = self.dsmm(conv4, kernel_conv4, conv3, kernel_conv3)
        # conv12 is f_esam
        edge1, edge2, conv12 = self.esam(conv1, conv2)

        s12, s34, s5 = self.prediction_decoder(conv5, conv34, conv12)

        s5_up = self.upsample8(s5)
        s34_up = self.upsample2(s34)

        sigmoid1 = self.sigmoid(s12)
        sigmoid2 = self.sigmoid(s34_up)
        sigmoid3  = self.sigmoid(s5_up)

        return s12, s34_up, s5_up, sigmoid1, sigmoid2, sigmoid3, edge1, edge2

In [16]:
#@title u2net

import torch
import torch.nn as nn
import torch.nn.functional as F

class REBNCONV(nn.Module):
    def __init__(self,in_ch=3,out_ch=3,dirate=1):
        super(REBNCONV,self).__init__()

        self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self,x):

        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

        return xout

## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):

    src = F.upsample(src,size=tar.shape[2:],mode='bilinear')

    return src


### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU7,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x
        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)
        hx = self.pool5(hx5)

        hx6 = self.rebnconv6(hx)

        hx7 = self.rebnconv7(hx6)

        hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))
        hx6dup = _upsample_like(hx6d,hx5)

        hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU6,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)
        hx = self.pool4(hx4)

        hx5 = self.rebnconv5(hx)

        hx6 = self.rebnconv6(hx5)


        hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU5,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)
        hx = self.pool3(hx3)

        hx4 = self.rebnconv4(hx)

        hx5 = self.rebnconv5(hx4)

        hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
        self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx = self.pool1(hx1)

        hx2 = self.rebnconv2(hx)
        hx = self.pool2(hx2)

        hx3 = self.rebnconv3(hx)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

        return hx1d + hxin

### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):

    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
        super(RSU4F,self).__init__()

        self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

        self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
        self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
        self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)

        self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

        self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
        self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
        self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

    def forward(self,x):

        hx = x

        hxin = self.rebnconvin(hx)

        hx1 = self.rebnconv1(hxin)
        hx2 = self.rebnconv2(hx1)
        hx3 = self.rebnconv3(hx2)

        hx4 = self.rebnconv4(hx3)

        hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
        hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
        hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))

        return hx1d + hxin


##### U^2-Net ####
class U2NET(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NET,self).__init__()

        self.stage1 = RSU7(in_ch,32,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,32,128)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(128,64,256)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(256,128,512)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(512,256,512)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(512,256,512)

        # decoder
        self.stage5d = RSU4F(1024,256,512)
        self.stage4d = RSU4(1024,128,256)
        self.stage3d = RSU5(512,64,128)
        self.stage2d = RSU6(256,32,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(512,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6*out_ch,out_ch,1)

    def forward(self,x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

        #-------------------- decoder --------------------
        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)

        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

        out1 = F.sigmoid(d0)
        out2 = F.sigmoid(d1)
        out3 = F.sigmoid(d3)
        out4 = F.sigmoid(d4)
        out5 = F.sigmoid(d5)
        out6 = F.sigmoid(d6)
        return out1, out2, out3, out4, out5, out6

        #return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

### U^2-Net small ###
class U2NETP(nn.Module):

    def __init__(self,in_ch=3,out_ch=1):
        super(U2NETP,self).__init__()

        self.stage1 = RSU7(in_ch,16,64)
        self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage2 = RSU6(64,16,64)
        self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage3 = RSU5(64,16,64)
        self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage4 = RSU4(64,16,64)
        self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage5 = RSU4F(64,16,64)
        self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

        self.stage6 = RSU4F(64,16,64)

        # decoder
        self.stage5d = RSU4F(128,16,64)
        self.stage4d = RSU4(128,16,64)
        self.stage3d = RSU5(128,16,64)
        self.stage2d = RSU6(128,16,64)
        self.stage1d = RSU7(128,16,64)

        self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
        self.side6 = nn.Conv2d(64,out_ch,3,padding=1)

        self.outconv = nn.Conv2d(6*out_ch,out_ch,1)

    def forward(self,x):

        hx = x

        #stage 1
        hx1 = self.stage1(hx)
        hx = self.pool12(hx1)

        #stage 2
        hx2 = self.stage2(hx)
        hx = self.pool23(hx2)

        #stage 3
        hx3 = self.stage3(hx)
        hx = self.pool34(hx3)

        #stage 4
        hx4 = self.stage4(hx)
        hx = self.pool45(hx4)

        #stage 5
        hx5 = self.stage5(hx)
        hx = self.pool56(hx5)

        #stage 6
        hx6 = self.stage6(hx)
        hx6up = _upsample_like(hx6,hx5)

        #decoder
        hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
        hx5dup = _upsample_like(hx5d,hx4)

        hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
        hx4dup = _upsample_like(hx4d,hx3)

        hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
        hx3dup = _upsample_like(hx3d,hx2)

        hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
        hx2dup = _upsample_like(hx2d,hx1)

        hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


        #side output
        d1 = self.side1(hx1d)

        d2 = self.side2(hx2d)
        d2 = _upsample_like(d2,d1)

        d3 = self.side3(hx3d)
        d3 = _upsample_like(d3,d1)

        d4 = self.side4(hx4d)
        d4 = _upsample_like(d4,d1)

        d5 = self.side5(hx5d)
        d5 = _upsample_like(d5,d1)

        d6 = self.side6(hx6)
        d6 = _upsample_like(d6,d1)

        d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

        return F.sigmoid(d1) #F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

In [46]:
#@title FLIM

import numpy as np
import time
import torch
import torch.nn as nn
import json
import torch.nn.functional as F

arch_path = 'arch.json' # Must be a .json file with the architecture definition



# Load JSON config
with open(arch_path, 'r') as f:
    config = json.load(f)



import torch
import torch.nn as nn
import torch.nn.functional as F

class MeanBasedDecoder(nn.Module):
    """
    based on paper https://arxiv.org/pdf/2504.20872 FLIM-based Salient Object Detection Networks with Adaptive Decoders
    """
    def __init__(self):
        super(MeanBasedDecoder, self).__init__()

    def forward(self, features, original_size, decoder_weights):
        B, C, H, W = features.shape
        decoder_weights = decoder_weights.to(dtype=torch.float32, device=features.device)

        # Expand decoder_weights for broadcasting: (C,) → (1, C, 1, 1)
        alpha_weights = decoder_weights.view(1, C, 1, 1)  # foreground mask (1s & 0s)

        # Compute μ1 and μ2 without branching
        foreground_mask = alpha_weights
        background_mask = 1.0 - alpha_weights

        foreground_sum = (features * foreground_mask).sum(dim=1, keepdim=True)  # (B, 1, H, W)
        background_sum = (features * background_mask).sum(dim=1, keepdim=True)  # (B, 1, H, W)

        # Count foreground and background channels
        fg_count = torch.sum(foreground_mask)
        bg_count = torch.sum(background_mask)

        # Avoid divide-by-zero
        fg_count = fg_count.clamp(min=1.0)
        bg_count = bg_count.clamp(min=1.0)

        μ1 = foreground_sum / fg_count
        μ2 = background_sum / bg_count

        # Compare μ1 and μ2, build alpha vector
        μ1_gt_μ2 = (μ1 > μ2).float()  # (B, 1, H, W)
        μ1_lt_μ2 = (μ1 < μ2).float()

        # Expand boolean results to match feature shape
        μ1_gt_μ2_exp = μ1_gt_μ2.expand(-1, C, -1, -1)
        μ1_lt_μ2_exp = μ1_lt_μ2.expand(-1, C, -1, -1)

        # Expand decoder weights to match feature shape
        decoder_mask = decoder_weights.view(1, C, 1, 1).expand(B, C, H, W)

        # Define α: +1, −1, or 0
        alpha = torch.where(
            (decoder_mask == 1.0) & (μ1_gt_μ2_exp == 1.0),
            torch.ones_like(features),
            torch.where(
                (decoder_mask == 0.0) & (μ1_lt_μ2_exp == 1.0),
                -torch.ones_like(features),
                torch.zeros_like(features)
            )
        )

        # Apply α to features
        weighted_features = features * alpha

        # Sum channels → saliency map
        saliency = weighted_features.sum(dim=1, keepdim=True)

        # Resize to original size
        output = F.interpolate(saliency, size=original_size, mode='bilinear', align_corners=True)

        return output

class MeanBasedDecoderV1(nn.Module):

    """

	Has data depedent flow like conditions, torch.tensor conversion that is a problem in import. So, Don’t use shapes like feature.shape[1] in control flow., Avoid conditions like if tensor.sum() > 0: inside the module., Use tensor ops like torch.where, torch.cond, or precompute indices.
    """
    def __init__(self):
        super(MeanBasedDecoder, self).__init__()

    def forward(self, features, original_size, decoder_weights):
        # features: (B, C, H, W)
        batch_size, channels, height, width = features.shape
        decoder_weights = decoder_weights.to(features.device)

        # Create foreground and background masks
        foreground_mask = (decoder_weights == 1)  # shape: (C,)
        background_mask = (decoder_weights == 0)

        # Compute μ1 and μ2 (mean over foreground/background channels)
        mean1 = features[:, foreground_mask, :, :].mean(dim=1)  # shape: (B, H, W)
        mean2 = features[:, background_mask, :, :].mean(dim=1)  # shape: (B, H, W)

        # Expand for broadcasting: shape → (B, C, H, W)
        mean1_exp = mean1.unsqueeze(1)
        mean2_exp = mean2.unsqueeze(1)

        # Expand weights for computation: (C,) → (1, C, 1, 1)
        weights = decoder_weights.view(1, channels, 1, 1).float()

        # Construct α vector: +1, -1, or 0 based on paper logic
        alpha = torch.where(
            (weights == 1) & (mean1_exp > mean2_exp), torch.ones_like(weights),
            torch.where((weights == 0) & (mean1_exp < mean2_exp), -torch.ones_like(weights), torch.zeros_like(weights))
        )

        # Apply weights to features
        weighted_features = features * alpha  # (B, C, H, W)

        # Sum across channels to produce output map
        output = weighted_features.sum(dim=1, keepdim=True)  # shape: (B, 1, H, W)

        # Interpolate to original size
        output_resized = F.interpolate(output, size=original_size, mode='bilinear', align_corners=True)

        return output_resized


class ConvNet(nn.Module):
    def __init__(self, config, original_size):
        super(ConvNet, self).__init__()
        layers = []
        input_channels = 3  # Assuming RGB input; modify as needed
        stdev = config.get("stdev_factor", 0.01)

        for i in range(1, config["nlayers"] + 1):
            layer_cfg = config[f"layer{i}"]
            conv_cfg = layer_cfg["conv"]
            pool_cfg = layer_cfg["pooling"]
            use_relu = layer_cfg.get("relu", False)

            # Convolution layer
            layers.append(nn.Conv2d(
                in_channels=input_channels,
                out_channels=conv_cfg["noutput_channels"],
                kernel_size=tuple(conv_cfg["kernel_size"][:2]),
                dilation=tuple(conv_cfg["dilation_rate"][:2]),
                padding=1
            ))

            # Pooling layer
            pool_type = pool_cfg["type"]
            pool_args = {
                'kernel_size': tuple(pool_cfg["size"][:2]),
                'stride': pool_cfg["stride"],
                'padding': 1
            }
            if pool_type == "avg_pool":
                layers.append(nn.AvgPool2d(**pool_args))
            elif pool_type == "max_pool":
                layers.append(nn.MaxPool2d(**pool_args))

            # Activation
            if use_relu:
                layers.append(nn.ReLU())

            input_channels = conv_cfg["noutput_channels"]

        self.feature_extractor = nn.Sequential(*layers)
        #self.classifier = nn.Linear(30 * 8 * 8, 10)  # Adjust size & output classes as needed

        decoder_input_size = 30 #config[f"layer{config["nlayers"]}"]["conv"]["noutput_channels"]

        adaptation_function="robust_weights"
        decoder_type = 'decoder_3'
        filter_by_size=False
       # device='cpu'
        #self.decoder = FLIMAdaptiveDecoderLayer(decoder_input_size, adaptation_function=adaptation_function, decoder_type=decoder_type, filter_by_size=filter_by_size)

        self.decoder =  MeanBasedDecoder()

        self.original_size = original_size #(512,510)
        self.decoder_weights = torch.randint(low=0, high=2, size=(30,))


    def forward(self, x):
        x = self.feature_extractor(x)
        #x = torch.flatten(x, 1)
        #x = self.classifier(x)
        x = self.decoder(x, self.original_size, self.decoder_weights )
        return x

# Example: Build model
model = ConvNet(config, (224, 224)).eval()
print(model)

### Test implementation
sample_inputs = torch.randn(1, 3, 224, 224)
with torch.no_grad():
	out = model(sample_inputs)
	print(out.shape)


ConvNet(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): AvgPool2d(kernel_size=(2, 2), stride=1, padding=1)
    (2): ReLU()
    (3): Conv2d(30, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): AvgPool2d(kernel_size=(2, 2), stride=1, padding=1)
    (5): ReLU()
  )
  (decoder): MeanBasedDecoder()
)
torch.Size([1, 1, 224, 224])


In [36]:
model = FastSal().eval().to("cpu")#SeaNet(pretrained=False)#FastSal().eval().to("cpu") #U2NET().eval().to("cpu")
#print(model)

### Test implementation
sample_inputs = torch.randn(1, 3, 224, 224)
with torch.no_grad():
	out = model(sample_inputs)
	#print(out.shape)


In [3]:
from torch.fx import symbolic_trace
traced = symbolic_trace(model)
print(traced.graph)


graph():
    %input_1 : [num_users=6] = placeholder[target=input]
    %backbone_level1_0_conv : [num_users=1] = call_module[target=backbone.level1_0.conv](args = (%input_1,), kwargs = {})
    %backbone_level1_0_pool : [num_users=1] = call_module[target=backbone.level1_0.pool](args = (%input_1,), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.cat](args = ([%backbone_level1_0_conv, %backbone_level1_0_pool], 1), kwargs = {})
    %backbone_level1_0_bn : [num_users=1] = call_module[target=backbone.level1_0.bn](args = (%cat,), kwargs = {})
    %backbone_level1_0_act : [num_users=3] = call_module[target=backbone.level1_0.act](args = (%backbone_level1_0_bn,), kwargs = {})
    %backbone_level1_0_conv_1 : [num_users=1] = call_module[target=backbone.level1.0.conv](args = (%backbone_level1_0_act,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%backbone_level1_0_act, %backbone_level1_0_conv_1), kwargs = {})
    %backbone_level1_0_bn_1 : [num

In [4]:
for name, param in model.named_parameters():
    if param.requires_grad == False:
      print(f"{name}: requires_grad={param.requires_grad}")


In [12]:
for name, param in model.named_parameters():
    print(f"{name}: {param.dtype}")


backbone.level1_0.conv.weight: torch.float32
backbone.level1_0.bn.weight: torch.float32
backbone.level1_0.bn.bias: torch.float32
backbone.level1_0.act.weight: torch.float32
backbone.level1.0.conv.weight: torch.float32
backbone.level1.0.bn.weight: torch.float32
backbone.level1.0.bn.bias: torch.float32
backbone.level1.0.act.weight: torch.float32
backbone.level1.1.channel_att.fc.0.weight: torch.float32
backbone.level1.1.channel_att.fc.2.weight: torch.float32
backbone.level1.1.spatialatt.spatial.conv.weight: torch.float32
backbone.level1.1.spatialatt.spatial.bn.weight: torch.float32
backbone.level1.1.spatialatt.spatial.bn.bias: torch.float32
backbone.branch1.weight: torch.float32
backbone.br1.0.weight: torch.float32
backbone.br1.0.bias: torch.float32
backbone.br1.1.weight: torch.float32
backbone.level2_0.conv0.weight: torch.float32
backbone.level2_0.conv1.weight: torch.float32
backbone.level2_0.bn.weight: torch.float32
backbone.level2_0.bn.bias: torch.float32
backbone.level2_0.act.weight: 

In [6]:
for name, buffer in model.named_buffers():
    print(f"{name}: {buffer.dtype}")


NameError: name 'model' is not defined

https://docs.pytorch.org/executorch/stable/backends-xnnpack.html


https://blog.tensorflow.org/2023/11/half-precision-inference-doubles-on-device-inference-performance.html

Interestingly, XNNPack does support FP16 inference — but only on devices with hardware support (like ARMv8.2 CPUs). If your model has FP16 weights and includes the right metadata, XNNPack can transparently switch to FP16 for performance gains

In [6]:
model = model.to(torch.float32)


# Export

https://docs.pytorch.org/executorch/stable/getting-started.html

In [5]:
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower


def main() -> None:
    model = FastSal().eval() #U2NET().eval().to("cpu")#SeaNet(pretrained=False).eval().to("cpu")#FastSal().eval().to("cpu")
    sample_inputs = (torch.randn(1, 3, 224, 224), )

    et_program = to_edge_transform_and_lower(
        torch.export.export(model, sample_inputs),
        partitioner=[XnnpackPartitioner()],
    ).to_executorch()

    with open("hvpnet_xnnpack_fp32.pte", "wb") as file:
        #et_program.write_to_file(file)
        file.write(et_program.buffer)

    print("Finished!")
if __name__ == "__main__":
    main()

Finished!


In [None]:
import torch
from executorch.runtime import Runtime
from typing import List

runtime = Runtime.get()

input_tensor: torch.Tensor = torch.randn(1, 3, 224, 224)
program = runtime.load_program("hvpnet_xnnpack_fp32.pte")
method = program.load_method("forward")
output: List[torch.Tensor] = method.execute([input_tensor])
print("Run succesfully via executorch")


[program.cpp:135] InternalConsistency verification requested but not available


# Test

In [1]:
from executorch.runtime import Runtime

In [5]:
import executorch
executorch.extension.
print(executorch.version)

AttributeError: module 'executorch' has no attribute 'version'

In [3]:
runtime = Runtime.get()
program = runtime.load_program("hvpnet_xnnpack_fp32.pte")
method = program.load_method("forward")

TypeError: Runtime.load_program() got an unexpected keyword argument 'program_verification'

In [4]:
method.metadata


MethodMeta(name='forward', num_inputs=1, input_tensor_meta=['TensorInfo(sizes=[1, 3, 224, 224], dtype=Float, is_memory_planned=True, nbytes=602112)'], num_outputs=7, output_tensor_meta=['TensorInfo(sizes=[1, 1, 224, 224], dtype=Float, is_memory_planned=True, nbytes=200704)', 'TensorInfo(sizes=[1, 1, 224, 224], dtype=Float, is_memory_planned=True, nbytes=200704)', 'TensorInfo(sizes=[1, 1, 224, 224], dtype=Float, is_memory_planned=True, nbytes=200704)', 'TensorInfo(sizes=[1, 1, 224, 224], dtype=Float, is_memory_planned=True, nbytes=200704)', 'TensorInfo(sizes=[1, 1, 224, 224], dtype=Float, is_memory_planned=True, nbytes=200704)', 'TensorInfo(sizes=[1, 1, 224, 224], dtype=Float, is_memory_planned=True, nbytes=200704)', 'TensorInfo(sizes=[1, 1, 224, 224], dtype=Float, is_memory_planned=True, nbytes=200704)'])

In [6]:
import torch
sample_image = torch.randn(1, 3, 224, 224).to("cpu").to(torch.float32)
sample_image.dtype

torch.float32

In [7]:
#sample_image = sample_image.to(torch.float32)


In [9]:
/var/colab/app.log

NameError: name 'var' is not defined

In [None]:
try:
    result = method.execute((sample_image,))
    print("Success:", result)
except Exception as e:
    print("ExecuTorch crash:", e)


In [17]:
result[0].shape

torch.Size([1, 1, 224, 224])

In [14]:
!pip show executorch


Name: executorch
Version: 0.6.0
Summary: On-device AI across mobile, embedded and edge for PyTorch
Home-page: https://pytorch.org/executorch/
Author: 
Author-email: PyTorch Team <packages@pytorch.org>
License: BSD License

For "ExecuTorch" software

Copyright (c) Meta Platforms, Inc. and affiliates.
Copyright 2023 Arm Limited and/or its affiliates.
Copyright (c) Qualcomm Innovation Center, Inc.
Copyright (c) 2023 Apple Inc.
Copyright (c) 2024 MediaTek Inc.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

 * Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

 * Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

 * Neither the name Meta nor the names of its con

# ONNX

In [9]:
!pip install onnxruntime onnx onnxsim


Collecting onnx
  Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.9 kB)
Collecting onnxsim
  Downloading onnxsim-0.4.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Downloading onnx-1.18.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m92.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxsim-0.4.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m64.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: onnx, onnxsim
Successfully installed onnx-1.18.0 onnxsim-0.4.36


In [24]:
import torch
import torch.onnx

# Create the model and set to evaluation mode
model = FastSal()
model.eval()

# Dummy input
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "hvpnet.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['classifier1', 'classifier2', 'classifier3', 'classifier4'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'classifier1': {0: 'batch_size'},
        'classifier2': {0: 'batch_size'},
        'classifier3': {0: 'batch_size'},
        'classifier4': {0: 'batch_size'},
    }
)
print("✅ ONNX model exported as .onnx")


✅ ONNX model exported as fastsal.onnx


In [18]:

import torch
import torch.onnx

# Create the model and set to evaluation mode
model = U2NET()
model.eval()

# Dummy input
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "u2net.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['out1', 'out2', 'out3', 'out4', 'out5', 'out6'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'out1': {0: 'batch_size'},
        'out2': {0: 'batch_size'},
        'out3': {0: 'batch_size'},
        'out4': {0: 'batch_size'},
        'out5': {0: 'batch_size'},
        'out6': {0: 'batch_size'},
    }
)
print("✅ ONNX model exported as .onnx")


  src = F.upsample(src,size=tar.shape[2:],mode='bilinear')


✅ ONNX model exported as .onnx


In [None]:

import torch
import torch.onnx

# Create the model and set to evaluation mode
model = U2NET()
model.eval()

# Dummy input
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "u2net.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['out1', 'out2', 'out3', 'out4', 'out5', 'out6'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'out1': {0: 'batch_size'},
        'out2': {0: 'batch_size'},
        'out3': {0: 'batch_size'},
        'out4': {0: 'batch_size'},
        'out5': {0: 'batch_size'},
        'out6': {0: 'batch_size'},
    }
)
print("✅ ONNX model exported as .onnx")


In [32]:
import torch
import torch.onnx

# Create the model and set to evaluation mode
model = SeaNet(pretrained=False)
model.eval()

# Dummy input
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "seanet.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['s12', 's34_up', 's5_up', 'sigmoid1', 'sigmoid2', 'sigmoid3', 'edge1', 'edge2'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        's12': {0: 'batch_size'},
        's34_up': {0: 'batch_size'},
        's5_up': {0: 'batch_size'},
        'sigmoid1': {0: 'batch_size'},
        'sigmoid2': {0: 'batch_size'},
        'sigmoid3': {0: 'batch_size'},
        'edge1': {0: 'batch_size'},
        'edge2': {0: 'batch_size'},
    }
)
print("✅ ONNX model exported as .onnx")


✅ ONNX model exported as .onnx


In [47]:
model = ConvNet(config, (224, 224)).eval()

import torch
import torch.onnx

# Create the model and set to evaluation mode
model = ConvNet(config, (224, 224)).eval()
model.eval()

# Dummy input
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "flim_small.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['x'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'x': {0: 'batch_size'},
    }
)
print("✅ ONNX model exported as .onnx")


✅ ONNX model exported as .onnx


In [38]:
import torch
import torch.onnx

# Create the model and set to evaluation mode
model = FastSal()
model.eval()

# Dummy input
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "samnet.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=[ 'output_main', 'output_side1', 'output_side2', 'output_side3', 'output_side4' ],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output_main': {0: 'batch_size'},
        'output_side1': {0: 'batch_size'},
        'output_side2': {0: 'batch_size'},
        'output_side3': {0: 'batch_size'},
        'output_side4': {0: 'batch_size'},
    }
)
print("✅ ONNX model exported as .onnx")


SymbolicValueError: Unsupported: ONNX export of operator adaptive_avg_pool2d, input size not accessible. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues  [Caused by the value 'x defined in (%x : Float(*, 128, *, *, strides=[6272, 49, 7, 1], requires_grad=1, device=cpu) = onnx::Add(%2468, %input.1752), scope: __main__.FastSal::/__main__.VAMM_backbone::context_path/torch.nn.modules.container.Sequential::layer5/__main__.VAMM::layer5.3 # /tmp/ipython-input-35-4151006688.py:203:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Add'.] 
    (node defined in /tmp/ipython-input-35-4151006688.py(203): forward
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1741): _slow_forward
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1762): _call_impl
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1751): _wrapped_call_impl
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/container.py(240): forward
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1741): _slow_forward
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1762): _call_impl
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1751): _wrapped_call_impl
/tmp/ipython-input-35-4151006688.py(158): forward
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1741): _slow_forward
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1762): _call_impl
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1751): _wrapped_call_impl
/tmp/ipython-input-35-4151006688.py(35): forward
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1741): _slow_forward
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1762): _call_impl
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1751): _wrapped_call_impl
/usr/local/lib/python3.11/dist-packages/torch/jit/_trace.py(129): wrapper
/usr/local/lib/python3.11/dist-packages/torch/jit/_trace.py(138): forward
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1762): _call_impl
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py(1751): _wrapped_call_impl
/usr/local/lib/python3.11/dist-packages/torch/jit/_trace.py(1501): _get_trace_graph
/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py(878): _trace_and_get_graph_from_model
/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py(971): _create_jit_graph
/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py(1087): _model_to_graph
/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py(1467): _export
/usr/local/lib/python3.11/dist-packages/torch/onnx/utils.py(529): export
/usr/local/lib/python3.11/dist-packages/torch/onnx/__init__.py(396): export
/tmp/ipython-input-38-2846696901.py(12): <cell line: 0>
/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py(3553): run_code
/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py(3473): run_ast_nodes
/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py(3257): run_cell_async
/usr/local/lib/python3.11/dist-packages/IPython/core/async_helpers.py(78): _pseudo_sync_runner
/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py(3030): _run_cell
/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py(2975): run_cell
/usr/local/lib/python3.11/dist-packages/ipykernel/zmqshell.py(528): run_cell
/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py(383): do_execute
/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py(730): execute_request
/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py(406): dispatch_shell
/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py(499): process_one
/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py(510): dispatch_queue
/usr/lib/python3.11/asyncio/events.py(84): _run
/usr/lib/python3.11/asyncio/base_events.py(1936): _run_once
/usr/lib/python3.11/asyncio/base_events.py(608): run_forever
/usr/local/lib/python3.11/dist-packages/tornado/platform/asyncio.py(205): start
/usr/local/lib/python3.11/dist-packages/ipykernel/kernelapp.py(712): start
/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py(992): launch_instance
/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py(37): <module>
<frozen runpy>(88): _run_code
<frozen runpy>(198): _run_module_as_main
)

    Inputs:
        #0: 2468 defined in (%2468 : Float(*, 128, *, *, strides=[6272, 49, 7, 1], requires_grad=1, device=cpu) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%input.1864, %context_path.layer5.3.fuse.conv.1.weight, %context_path.layer5.3.fuse.conv.1.bias, %context_path.layer5.3.fuse.conv.1.running_mean, %context_path.layer5.3.fuse.conv.1.running_var), scope: __main__.FastSal::/__main__.VAMM_backbone::context_path/torch.nn.modules.container.Sequential::layer5/__main__.VAMM::layer5.3/__main__.convbnrelu::fuse/torch.nn.modules.container.Sequential::conv/torch.nn.modules.batchnorm.BatchNorm2d::conv.1 # /usr/local/lib/python3.11/dist-packages/torch/nn/functional.py:2822:0
    )  (type 'Tensor')
        #1: input.1752 defined in (%input.1752 : Float(*, 128, *, *, strides=[6272, 49, 7, 1], requires_grad=1, device=cpu) = onnx::Add(%2406, %input.1636), scope: __main__.FastSal::/__main__.VAMM_backbone::context_path/torch.nn.modules.container.Sequential::layer5/__main__.VAMM::layer5.2 # /tmp/ipython-input-35-4151006688.py:203:0
    )  (type 'Tensor')
    Outputs:
        #0: x defined in (%x : Float(*, 128, *, *, strides=[6272, 49, 7, 1], requires_grad=1, device=cpu) = onnx::Add(%2468, %input.1752), scope: __main__.FastSal::/__main__.VAMM_backbone::context_path/torch.nn.modules.container.Sequential::layer5/__main__.VAMM::layer5.3 # /tmp/ipython-input-35-4151006688.py:203:0
    )  (type 'Tensor')

In [None]:
#Optional
from onnxsim import simplify
import onnx

model_path = "fastsal.onnx"
onnx_model = onnx.load(model_path)
model_simplified, check = simplify(onnx_model)

if check:
    onnx.save(model_simplified, "fastsal_simplified.onnx")
    print("✅ Simplified model saved as fastsal_simplified.onnx")
else:
    print("❌ Simplification failed.")


Test

In [48]:
import onnxruntime as ort
import numpy as np

# Load the ONNX model
session = ort.InferenceSession("flim_small.onnx")

# Print input/output names
print("Inputs:", [i.name for i in session.get_inputs()])
print("Outputs:", [o.name for o in session.get_outputs()])

# Create a dummy input tensor (batch size 1, 3 channels, 224x224 image)
dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)

# Run inference
outputs = session.run(None, {"input": dummy_input})

# Check output shapes
for i, out in enumerate(outputs):
    print(f"Output {i + 1}: shape = {out.shape}")


Inputs: ['input']
Outputs: ['x']
Output 1: shape = (1, 1, 224, 224)


https://github.com/microsoft/onnxruntime-inference-examples/blob/main/mobile/README.md


In [15]:
!python prepare.py --output_dir /content/outputs

Using cached MobileNet v2-1.0-int8 model from /root/.cache/onnx/hub/validated/vision/classification/mobilenet/model/cc028fe6cae7bc11a4ff53cfc9b79c920e8be65ce33a904ec3e2a8f66d77f95f_mobilenetv2-12-int8.onnx
Using cached MobileNet v2-1.0-fp32 model from /root/.cache/onnx/hub/validated/vision/classification/mobilenet/model/c0c3f76d93fa3fd6580652a45618618a220fced18babf65774ed169de0432ad5_mobilenetv2-12.onnx


In [21]:
val inputStream = assets.open("u2net.onnx")
val modelBytes = inputStream.readBytes()


SyntaxError: invalid syntax (ipython-input-21-1429801123.py, line 1)