In [53]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [54]:
from src.bifpn import ResampleFeatureMap, BiFpn
import torch
import timm

In [55]:
# downsampling
inp = torch.randn(1, 40, 64, 64)
resample = ResampleFeatureMap(in_channels=40, out_channels=112, reduction_ratio=2)
out = resample(inp)
inp.shape, out.shape

(torch.Size([1, 40, 64, 64]), torch.Size([1, 112, 32, 32]))

In [56]:
# upsampling
inp = torch.randn(1, 40, 64, 64)
resample = ResampleFeatureMap(in_channels=40, out_channels=112, reduction_ratio=0.5)
out = resample(inp)
inp.shape, out.shape

(torch.Size([1, 40, 64, 64]), torch.Size([1, 112, 128, 128]))

In [57]:
from typing import Callable

def get_feature_info(backbone):
    if isinstance(backbone.feature_info, Callable):
        # old accessor for timm versions <= 0.1.30, efficientnet and mobilenetv3 and related nets only
        feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction'])
                        for i, f in enumerate(backbone.feature_info())]
    else:
        # new feature info accessor, timm >= 0.2, all models supported
        feature_info = backbone.feature_info.get_dicts(keys=['num_chs', 'reduction'])
    return feature_info

In [58]:
backbone = timm.create_model(
        'resnet18', features_only=True, out_indices=(0, 1, 2, 3, 4),
        pretrained=True)

feature_info = get_feature_info(backbone)
feature_info

[{'num_chs': 64, 'reduction': 2},
 {'num_chs': 64, 'reduction': 4},
 {'num_chs': 128, 'reduction': 8},
 {'num_chs': 256, 'reduction': 16},
 {'num_chs': 512, 'reduction': 32}]

In [59]:
fpn = BiFpn(feature_info)

In [60]:
inputs = torch.randn(16, 3, 256, 256)
features = backbone(inputs)
for f in features:
    print(f.shape)

torch.Size([16, 64, 128, 128])
torch.Size([16, 64, 64, 64])
torch.Size([16, 128, 32, 32])
torch.Size([16, 256, 16, 16])
torch.Size([16, 512, 8, 8])


In [61]:
out = fpn(features)
for f in out:
    print(f.shape)

torch.Size([16, 64, 128, 128])
torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 32, 32])
torch.Size([16, 64, 16, 16])
torch.Size([16, 64, 8, 8])
