In [2]:
import sys
sys.path.append("/research/m324371/Project/adnexal/networks/")

import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class FeatureEnsemble2models(nn.Module):
    def __init__(self, model1, model2, clip1:int, clip2:int):
        super(FeatureEnsemble2models, self).__init__()

        self.clipped_model1 = nn.Sequential(*list(model1.children()))[:-clip1] # clip first model
        self.clipped_model2 = nn.Sequential(*list(model2.children()))[:-clip2] # clip second model

    def forward(self, x):
        features1 = self.clipped_model1(x)
        features2 = self.clipped_model2(x)

        # Ensure both feature maps have the same spatial dimensions
        if features1.shape[2:] != features2.shape[2:]:
            # Resize feature2's output to match feature1's spatial dimensions
            features2 = F.interpolate(features2, size=features1.shape[2:], mode='bilinear', align_corners=False)

        # Concatenate the feature maps along the channel dimension
        combined_features = torch.cat((features1, features2), dim=1) 
        
        return combined_features

In [5]:
if __name__ == "__main__":
    from res50pscse_512x28x28 import ResNet50Pscse_512x28x28
    from enetb2lpscse_384x28x28 import EfficientNetB2LPscse_384x28x28

    inp=torch.rand(1, 3, 224, 224)
    num_classes=2
    out_channels=[1024, 512, 256]
    pretrain=True
    dropout=0.3
    activation='leakyrelu'
    reduction=16
    
    model1 = ResNet50Pscse_512x28x28(num_classes, out_channels, pretrain, dropout, activation, reduction) 
    model2 = EfficientNetB2LPscse_384x28x28(num_classes, out_channels, pretrain, dropout, activation, reduction)
    
    feature_ensembled_model = FeatureEnsemble2models(model1, model2, clip1=1, clip2=1) # clip classification head
    out = feature_ensembled_model(inp)
    
    print(out.shape) # torch.Size([1, 2048, 7, 7])

torch.Size([1, 2048, 7, 7])
