# Task 2.1

“Given a video of the modal mask of an object, predict the amodal mask of the same object”

Inputs:
- RGB Frames - N 3-channel images (N, 3, 256, 256)
- N Modal Masks of Object i - N Binary (1-channel) images (N, 1, 256, 256)

Outputs:
- N Amodal Masks of Object i - N Binary (1-channel) images (1, 256, 256)

In [1]:
# PyTorch, Torchvision
import torch
from torch import nn
from torchvision.transforms import ToPILImage, ToTensor
from torchvision.utils import make_grid
from torchvision.io import write_video

# Common
from pathlib import Path
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import random
import json
from IPython.display import Video

# Utils from Torchvision
tensor_to_image = ToPILImage()
image_to_tensor = ToTensor()

In [2]:
torch.cuda.is_available()

False

### Model

In [3]:

def calculate_metrics(amodal_mask_preds, 
                      amodal_mask_labels, 
                      amodal_content_preds,
                      amodal_content_labels):
    """
    Here, you can calculate non-loss metrics like mIOU, accuracy, J&F scores.

    And non-loss image generation metrics between the predicted and ground-truth amodal content
    Such as Inception Score, Frechet Inception Distance, Learned Perceptual Patch Similarity (LPIPS),
    Structure Similarity Index Metric (SSIM), Peak Signal-Noise Ratio (PSNR)

    These should all have easy-to-use implementations in libraries such as TorchMetrics.
    """
    return None


#### Class - Unet_Image

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class conv2d_inplace_spatial(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, pooling_function, activation = nn.GELU()):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)),
            nn.BatchNorm2d(out_channels),
            activation,
            pooling_function,
        )

    def forward(self, x):
        return self.double_conv(x)
    
class Upscale(nn.Module):
    def __init__(self, scale_factor=(2, 2), mode='bilinear', align_corners=False):
        super(Upscale, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners

    def forward(self, x):
        return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)

