In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, transforms
import torch.nn.functional as F
from typing import Optional

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score

import cv2
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [3]:
from torchvision.models import resnet18, resnet101, resnet50, ResNet18_Weights

class ResNetClassifier(nn.Module):
    def __init__(self):
        super(ResNetClassifier, self).__init__()
        self.resnet = resnet18(weights=None)
        num_features = self.resnet.fc.in_features
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet.fc = nn.Linear(num_features, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.resnet(x)
        x = torch.squeeze(x, dim=1)
        x = self.sigmoid(x)
        return x

In [4]:
from torchvision.models import resnet18
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the global model
global_model = ResNetClassifier().to(device)

# Iterate over the trained models
model_paths = ["../../LocalServers/saved_models/local_server_1_RESNET.pth", 
               "../../LocalServers/saved_models/local_server_2_RESNET.pth", 
               "../../LocalServers/saved_models/local_server_3_RESNET.pth"]
num_models = len(model_paths)


In [5]:
# Load the trained models
trained_models = []
for path in model_paths:
    model = ResNetClassifier().to(device)
    model.load_state_dict(torch.load(path))
    trained_models.append(model)

In [6]:
# Initialize the global model's weights
global_model.load_state_dict(trained_models[0].state_dict())

<All keys matched successfully>

In [45]:
# Initialize the global model's weights
global_weights = global_model.state_dict()

# Iterate over the trained models and calculate the sum of their weights
for i in range(num_models):
    local_weights = trained_models[i].state_dict()

    # Accumulate the weights
    for key in global_weights.keys():
        global_weights[key] += local_weights[key]

In [46]:
# Calculate the average by dividing the weights by the number of models
for key in global_weights.keys():
    global_weights[key] / num_models

In [47]:
# Save the global model
torch.save(global_model.state_dict(), "AVG_global_model.pth")