In [7]:
import torch
import torch.nn as nn
import torchvision.models as models

class InceptionV3Model(nn.Module):
    def __init__(self):
        super(InceptionV3Model, self).__init__()
        # Load the pre-trained Inception V3 model
        self.model = models.inception_v3(pretrained=True)
        
        # Freeze all the parameters in the feature extraction layers
        for param in self.model.parameters():
            param.requires_grad = False
        
        # Replace the classifier part of Inception V3
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, 1)  # Adapted for binary classification

    def forward(self, x):
        # Inception V3's forward method may return auxiliary outputs
        # when training, which we do not need during inference
        if self.model.training:
            x, _ = self.model(x)
        else:
            x = self.model(x)
        return torch.sigmoid(x)

# Initialize the global model
global_model = InceptionV3Model()

# Save the global model
torch.save(global_model.state_dict(), 'global_model.pth')
print("Global model created and saved as 'global_model.pth'")


Global model created and saved as 'global_model.pth'
