### Import exploratory tools

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from ml_utils.misc import set_seed, print_layers, get_device
from ml_utils.data import Pad, Resize, generate_sample_data

In [3]:
import timm

class ConvNextV2BackBone(nn.Module):
    def __init__(self, pretrained = True):
        super().__init__()

        model = timm.create_model('convnextv2_atto.fcmae', pretrained=pretrained)

        num_channels = 1

        ave_weight_patch_embd = model.stem[0].weight.data.mean(dim = num_channels, keepdim = True) # [40, 3, 4, 4] -> [40, 1, 4, 4]
        model.stem[0] = nn.Conv2d(num_channels, 40, kernel_size=(4, 4), stride=(4, 4))
        model.stem[0].weight.data = ave_weight_patch_embd
        
        self.stem  = model.stem
        self.stages = model.stages

    def forward(self, x):
        # Process input through embeddings
        embedding_output = self.stem(x)

        # Initialize a list to hold the feature maps from each stage
        stage_feature_maps = []

        # Manually forward through each stage
        hidden_states = embedding_output
        for stage in self.stages:
            hidden_states = stage(hidden_states)
            stage_feature_maps.append(hidden_states)

        return embedding_output, stage_feature_maps

In [4]:
device = get_device()
model = ConvNextV2BackBone(True).to(device)
model.eval();

In [8]:
from bifpn.bifpn        import BiFPN
from bifpn.bifpn_config import BiFPNConfig
from bifpn.utils_build  import BackboneToBiFPNAdapter

In [None]:
generate_sample_data()


B, C, H, W = 10, 1, 1920, 1920
batch_input = np.random.rand(B, C, H, W)

H_unify, W_unify = 1024, 1024
resizer = Resize(H_unify, W_unify)

batch_input_unify = resizer(batch_input.reshape(B*C, H, W)).reshape(B, C, H_unify, W_unify)
batch_input_unify_tensor = torch.from_numpy(batch_input_unify).to(torch.float)



In [None]:
B, C, H, W = 10, 1, 1920, 1920
batch_input = np.random.rand(B, C, H, W)

H_unify, W_unify = 1024, 1024
resizer = Resize(H_unify, W_unify)

batch_input_unify = resizer(batch_input.reshape(B*C, H, W)).reshape(B, C, H_unify, W_unify)
batch_input_unify_tensor = torch.from_numpy(batch_input_unify).to(torch.float)

In [None]:
with torch.no_grad():
    batch_input_unify_tensor_cuda = batch_input_unify_tensor.to(device)
    embedding_output, stage_feature_maps = model(batch_input_unify_tensor_cuda)

In [9]:
for i in range(len(stage_feature_maps)):
    print(stage_feature_maps[i].shape)

torch.Size([10, 40, 256, 256])
torch.Size([10, 80, 128, 128])
torch.Size([10, 160, 64, 64])
torch.Size([10, 320, 32, 32])


In [10]:
num_bifpn_features = 256
backbone_output_channels = {
        "layer1" : 40,
        "layer2" : 80,
        "layer3" : 160,
        "layer4" : 320,
}
backbone_to_bifpn = BackboneToBiFPNAdapter(num_bifpn_features, backbone_output_channels)

In [11]:
bifpn_input_list = backbone_to_bifpn(stage_feature_maps)

In [12]:
for i in range(len(bifpn_input_list)):
    print(bifpn_input_list[i].shape)

torch.Size([10, 256, 256, 256])
torch.Size([10, 256, 128, 128])
torch.Size([10, 256, 64, 64])
torch.Size([10, 256, 32, 32])


In [13]:
BiFPNConfig = BiFPN.get_default_config()

In [14]:
BiFPNConfig

BiFPNConfig(RELU_INPLACE=False, DOWN_SCALE_FACTOR=0.5, UP_SCALE_FACTOR=2, NUM_BLOCKS=1, NUM_FEATURES=256, NUM_LEVELS=4, BASE_LEVEL=2, BN=BNConfig(EPS=1e-05, MOMENTUM=0.1), FUSION=FusionConfig(EPS=1e-05))

In [15]:
# Apply the BiFPN layer...
bifpn = BiFPN(config = BiFPNConfig)
bifpn_output_list = bifpn(bifpn_input_list)

In [16]:
for i in range(len(bifpn_output_list)):
    print(bifpn_output_list[i].shape)

torch.Size([10, 256, 256, 256])
torch.Size([10, 256, 128, 128])
torch.Size([10, 256, 64, 64])
torch.Size([10, 256, 32, 32])
