<a href="https://colab.research.google.com/github/harshjoshi23/Object_segmentation_TF/blob/main/Different_Segmentation_Models_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Imports and libraries**

In [8]:
# !pip install segmentation-models-pytorch
# Run this cell once before running

In [9]:
import segmentation_models_pytorch as smp
import torch


## Implementing FPN with ResNet 34

**Defining the Model:**

In [10]:
# Define the FPN model with ResNet34 encoder
model = smp.FPN(
    encoder_name="resnet34",        # Choose encoder
    encoder_weights="imagenet",     # Use pretrained weights
    decoder_pyramid_channels=256,   # Number of convolution filters in FPN
    decoder_segmentation_channels=128,  # Number of filters in segmentation blocks
    decoder_merge_policy='add',     # Merge policy
    decoder_dropout=0.2,            # Spatial dropout rate
    in_channels=3,                  # Number of input channels
    classes=3,                      # Number of output classes
    activation='softmax'            # Activation function
)


Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 221MB/s]


**Creating the Input Tensor:**



In [11]:
# Define input tensor with random values
x = torch.rand(1, 3, 256, 256)  # Example input (batch_size, channels, height, width)


**Getting the Output:**

In [12]:
# Get the model output
mask = model(x)

# Print the shape of the output
print(mask.shape)  # Should be (1, num_classes, height, width)


torch.Size([1, 3, 256, 256])


  return self._call_impl(*args, **kwargs)


## **Defining the variables**

In [13]:
# Auxiliary parameters
aux_params = dict(
    pooling='avg',
    dropout=0.5,
    activation='sigmoid',
    classes=4,
)

In [14]:
# Defining correct decoder_channels for encoder_depth=4

decoder_channels = [256, 128, 64, 32]

## Detailed Overview of Segmentation Models

#### Common Parameters
- `encoder_name`: Name of the classification model used as the encoder (e.g., `resnet34`).
- `encoder_depth`: Number of stages in the encoder, ranging from 3 to 5.
- `encoder_weights`: Pre-trained weights (e.g., `imagenet`).
- `decoder_channels`: List of integers specifying the `in_channels` parameter for convolutions in the decoder.
- `decoder_use_batchnorm`: Whether to use `BatchNorm2d` in the decoder.
- `decoder_attention_type`: Attention module used in the decoder (options: `None`, `scse`).
- `in_channels`: Number of input channels (default is 3).
- `classes`: Number of classes for the output mask.
- `activation`: Activation function applied after the final convolution layer.
- `aux_params`: Parameters for auxiliary output (classification head).


## **Code for Different Models**

### 1. Unet
- **Class**: `segmentation_models_pytorch.Unet`
- **Description**: Unet is a fully convolutional network designed for image semantic segmentation. It consists of an encoder (downsampling) and a decoder (upsampling) connected via skip connections.
- **Key Parameters**: See common parameters.

In [15]:
# Create the model with custom encoder depth and decoder channels
model = smp.Unet('resnet34', encoder_depth=4, decoder_channels=decoder_channels, classes=4, aux_params=aux_params)

# Define input tensor 'x' with random values
x = torch.randn(1, 3, 256, 256)  # example input (batch size, channels, height, width)

# Get the output
mask, label = model(x)

# Print the shapes of the outputs
print(mask.shape)  # Should be (1, num_classes, height, width)
print(label.shape)  # Should be (1, num_aux_classes)

torch.Size([1, 4, 256, 256])
torch.Size([1, 4])


### 2. Unet++
- **Class**: `segmentation_models_pytorch.UnetPlusPlus`
- **Description**: Unet++ improves upon Unet by using nested and dense skip connections to better capture multi-scale features.
- **Key Parameters**: Same as Unet, with a more complex decoder.

In [16]:
# Create the model with custom encoder depth and decoder channels
model = smp.UnetPlusPlus('resnet34', encoder_depth=4, decoder_channels=decoder_channels, classes=4, aux_params=aux_params)

# Define input tensor 'x' with random values
x = torch.randn(1, 3, 256, 256)  # example input (batch size, channels, height, width)

# Get the output
mask, label = model(x)

# Print the shapes of the outputs
print(mask.shape)  # Should be (1, num_classes, height, width)
print(label.shape)  # Should be (1, num_aux_classes)

torch.Size([1, 4, 256, 256])
torch.Size([1, 4])


