### Import exploratory tools

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from ml_utils.misc import set_seed, print_layers, get_device
from ml_utils.data import Pad, Resize, generate_sample_data

In [2]:
import timm

class ConvNextV2BackBone(nn.Module):
    def __init__(self, pretrained = True):
        super().__init__()

        model = timm.create_model('convnextv2_atto.fcmae', pretrained=pretrained)

        in_channels = 1
        out_channels = model.stem[0].out_channels

        ave_weight_patch_embd = model.stem[0].weight.data.mean(dim = in_channels, keepdim = True) # [40, 3, 4, 4] -> [40, 1, 4, 4]
        model.stem[0] = nn.Conv2d(in_channels, out_channels, kernel_size=(4, 4), stride=(4, 4))
        model.stem[0].weight.data = ave_weight_patch_embd
        
        self.stem  = model.stem
        self.stages = model.stages

    def forward(self, x):
        # Process input through embeddings
        embedding_output = self.stem(x)

        # Initialize a list to hold the feature maps from each stage
        stage_feature_maps = []

        # Manually forward through each stage
        hidden_states = embedding_output
        for stage in self.stages:
            hidden_states = stage(hidden_states)
            stage_feature_maps.append(hidden_states)

        return embedding_output, stage_feature_maps

In [3]:
device = get_device()
model = ConvNextV2BackBone(True).to(device)
model.eval();

In [4]:
from bifpn.bifpn        import BiFPN
from bifpn.bifpn_config import BiFPNConfig
from bifpn.utils_build  import BackboneToBiFPNAdapter, BackboneToBiFPNAdapterConfig

In [5]:
generate_sample_data()


B, C, H, W = 10, 1, 1920, 1920
batch_input = np.random.rand(B, C, H, W)

H_unify, W_unify = 1024, 1024
resizer = Resize(H_unify, W_unify)

batch_input_unify = resizer(batch_input.reshape(B*C, H, W)).reshape(B, C, H_unify, W_unify)
batch_input_unify_tensor = torch.from_numpy(batch_input_unify).to(torch.float)



In [5]:
B, C, H, W = 10, 1, 1920, 1920
batch_input = np.random.rand(B, C, H, W)

H_unify, W_unify = 1024, 1024
resizer = Resize(H_unify, W_unify)

batch_input_unify = resizer(batch_input.reshape(B*C, H, W)).reshape(B, C, H_unify, W_unify)
batch_input_unify_tensor = torch.from_numpy(batch_input_unify).to(torch.float)

In [6]:
with torch.no_grad():
    batch_input_unify_tensor_cuda = batch_input_unify_tensor.to(device)
    embedding_output, stage_feature_maps = model(batch_input_unify_tensor_cuda)

In [7]:
for i in range(len(stage_feature_maps)):
    print(stage_feature_maps[i].shape)

torch.Size([10, 40, 256, 256])
torch.Size([10, 80, 128, 128])
torch.Size([10, 160, 64, 64])
torch.Size([10, 320, 32, 32])


In [8]:
num_bifpn_features = 256
backbone_output_channels = {
        "layer1" : 40,
        "layer2" : 80,
        "layer3" : 160,
        "layer4" : 320,
}

config = BackboneToBiFPNAdapterConfig(num_bifpn_features = num_bifpn_features, backbone_output_channels = backbone_output_channels)
backbone_to_bifpn = BackboneToBiFPNAdapter(config)

In [9]:
bifpn_input_list = backbone_to_bifpn(stage_feature_maps)

In [10]:
for i in range(len(bifpn_input_list)):
    print(bifpn_input_list[i].shape)

torch.Size([10, 256, 256, 256])
torch.Size([10, 256, 128, 128])
torch.Size([10, 256, 64, 64])
torch.Size([10, 256, 32, 32])


In [11]:
BiFPNConfig = BiFPN.get_default_config()

In [12]:
BiFPNConfig

BiFPNConfig(RELU_INPLACE=False, DOWN_SCALE_FACTOR=0.5, UP_SCALE_FACTOR=2, NUM_BLOCKS=1, NUM_FEATURES=256, NUM_LEVELS=4, BASE_LEVEL=2, BN=BNConfig(EPS=1e-05, MOMENTUM=0.1), FUSION=FusionConfig(EPS=1e-05))

In [13]:
# Apply the BiFPN layer...
bifpn = BiFPN(config = BiFPNConfig)
bifpn_output_list = bifpn(bifpn_input_list)

