In [None]:
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from torch.nn import functional as F
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

In [None]:
# !pip install efficientnet_pytorch 

In [None]:
from efficientnet_pytorch import EfficientNet

In [None]:
# load 2 models

# load model #1
modelA = EfficientNet.from_pretrained("efficientnet-b6", advprop=True)

# load model #2
modelB = EfficientNet.from_pretrained("efficientnet-b7", advprop=True)


In [None]:
class MyEnsemble(nn.Module):
    
    def __init__(self, modelA, modelB, nb_classes=10):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        
        # EfficientNet do not have layer named fc!
        # we use extract_features function instead of removing final layter
#         # Remove last linear layer
#         self.modelA.fc = nn.Identity()
#         self.modelB.fc = nn.Identity()
        
        # Create new classifier
        self.classifier = nn.Linear(2304+2560, nb_classes)
        
        # EfficientNet do pooling before final linear layter but extract_features do not do this.
        # may not be needed.
        self._avg_pooling =  nn.AdaptiveAvgPool2d(1)
        
    def forward(self, x):
        
        x1 = self.modelA.extract_features(x.clone())  # clone to make sure x is not changed by inplace methods
        x1 = self._avg_pooling(x1)
        x1 = x1.view(x1.size(0), -1)
        
        x2 = self.modelB.extract_features(x)
        x2 = self._avg_pooling(x2)
        x2 = x2.view(x2.size(0), -1)
        
        x = torch.cat((x1, x2), dim=1)
        
        x = self.classifier(F.relu(x))
        return x

In [None]:
my_model = MyEnsemble(modelA=modelA, modelB=modelB, nb_classes=10)

In [None]:
# samle data

inputs = torch.randn(1, 3, 224, 224)

In [None]:
# 
outputs = my_model(inputs)

In [None]:
outputs.shape