class Unet_Image(nn.Module):
    def __init__(self, in_channels = 4, mask_content_preds = False):
        super().__init__()
        # downsampling
        self.mpool_2 = nn.MaxPool2d((2, 2))
        
        # down path - encoding - increase dimensions
        self.down1 = conv2d_inplace_spatial(in_channels, 32, self.mpool_2)
        self.down2 = conv2d_inplace_spatial(32, 64, self.mpool_2)
        self.down3 = conv2d_inplace_spatial(64, 128, self.mpool_2)
        self.down4 = conv2d_inplace_spatial(128, 256, self.mpool_2)
        
        # upsample by factor of 2, in both dimensions
        self.upscale_2 = Upscale(scale_factor=(2, 2), mode='bilinear', align_corners=False)
        
        # up path - decoder
        # slowly reduce dimensions
        self.up1 = conv2d_inplace_spatial(256, 128, self.upscale_2)
        self.up2 = conv2d_inplace_spatial(256, 64, self.upscale_2)
        self.up3 = conv2d_inplace_spatial(128, 32, self.upscale_2)
        
        # no activation
        self.up4_amodal_mask = conv2d_inplace_spatial(64, 1, self.upscale_2, activation = nn.Identity())
        self.up4_amodal_content = conv2d_inplace_spatial(64, 3, self.upscale_2, activation = nn.Identity())

        # Optional arguments
        self.mask_content_preds = mask_content_preds # Should we mask the amodal content prediction by the amodal mask prediction?

        # Optimization
        self.mse_loss = nn.L1Loss()
        self.bce_loss = nn.BCEWithLogitsLoss()


        # Self-attention feature enrichment
        #max_seq = 16*16*6
        #token_dim = 256
        #self.pos_enc = nn.Parameter(torch.zeros((max_seq, 1, token_dim))) # seq b dim
        #encoder_layer = nn.TransformerEncoderLayer(d_model=token_dim, nhead=8)
        #self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=4)
    
    # Packages the encoder path (downpath)
    def encode(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        
        # Returns the four
        # [torch.Size([6, 32, 4, 128, 128]), torch.Size([6, 64, 2, 64, 64]), torch.Size([6, 128, 1, 32, 32]), torch.Size([6, 256, 1, 16, 16])]
        return x1, x2, x3, x4
    
    # packages the decode path (up-path)
    def decode(self, h1, h2, h3, h4):
        h4 = self.up1(h4) # 6, 256, 1, 16, 16 -> 6, 128, 1, 32, 32 (double spatial, then conv-in-place channels to half)
        h34 = torch.cat((h3, h4), dim = 1) # (6, 2*128, 1, 32, 32)

        h34 = self.up2(h34) # 6, 256, 1, 32, 32 -> 6, 128, 2, 64, 64
        h234 = torch.cat((h2, h34), dim = 1)

        h234 = self.up3(h234)
        h1234 = torch.cat((h1, h234), dim = 1)
        
        logits_amodal_mask = self.up4_amodal_mask(h1234)
        logits_amodal_content = self.up4_amodal_content(h1234)
        return logits_amodal_mask, logits_amodal_content
    
    # Encode - decode together
    # Runs through the whole model
    def encode_decode(self, x):
        """
        input image tensor: (bs, c, h, w) - batch size, channels, height, width
        """
        b, c, h, w = x.shape

        # Multiscale features x1, x2, x3, x4
        x1, x2, x3, x4 = self.encode(x)

        # You could add code here for example more layers that modify the latent x4? Be creative :)

        # Decode using enriched features
        logits_amodal_mask, logits_amodal_content = self.decode(x1, x2, x3, x4)

        return logits_amodal_mask, logits_amodal_content
        
    def forward(self, batch):
        # This takes in a batch
        # RGB image batch of shape (bs, c, n_frames, h, w)
        # modal_mask (binary, c=1) batch of shape (bs, c, n_frames, h, w)
        # concat rgb and modal masks as input to model
        model_input = torch.cat((batch['frames'], batch['modal_masks']), dim = 1)

        # Remove time dimension from a few things (add it back in later when you make your video model!)
        model_input = model_input.squeeze(2) # remove time dimension (you will probably want it later!)
        # This is what we use for loss calculations
        # These both have range 0 to 1 - same to logits
        amodal_mask_labels = batch['amodal_masks'].float().squeeze(2)
        amodal_content_labels = batch['amodal_content'].float().squeeze(2)

        # Model input
        #print('model input:', model_input.shape)
        
        logits_amodal_mask, logits_amodal_content = self.encode_decode(model_input)

        # Should we mask the amodal content prediction by the predicted amodal mask?
        if self.mask_content_preds:
            # Element-wise masking by self-predictions:
            logits_amodal_content = logits_amodal_mask.sigmoid().round() * logits_amodal_content

        # print('0000 Verify Shapes 0000')
        # print(logits_amodal_mask.shape)
        # print(logits_amodal_content.shape)
        # print(amodal_mask_labels.shape)
        # print(amodal_content_labels.shape)
        # Loss between predicted amodal masks/content and GT masks/content
        mask_loss, content_loss = self.loss_function(logits_amodal_mask, 
                                                        amodal_mask_labels,
                                                        logits_amodal_content, 
                                                        amodal_content_labels)
        
        loss = mask_loss + content_loss

        # Calculate metrics into a dictionary
        metrics = calculate_metrics(logits_amodal_mask, 
                                    amodal_mask_labels, 
                                    logits_amodal_content,
                                    amodal_content_labels)
        
        # Report the metrics we calculated in addition to our loss functions
        metrics = {
            'loss': loss.data.item(),
            'mask_loss': mask_loss.data.item(),
            'content_loss': content_loss.data.item(),
            'other_metrics_to_monitor': 0 # add more metrics here - just make sure they are a number
        }
        return loss, metrics
    def loss_function(self,
                    amodal_mask_preds,
                    amodal_mask_labels,
                    amodal_content_preds,
                    amodal_content_labels):
        # Compares raw logits for binary prediction - object vs no object
        mask_loss = self.bce_loss(amodal_mask_preds, amodal_mask_labels) # Binary Cross Entropy Loss
        # Pixel by pixel difference in the RGB
        content_loss = self.mse_loss(amodal_content_preds, amodal_content_labels) # L1 loss
        return mask_loss, content_loss

#### Example Usage

Test model usage

In [5]:
# Make a U-net which expects a batch of 4-channel images as input (batch_size, channels=4, height, width)
model = Unet_Image(4)

# Make a dummy tensor to test the model can successfully apply its layers and get an output
batch_size = 1
rand_input = torch.randn((batch_size, 4, 256, 256))

# Apply the model to the input - we use encode decode here rather than forward
# because we don't have the full batch yet - we will later
logits_amodal_mask, logits_amodal_content = model.encode_decode(rand_input)
print('Model output:', logits_amodal_mask.shape, logits_amodal_content.shape)

# For our use-case, we are predicting amodal masks and amodal content
# The amodal mask prediction should be (batch_size, 1, h, w) for a binary mask and the 
# Amodal content prediction should be  (batch_size, 3, h, w) for an rgb object prediction

Model output: torch.Size([1, 1, 256, 256]) torch.Size([1, 3, 256, 256])


#### Class - Unet_Image + LSTM

In [5]:
class Unet_Image_LSTM(Unet_Image):
    def __init__(self, in_channels=4, mask_content_preds=False, input_size=(128, 128)):
        super().__init__(in_channels, mask_content_preds)
        # Dynamically infer bottleneck shape
        dummy = torch.zeros(1, in_channels, *input_size)
        with torch.no_grad():
            _, _, _, x4 = self.encode(dummy)
        _, bottleneck_channels, h_b, w_b = x4.shape
        self.bottleneck_channels = bottleneck_channels
        self.bottleneck_h = h_b
        self.bottleneck_w = w_b
        self.lstm_input_size = bottleneck_channels * h_b * w_b
        self.lstm_hidden_size = self.lstm_input_size
        self.lstm = nn.LSTM(
            input_size=self.lstm_input_size,
            hidden_size=self.lstm_hidden_size,
            num_layers=1,
            batch_first=True,
            bidirectional=False
        )

    def encode_frames(self, x, n):
        x1_list, x2_list, x3_list, x4_list = [], [], [], []
        for t in range(n): # n gives how many frames to expect
            frame_t = x[:, :, t, :, :]  # (bs, c+1, h, w), same shape as the inputs of Unet_Image
            x1, x2, x3, x4 = self.encode(frame_t)
            x1_list.append(x1)
            x2_list.append(x2)
            x3_list.append(x3)
            x4_list.append(x4)
        # Stack along time: (bs, cX, n_frames, hX, wX)
        x1_seq = torch.stack(x1_list, dim=2)
        x2_seq = torch.stack(x2_list, dim=2)
        x3_seq = torch.stack(x3_list, dim=2)
        x4_seq = torch.stack(x4_list, dim=2)
        return x1_seq, x2_seq, x3_seq, x4_seq
    
    def run_lstm(self, h4_seq):
        # Uses the last layer of encoder output
        # Prepare LSTM input: flatten spatial dims
        bs, ch, n, h_b, w_b = h4_seq.shape
        h4_seq_flat = h4_seq.permute(0, 2, 1, 3, 4).contiguous().view(bs, n, -1)  # (bs, n, ch*h_b*w_b)
        # Pass through LSTM
        lstm_out, _ = self.lstm(h4_seq_flat)  # (bs, n, hidden_size)
        # Reshape LSTM output back to spatial, then permute back to (bs, ch, n, h, w)
        lstm_out_spatial = lstm_out.view(bs, n, ch, h_b, w_b).permute(0, 2, 1, 3, 4)
        return lstm_out_spatial
    
    def decode_frames(self, h1_seq, h2_seq, h3_seq, h4_seq, n):
        # Decode each timestep
        logits_mask_seq = []
        logits_content_seq = []
        for t in range(n): # n is the number of consecutive frames
            # Picks a single entry of the time dimension resulting in (bs, c, h, w) shape
            logits_mask, logits_content = self.decode(
                h1_seq[:, :, t, :, :],
                h2_seq[:, :, t, :, :],
                h3_seq[:, :, t, :, :],
                h4_seq[:, :, t, :, :]
            )
            logits_mask_seq.append(logits_mask)
            logits_content_seq.append(logits_content)
        # Stack outputs along time (dim=2)
        logits_mask_seq = torch.stack(logits_mask_seq, dim=2)      # (bs, 1, n_frames, h, w)
        logits_content_seq = torch.stack(logits_content_seq, dim=2) # (bs, 3, n_frames, h, w)
        return logits_mask_seq, logits_content_seq

    def forward(self, batch):
        # This takes in a batch
        # RGB image batch of shape (bs, c, n_frames, h, w)
        # modal_mask (binary, c=1) batch of shape (bs, c, n_frames, h, w)
        # concat rgb and modal masks as input to model

        # frames: (bs, c, n_frames, h, w)
        # modal_masks: (bs, 1, n_frames, h, w)
        frames = batch['frames']         # (bs, c, n_frames, h, w)
        modal_masks = batch['modal_masks']  # (bs, 1, n_frames, h, w)
        bs, c, n, h, w = frames.shape

        # Concatenate along channel dim: (bs, c+1, n_frames, h, w)
        model_input = torch.cat((frames, modal_masks), dim=1)

        # This time we don't squeeze the time dimension out
        # Prep for loss
        # This is what we use for loss calculations
        # These both have range 0 to 1 - same to logits
        amodal_mask_labels = batch['amodal_masks'].float().squeeze(2)
        amodal_content_labels = batch['amodal_content'].float().squeeze(2)

        # Process each frame in the sequence
        # Pass this into the encoder function
        # Returns the four
        # [torch.Size([6, 32, 4, 128, 128]), torch.Size([6, 64, 2, 64, 64]), torch.Size([6, 128, 1, 32, 32]), torch.Size([6, 256, 1, 16, 16])]
        x1_seq, x2_seq, x3_seq, x4_seq = self.encode_frames(x=model_input, n=n)

        # Run LSTM
        lstm_out_spatial = self.run_lstm(x4_seq)        

        # Run decoding of frames
        logits_amodal_mask, logits_amodal_content = self.decode_frames(x1_seq, x2_seq, x3_seq, lstm_out_spatial, n=n)
        

        # Should we mask the amodal content prediction by the predicted amodal mask?
        if self.mask_content_preds:
            # Element-wise masking by self-predictions:
            logits_amodal_content = logits_amodal_mask.sigmoid().round() * logits_amodal_content

        # print('0000 Verify Shapes 0000')
        # print(logits_amodal_mask.shape)
        # print(logits_amodal_content.shape)
        # print(amodal_mask_labels.shape)
        # print(amodal_content_labels.shape)
        # Loss between predicted amodal masks/content and GT masks/content
        mask_loss, content_loss = self.loss_function(logits_amodal_mask, 
                                                        amodal_mask_labels,
                                                        logits_amodal_content, 
                                                        amodal_content_labels)
        
        loss = mask_loss + content_loss

        # Calculate metrics into a dictionary
        metrics = calculate_metrics(logits_amodal_mask, 
                                    amodal_mask_labels, 
                                    logits_amodal_content,
                                    amodal_content_labels)
        
        # Report the metrics we calculated in addition to our loss functions
        metrics = {
            'loss': loss.data.item(),
            'mask_loss': mask_loss.data.item(),
            'content_loss': content_loss.data.item(),
            'other_metrics_to_monitor': 0 # add more metrics here - just make sure they are a number
        }
        return loss, metrics


#### Put dummies through new model

In [6]:
# Make a U-net which expects a batch of 4-channel images as input (batch_size, channels=4, height, width)
model = Unet_Image_LSTM(in_channels=4, mask_content_preds=True)

# Make a dummy tensor to test the model can successfully apply its layers and get an output
# RGB image batch of shape (bs, c, n_frames, h, w)
# modal_mask (binary, c=1) batch of shape (bs, c, n_frames, h, w)
# all in all batch is (bs, 4, n_frames, h, w)
batch_size = 1
rand_input = torch.randn((batch_size, 4, 8, 256, 256))
bs, c, n, h, w = rand_input.shape





# For our use-case, we are predicting amodal masks and amodal content
# The amodal mask prediction should be (batch_size, 1, h, w) for a binary mask and the 
# Amodal content prediction should be  (batch_size, 3, h, w) for an rgb object prediction

In [7]:
# Test layer by layer
with torch.no_grad():
    x1_seq, x2_seq, x3_seq, x4_seq = model.encode_frames(x=rand_input, n=n)


In [None]:
model.lstm

In [8]:
x4_seq.shape

torch.Size([1, 256, 8, 16, 16])

In [9]:
# Run LSTM
with torch.no_grad():
    lstm_out_spatial = model.run_lstm(x4_seq)

: 