In [1]:
pip install torchinfo

Note: you may need to restart the kernel to use updated packages.


In [11]:
import torch
import torch.nn as nn
from torchinfo import summary
device="cuda" if torch.cuda.is_available else "cpu"

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.relu(self.bn1(self.conv1(inputs)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x

class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)
        self.conv = conv_block(2 * out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x

class segUnet(nn.Module):
    def __init__(self, num_classes, in_channels=3, depth=5, start_filts=64):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        """ Encoders """
        self.encoders = nn.ModuleList([encoder_block(in_channels, start_filts)])
        self.encoders.extend([encoder_block(start_filts * (2 ** i), start_filts * (2 ** (i + 1))) for i in range(depth - 1)])

        """ Bottleneck """
        self.bottleneck = conv_block(start_filts * (2 ** (depth - 1)), start_filts * (2 ** depth))

        """ Decoders """
        self.decoders = nn.ModuleList([decoder_block(start_filts * (2 ** i), start_filts * (2 ** (i - 1))) for i in range(depth, 0, -1)])

        """ Classifier """
        self.outputs = nn.Conv2d(start_filts, num_classes, kernel_size=1)

    def forward(self, inputs):
        skips = []
        x = inputs
        for encoder in self.encoders:
            x, p = encoder(x)
            skips.append(x)
            x = p

        x = self.bottleneck(x)

        for i, decoder in enumerate(self.decoders):
            x = decoder(x, skips[-(i+1)])

        outputs = self.outputs(x)
        return outputs

# Check for CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"

# Instantiate the model and move it to the appropriate device
model = segUnet(num_classes=3, in_channels=3, depth=3, start_filts=8).to(device)

# Print model summary using torchinfo
summary(model, input_size=(1, 3,16, 16), device=device)
x=torch.randn((1,3,16,16))
print(model(x))

tensor([[[[ 0.2934,  0.7496,  0.0298,  0.6885,  0.2576,  0.4183,  0.6098,
            0.5581,  0.4596,  0.6959,  0.7601,  0.6012,  0.1844,  0.6143,
            0.3729,  0.3686],
          [ 0.2226,  0.2489,  0.7342, -0.0440,  0.5824,  0.3492,  0.7194,
            0.0847,  0.3853,  0.9299, -0.0835,  0.7122,  0.4734,  0.4886,
            0.0520,  0.5180],
          [ 0.3903,  0.9785,  0.2720,  0.3218,  0.4021,  0.5943,  0.2930,
            0.8108, -0.0618,  0.5473,  0.7174,  0.3333,  0.2123,  1.3025,
            0.6022,  0.4878],
          [ 0.4036,  0.3093,  0.2662,  0.4898,  0.1147,  0.3808,  0.9981,
            0.1506,  1.2372,  1.1260,  0.4228,  0.0624,  0.9751,  0.1683,
            0.4065,  0.7740],
          [ 0.2886,  0.2034, -0.1459,  0.4930,  0.8807, -0.5096,  0.0829,
            0.3380,  0.3801,  0.9325,  0.6704,  0.5806,  0.0240,  0.2116,
           -0.0629,  0.5145],
          [ 0.4846,  0.4377,  0.7570,  0.3116,  0.6457,  0.0429,  0.8268,
           -0.2849,  0.6896,  0.4779

In [17]:
from urllib.request import urlopen
from PIL import Image
import timm
import torch
import torch.nn as nn
from torchinfo import summary
import segmentation_models_pytorch as smp
class Res34Unet(nn.Module):
    def __init__(self, num_classes, in_channels=3,start_filts=64, depth=4, negative_slope=0.01):
        super().__init__()

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.depth = depth
        self.negative_slope=negative_slope
        self.start_filts=start_filts 
        self.internal_masks=[]

        decoder_channels=tuple([(2**i)*self.start_filts for i in range(self.depth,0,-1)])
        self.model =  smp.Unet(
            encoder_name="resnet152",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights="imagenet",
            decoder_use_batchnorm=True,
            decoder_channels=decoder_channels,
            encoder_depth=self.depth ,  # use `imagenet` pre-trained weights for encoder initialization
            in_channels=self.in_channels,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=num_classes,
        )

    def forward(self, inputs):
        return self.model(inputs)
    
device = "cuda" if torch.cuda.is_available() else "cpu"

model6=Res34Unet(num_classes=3,in_channels=3,start_filts=4,depth=4)
# Print model summary using torchinfo
summary(model6, input_size=(1, 3,64,64), device=device)

Layer (type:depth-idx)                             Output Shape              Param #
Res34Unet                                          [1, 3, 64, 64]            --
├─Unet: 1-1                                        [1, 3, 64, 64]            --
│    └─ResNetEncoder: 2-1                          [1, 3, 64, 64]            13,114,368
│    │    └─Conv2d: 3-1                            [1, 64, 32, 32]           9,408
│    │    └─BatchNorm2d: 3-2                       [1, 64, 32, 32]           128
│    │    └─ReLU: 3-3                              [1, 64, 32, 32]           --
│    │    └─MaxPool2d: 3-4                         [1, 64, 16, 16]           --
│    │    └─Sequential: 3-5                        [1, 64, 16, 16]           221,952
│    │    └─Sequential: 3-6                        [1, 128, 8, 8]            1,116,416
│    │    └─Sequential: 3-7                        [1, 256, 4, 4]            6,822,400
│    └─UnetDecoder: 2-2                            [1, 64, 64, 64]           --
│   

In [13]:
torch.set_printoptions(threshold=torch.inf)

x=torch.randn((1,3,16,16))
print(model6(x))

tensor([[[[ 5.1729e-02, -8.2027e-01, -6.7957e-01, -1.1168e+00, -6.9178e-01,
           -2.4227e-01, -6.0260e-01, -3.4603e-01, -2.2835e-01,  7.2157e-01,
           -2.5448e-01, -7.2431e-01,  3.0075e-01, -1.3035e+00, -5.3631e-01,
            5.6581e-01],
          [ 4.9090e-01, -9.1410e-01, -7.4854e-01, -6.9451e-01, -4.7431e-01,
           -1.0747e+00, -1.0025e+00, -1.9974e-01, -1.0810e+00, -5.8953e-01,
           -1.6099e-01, -2.9569e-01,  1.2158e-01, -5.3102e-01, -6.1296e-01,
            8.3656e-01],
          [-6.0371e-02, -4.7486e-01, -8.1461e-02, -9.5704e-01, -6.0165e-01,
           -3.5285e-01, -7.1410e-01, -1.1587e+00, -1.4112e+00, -1.6577e+00,
           -1.0785e+00, -7.6726e-01,  2.3278e-01,  5.0061e-01, -3.2182e-01,
            8.1539e-01],
          [ 7.1781e-01, -7.5229e-01, -7.2120e-01, -3.1911e-01, -6.3669e-01,
           -8.4784e-01, -1.4153e+00, -1.4862e+00, -2.0706e+00, -1.6637e+00,
           -9.1859e-01, -6.5429e-01,  1.0354e-01,  1.3613e-01, -1.4652e-01,
            8