In [53]:
from torch import nn, Tensor
from torch.nn import functional as F

from torchvision.models.segmentation.deeplabv3 import IntermediateLayerGetter, DeepLabHead
from torchvision.models import resnet50, mobilenet_v3_large

In [50]:
class SimpleChangeDetectionModel(nn.Module):
    def __init__(self, backbone: nn.Module, classifier: nn.Module) -> None:
        super().__init__()
        self.backbone = backbone
        self.classifier = classifier

    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        input_shape = x1.shape[-2:]
        # contract: features is a dict of tensors
        features1 = self.backbone(x1)
        features2 = self.backbone(x2)
        
        x1 = features1["out"]
        x2 = features2["out"]
        out = torch.concat([x1, x2], dim=1)
        out = self.classifier(out)
        out = F.interpolate(out, size=input_shape, mode="bilinear", align_corners=False)
        
        return out

In [51]:
backbone = resnet50(weights=None, replace_stride_with_dilation=[False, True, True])
conv1 = backbone.conv1
backbone.conv1 = nn.Conv2d(in_channels=12, 
                           out_channels=conv1.out_channels, 
                           kernel_size=conv1.kernel_size, stride=conv1.stride, 
                           padding=conv1.padding, bias=conv1.bias)

return_layers = {"layer4": "out"}
    
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

classifier = DeepLabHead(2048 * 2, 2)
model = SimpleChangeDetectionModel(backbone, classifier)


In [52]:
inp = torch.zeros(2, 12, 512, 512)
model(inp, inp).shape

torch.Size([2, 2048, 64, 64])


torch.Size([2, 2, 512, 512])