In [2]:
import torch
import torch.nn as nn
import torchvision.models as models

def inflate_conv2d_to_conv3d(conv2d_layer, depth=3):
    assert isinstance(conv2d_layer, nn.Conv2d)
    out_c, in_c, h, w = conv2d_layer.weight.shape
    conv3d_layer = nn.Conv3d(
        in_c, out_c, (depth, h, w),
        stride=(1, *conv2d_layer.stride),
        padding=(depth // 2, *conv2d_layer.padding),
        bias=conv2d_layer.bias is not None
    )
    weight2d = conv2d_layer.weight.data.unsqueeze(2)
    weight3d = weight2d.repeat(1, 1, depth, 1, 1)
    weight3d = weight3d / depth
    conv3d_layer.weight.data.copy_(weight3d)
    if conv2d_layer.bias is not None:
        conv3d_layer.bias.data.copy_(conv2d_layer.bias.data)
    return conv3d_layer
    print(conv3d_layer)

def inflate_densenet2d_to_3d(model_2d, depth=3):
    for name, module in model_2d.named_children():
        if isinstance(module, nn.Conv2d):
            setattr(model_2d, name, inflate_conv2d_to_conv3d(module, depth=depth))
        else:
            inflate_densenet2d_to_3d(module, depth=depth)
    return model_2d
    print(model_2d)

if __name__ == "__main__":
    model_2d = models.densenet121(pretrained=True)
    model_3d = inflate_densenet2d_to_3d(model_2d, depth=3)
    print(model_3d.features.conv0)  # Example output



Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