In [14]:
for i in range(len(bifpn_output_list)):
    print(bifpn_output_list[i].shape)

torch.Size([10, 256, 256, 256])
torch.Size([10, 256, 128, 128])
torch.Size([10, 256, 64, 64])
torch.Size([10, 256, 32, 32])


In [21]:
from dataclasses import dataclass, field, asdict, is_dataclass
from typing import List

class SegLateralLayer(nn.Module):

    def __init__(self, in_channels, out_channels, num_groups, num_layers, base_scale_factor = 2):
        super().__init__()

        self.enables_upsample = num_layers > 0

        # Strange strategy, but...
        num_layers = max(num_layers, 1)

        # 3x3 convolution with pad 1, group norm and relu...
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels  = (in_channels if idx == 0 else out_channels),
                          out_channels = out_channels,
                          kernel_size  = 3,
                          padding      = 1,),
                nn.GroupNorm(num_groups, out_channels),
                nn.ReLU(),
            )
            for idx in range(num_layers)
        ])

        self.base_scale_factor = base_scale_factor


    def forward(self, x):
        for layer in self.layers:
            # Conv3x3...
            x = layer(x)

            # Optional upsampling...
            if self.enables_upsample:
                x = F.interpolate(x,
                                  scale_factor  = self.base_scale_factor,
                                  mode          = 'bilinear',
                                  align_corners = False)

        return x


@dataclass
class SegHeadConfig:
    UP_SCALE_FACTOR: List[int] = field(
        default_factory = lambda : [
            ## 2,  # stem
            4,  # layer1
            8,  # layer2
            16, # layer3
            32, # layer4
        ]
    )
    NUM_GROUPS           : int  = 32
    OUT_CHANNELS         : int  = 256
    NUM_CLASSES          : int  = 3
    BASE_SCALE_FACTOR    : int  = 2
    USES_LEARNED_UPSAMPLE: bool = False

In [22]:
from math import log

seghead_config = SegHeadConfig()

# Create the prediction head...
base_scale_factor         = seghead_config.BASE_SCALE_FACTOR
max_scale_factor          = seghead_config.UP_SCALE_FACTOR[0]
num_upscale_layer_list    = [ int(log(i/max_scale_factor)/log(2)) for i in seghead_config.UP_SCALE_FACTOR ]
lateral_layer_in_channels = BiFPNConfig.NUM_FEATURES
seg_lateral_layers = nn.ModuleList([
    # Might need to reverse the order (pay attention to the order in the bifpn output)
    SegLateralLayer(in_channels       = lateral_layer_in_channels,
                    out_channels      = seghead_config.OUT_CHANNELS,
                    num_groups        = seghead_config.NUM_GROUPS,
                    num_layers        = num_upscale_layers,
                    base_scale_factor = base_scale_factor)
    for num_upscale_layers in num_upscale_layer_list
])

head_segmask  = nn.Conv2d(in_channels  = seghead_config.OUT_CHANNELS,
                                out_channels = seghead_config.NUM_CLASSES,
                                kernel_size  = 1,
                                padding      = 0,)

if seghead_config.USES_LEARNED_UPSAMPLE:
    head_upsample_layer = nn.ConvTranspose2d(in_channels  = seghead_config.NUM_CLASSES,
                                                    out_channels = seghead_config.NUM_CLASSES,
                                                    kernel_size  = 6,
                                                    stride       = 4,
                                                    padding      = 1,)

In [23]:
# Fuse feature maps at each resolution (from low res to high res)...
for idx, (lateral_layer, bifpn_output) in enumerate(zip(seg_lateral_layers[::-1], bifpn_output_list[::-1])):
    fmap = lateral_layer(bifpn_output)

    if idx == 0:
        fmap_acc  = fmap
    else:
        fmap_acc += fmap

In [25]:
fmap_acc.shape

torch.Size([10, 256, 256, 256])

In [27]:
# Make prediction...
pred_map = head_segmask(fmap_acc)

# Upscale...
max_scale_factor = seghead_config.UP_SCALE_FACTOR[0]
pred_map = F.interpolate(pred_map,
                            scale_factor  = max_scale_factor,
                            mode          = 'bilinear',
                            align_corners = False)                   \
            if not seghead_config.USES_LEARNED_UPSAMPLE else \
            head_upsample_layer(pred_map)

In [29]:
pred_map.shape

torch.Size([10, 3, 1024, 1024])