### 3. MAnet
- **Class**: `segmentation_models_pytorch.MAnet`
- **Description**: Multi-scale Attention Net (MA-Net) captures rich contextual dependencies using Position-wise Attention Block (PAB) and Multi-scale Fusion Attention Block (MFAB).
- **Key Parameters**:
  - `decoder_pab_channels`: Number of channels for the PAB module in the decoder.
  - Additional parameters similar to Unet.

In [17]:
model = smp.MAnet('resnet34', encoder_depth=4, decoder_channels=decoder_channels, classes=4, aux_params=aux_params)

# Define input tensor 'x' with random values
x = torch.randn(1, 3, 256, 256)  # example input (batch size, channels, height, width)

# Get the output
mask, label = model(x)

# Print the shapes of the outputs
print(mask.shape)  # Should be (1, num_classes, height, width)
print(label.shape)  # Should be (1, num_aux_classes)

torch.Size([1, 4, 256, 256])
torch.Size([1, 4])


### 4. Linknet
- **Class**: `segmentation_models_pytorch.Linknet`
- **Description**: Linknet is a lightweight architecture designed for fast inference, using summation for fusing decoder blocks with skip connections.
- **Key Parameters**: Similar to Unet.

In [18]:
# Create the model
model = smp.Linknet('resnet34', classes=4, aux_params=aux_params)

# Define input tensor 'x' with random values
x = torch.randn(1, 3, 256, 256)  # example input (batch size, channels, height, width)

# Get the output
mask, label = model(x)

# Print the shapes of the outputs
print(mask.shape)  # Should be (1, num_classes, height, width)
print(label.shape)  # Should be (1, num_aux_classes)

torch.Size([1, 4, 256, 256])
torch.Size([1, 4])


### 5. FPN
- **Class**: `segmentation_models_pytorch.FPN`
- **Description**: Feature Pyramid Network (FPN) uses a pyramid structure to enhance the features extracted by the encoder.
- **Key Parameters**:
  - `decoder_pyramid_channels`: Number of convolution filters in the FPN.
  - `decoder_segmentation_channels`: Number of convolution filters in the segmentation blocks.
  - `decoder_merge_policy`: How to merge pyramid features (options: `add`, `cat`).
  - `decoder_dropout`: Spatial dropout rate for the FPN.
  - Additional parameters similar to Unet.

In [19]:
# Create the model
model = smp.FPN('resnet34', classes=4, aux_params=aux_params)

# Define input tensor 'x' with random values
x = torch.randn(1, 3, 256, 256)  # example input (batch size, channels, height, width)

# Get the output
mask, label = model(x)

# Print the shapes of the outputs
print(mask.shape)  # Should be (1, num_classes, height, width)
print(label.shape)  # Should be (1, num_aux_classes)


torch.Size([1, 4, 256, 256])
torch.Size([1, 4])


## Explanation and Corrected Code for Each Model

#### Common Issue I fased and Solution:
The error occurs due to using a batch size of 1 when training, which is incompatible with batch normalization layers. Batch normalization layers require a batch size greater than 1 to calculate meaningful statistics. The solution involves either increasing the batch size or switching the model to evaluation mode for inference.

---

#### DeepLabV3

**Explanation:**

- **Auxiliary Parameters:** Configures the auxiliary classification output, which adds an additional classification head to the model.
- **Model Creation:** Creates a DeepLabV3 model with a ResNet34 encoder.
- **Input Tensor:** Defines the input tensor with random values. Here, we use a batch size of 4 to avoid issues with batch normalization.
- **Output Handling:** Checks if the output is a tuple (when auxiliary parameters are used) and prints the shapes of the outputs.

---

#### DeepLabV3+

**Explanation:**

- **Auxiliary Parameters:** Configures the auxiliary classification output.
- **Model Creation:** Creates a DeepLabV3+ model with a ResNet34 encoder.
- **Model Evaluation Mode:** Sets the model to evaluation mode, which changes the behavior of certain layers like batch normalization.
- **Input Tensor:** Defines the input tensor with random values.
- **Output Handling:** Checks if the output is a tuple and prints the shapes of the outputs.

---

#### PAN

**Explanation:**

- **Auxiliary Parameters:** Configures the auxiliary classification output.
- **Model Creation:** Creates a PAN model with a ResNet34 encoder.
- **Options:** Provides two options based on whether the model is used for training or inference.
  - **Training Mode:** Increase the batch size to more than 1 to avoid issues with batch normalization.
  - **Evaluation Mode:** Set the model to evaluation mode to use batch normalization with a batch size of 1.

---

#### Summary

- **DeepLabV3 and DeepLabV3+:** Use appropriate auxiliary parameters and check for tuple output. Adjust batch size as needed and use evaluation mode for inference to avoid batch normalization issues.
- **PAN:** Increase batch size for training or use evaluation mode for inference to avoid batch normalization issues.



### 7. PAN
- **Class**: `segmentation_models_pytorch.PAN`
- **Description**: Pyramid Attention Network (PAN) focuses on aggregating multi-scale features with global attention mechanisms.
- **Key Parameters**:
  - `encoder_output_stride`: Determines the downsampling factor in the encoder.
  - `decoder_channels`: Number of convolution filters in the decoder.
  - Additional parameters similar to Unet.


### **For Training (Increase Batch Size)**

In [20]:
import torch
import segmentation_models_pytorch as smp

# Define auxiliary parameters
aux_params = dict(
    pooling='avg',
    dropout=0.5,
    activation='sigmoid',
    classes=4,
)

# Create the model with supported encoder output stride
model = smp.PAN(
    encoder_name='resnet34',
    encoder_output_stride=16,  # Correct stride
    classes=4,
    aux_params=aux_params
)

# Define input tensor 'x' with random values for a larger batch size
x = torch.randn(4, 3, 256, 256)  # Increase batch size from 1 to 4

# Get the output
output = model(x)

# Print the shapes of the outputs
if isinstance(output, tuple):
    mask, label = output
    print(mask.shape)
    print(label.shape)
else:
    mask = output
    print(mask.shape)


torch.Size([4, 4, 256, 256])
torch.Size([4, 4])


### **For Inference (Use Evaluation Mode)**


In [21]:
import torch
import segmentation_models_pytorch as smp

# Define auxiliary parameters
aux_params = dict(
    pooling='avg',
    dropout=0.5,
    activation='sigmoid',
    classes=4,
)

# Create the model with supported encoder output stride
model = smp.PAN(
    encoder_name='resnet34',
    encoder_output_stride=16,
    classes=4,
    aux_params=aux_params
)

model.eval()  # Set model to evaluation mode

# Define input tensor 'x' with random values
x = torch.randn(1, 3, 256, 256)  # Can use a batch size of 1 in evaluation mode

# Get the output
output = model(x)

# Print the shapes of the outputs
if isinstance(output, tuple):
    mask, label = output
    print(mask.shape)
    print(label.shape)
else:
    mask = output
    print(mask.shape)


torch.Size([1, 4, 256, 256])
torch.Size([1, 4])


### 8. DeepLabV3
- **Class**: `segmentation_models_pytorch.DeepLabV3`
- **Description**: DeepLabV3 utilizes atrous convolution to capture multi-scale context by increasing the receptive field.
- **Key Parameters**: Similar to Unet.

In [22]:
# Define auxiliary parameters
aux_params = dict(
    pooling='avg',  # Consider changing or removing this if the error persists
    dropout=0.5,
    activation='sigmoid',
    classes=4,
)

# Create the model
model = smp.DeepLabV3(
    encoder_name='resnet34',
    classes=4,
    aux_params=aux_params  # Set to None if you want to disable the auxiliary classifier
)

# Define input tensor 'x' with random values
x = torch.randn(4, 3, 256, 256)  # Increased batch size

# Get the output
output = model(x)  # Adjust based on whether aux_params is None or not

# Check outputs
if isinstance(output, tuple):
    mask, label = output
    print(mask.shape, label.shape)
else:
    mask = output
    print(mask.shape)


torch.Size([4, 4, 256, 256]) torch.Size([4, 4])



### 9. DeepLabV3+
- **Class**: `segmentation_models_pytorch.DeepLabV3Plus`
- **Description**: An enhanced version of DeepLabV3, which includes a decoder module to improve object localization.
- **Key Parameters**: Similar to Unet.


In [23]:
import torch
import segmentation_models_pytorch as smp

# Define auxiliary parameters
aux_params = dict(
    pooling='avg',  # Adjust based on the required output size
    dropout=0.5,
    activation='sigmoid',
    classes=4,
)

# Create the model
model = smp.DeepLabV3Plus(
    encoder_name='resnet34',
    classes=4,
    aux_params=aux_params
)
model.eval()  # Set model to evaluation mode

# Define input tensor 'x' with random values
x = torch.randn(1, 3, 256, 256)  # example input (batch size, channels, height, width)

# Get the output
output = model(x)  # Handle output based on whether aux_params is used

# Print the shapes of the outputs
if isinstance(output, tuple):
    mask, label = output
    print(mask.shape)
    print(label.shape)
else:
    mask = output
    print(mask.shape)


torch.Size([1, 4, 256, 256])
torch.Size([1, 4])
