In [1]:
import os
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import from_networkx
from torch_geometric.utils import subgraph
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
from torch_geometric.data import Dataset
from sklearn.model_selection import train_test_split
import copy 

In [2]:
torch.manual_seed(1234567)
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]


folder_path = 'second_order_client_neighbors'
sub_data_list = []
client_number = []

for filename in os.listdir(folder_path):
    if filename.endswith('.gml'):
      
        file_path = os.path.join(folder_path, filename)
        g = nx.read_gml(file_path)

        subgraph_nodes = list(g.nodes)
        subgraph_nodes = [int(node) for node in subgraph_nodes]  # Convert to integer if they are not

        sub_edge_index, _ = subgraph(subgraph_nodes, data.edge_index, relabel_nodes=True)
        print(sub_edge_index.shape)

        sub_data = Data(x=data.x[subgraph_nodes], edge_index=sub_edge_index, y=data.y[subgraph_nodes])
        sub_data_list.append(sub_data)
        client_number.append(int(filename.split('.')[0].split('_')[1]))
        


torch.Size([2, 746])
torch.Size([2, 164])
torch.Size([2, 570])
torch.Size([2, 92])
torch.Size([2, 228])
torch.Size([2, 854])
torch.Size([2, 136])
torch.Size([2, 378])
torch.Size([2, 102])
torch.Size([2, 856])
torch.Size([2, 144])
torch.Size([2, 128])
torch.Size([2, 734])
torch.Size([2, 132])
torch.Size([2, 862])
torch.Size([2, 100])
torch.Size([2, 136])
torch.Size([2, 418])
torch.Size([2, 72])
torch.Size([2, 98])
torch.Size([2, 142])
torch.Size([2, 148])
torch.Size([2, 238])
torch.Size([2, 92])
torch.Size([2, 124])
torch.Size([2, 368])
torch.Size([2, 350])
torch.Size([2, 414])
torch.Size([2, 1230])
torch.Size([2, 184])
torch.Size([2, 756])
torch.Size([2, 1790])
torch.Size([2, 312])
torch.Size([2, 182])
torch.Size([2, 156])
torch.Size([2, 82])
torch.Size([2, 166])
torch.Size([2, 376])
torch.Size([2, 330])
torch.Size([2, 178])
torch.Size([2, 120])
torch.Size([2, 144])
torch.Size([2, 160])
torch.Size([2, 340])
torch.Size([2, 148])
torch.Size([2, 220])
torch.Size([2, 790])
torch.Size([2, 9

In [3]:
test_data = None

g = nx.read_gml('second_order_test_graph.gml')
subgraph_nodes = list(g.nodes)
subgraph_nodes = [int(node) for node in subgraph_nodes]  # Convert to integer if they are not
sub_edge_index, _ = subgraph(subgraph_nodes, data.edge_index, relabel_nodes=True)
test_data = Data(x=data.x[subgraph_nodes], edge_index=sub_edge_index, y=data.y[subgraph_nodes])



In [4]:

g = nx.read_gml('new_facebook_network.gml')
neighbors = {}

for node in g.nodes:
    neighbors[node] = list(g.neighbors(node))

print(neighbors)

{'828': ['713', '705', '719', '805', '824', '745', '747', '823', '694', '830', '781', '697', '724', '827', '688', '853', '834', '703', '784', '815', '752', '728', '820', '842', '800', '819', '774', '726', '773', '766', '829', '780', '810', '696', '849', '718', '754', '760', '739', '856', '764', '687', '840', '847', '770', '797', '814', '727', '731', '708', '698', '817', '792', '845', '838', '848', '835', '807', '711', '779', '706', '772', '741', '709', '778', '737', '844', '693'], '713': ['828', '705', '719', '805', '824', '745', '747', '823', '694', '830', '781', '697', '724', '827', '688', '853', '834', '695', '784', '815', '752', '842', '800', '819', '774', '726', '773', '766', '829', '780', '810', '696', '760', '787', '840', '847', '814', '731', '708', '698', '738', '817', '734', '845', '701', '848', '748', '706', '777', '772', '818', '778', '831', '833', '722', '844', '736', '794', '693'], '705': ['828', '713', '719', '805', '824', '745', '747', '823', '694', '830', '781', '697', 

In [5]:
# read new_facebook_network.gml and create a fully connected graph

neighbors_fully_connected = {}

for node in g.nodes:
    neighbors_fully_connected[node] = list(g.nodes)
    neighbors_fully_connected[node].remove(node)

    
# print(neighbors_fully_connected)

In [6]:
print(sub_data_list[0])
print(sub_data_list[0].x.shape)
print(type(client_number))

Data(x=[188, 1433], edge_index=[2, 746], y=[188])
torch.Size([188, 1433])
<class 'list'>


In [None]:
#keep a  list of training and validation loss per epoch for each subgraph
train_losses = []
val_losses = []

In [7]:
import torch



def transductive_split(data, train_percent=0.8, val_percent=0.1):
    """
    Split graph data into training, validation, and testing sets for transductive learning.
    :param data: PyG Data object
    :param train_percent: Percentage of nodes to be used for training
    :param val_percent: Percentage of nodes to be used for validation
    :return: data object with train_mask, val_mask, and test_mask attributes added
    """
    # set a seed for reproducibility

    num_nodes = data.num_nodes
    train_size = int(train_percent * num_nodes)
    val_size = int(val_percent * num_nodes)

    # Create a random permutation of node indices
    perm = torch.randperm(num_nodes)

    # Create masks for training, validation, and testing nodes
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)

    train_mask[perm[:train_size]] = True
    val_mask[perm[train_size:train_size + val_size]] = True
    test_mask[perm[train_size + val_size:]] = True

    # Add masks to data object
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    return data


In [44]:
from torch_geometric.nn import GCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   
    
class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 32)
        self.conv2 = GCNConv(32, num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training) # p = 0.25
        x = self.conv2(x, edge_index)
        return x
    


model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)

In [9]:
#GAT model
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
        self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

In [None]:
import matplotlib.pyplot as plt

# fig, axs = plt.subplots(10, 10, figsize=(50, 50))

def train(sub_data, model, optimizer, criterion):
      #put data on device
      sub_data = sub_data.to(device)
      model.train()
      optimizer.zero_grad() 
      out = model(sub_data.x, sub_data.edge_index)  # Perform a single forward pass.
      loss = criterion(out, sub_data.y)  # Compute the loss solely based on the training nodes.
      loss.backward() 
      optimizer.step() 
      return loss

def test(test_data, criterion, model):
      model.eval()
      out = model(test_data.x, test_data.edge_index)
      pred = out.argmax(dim=1)  # Use the class with highest probability.
      test_correct = pred == test_data.y # Check against ground-truth labels.
      test_acc = int(test_correct.sum()) / len(test_data.y)  # Derive ratio of correct predictions.
      test_loss = criterion(out, test_data.y)  # Compute validation loss
      
      return test_loss, test_acc


def validate(test_data, model, criterion):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Do not compute gradients during this step
        out = model(test_data.x, test_data.edge_index)  # Forward pass
        pred = out.argmax(dim=1)  # Get predicted classes
        val_correct = pred == test_data.y # Compare with ground-truth
        val_loss = criterion(out, test_data.y)  # Compute validation loss
        val_acc = int(val_correct.sum()) / int(len(test_data.y))  # Compute validation accuracy
    return val_loss.item(), val_acc  # Return validation loss and accuracy


def initial_training(epochs, learning_rate):
    test_losses = []
    test_accs = []
    dictionary_t_1 = {}
    dictionary_t = {}

    for i in range(len(sub_data_list)):
        sub_data = sub_data_list[i]
        model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
        criterion = torch.nn.CrossEntropyLoss()

        train_losses = [] 
        val_losses = []    

        # Training loop for each epoch (adjust the range as needed)
        for epoch in range(1, epochs + 1):
            train_loss = train(sub_data, model, optimizer, criterion)  # Assuming 'train' returns a loss
            val_loss, _ = validate(test_data, model, criterion)         # Assuming 'validate' returns a loss and something else (like accuracy)
            
            train_losses.append(train_loss)
            val_losses.append(val_loss)
            
        
        test_loss, test_acc = test(test_data, criterion, model)
        test_losses.append(test_loss.item())
        test_accs.append(test_acc)
        row = i // 10
        col = i % 10

        model_weights = model.state_dict()
        client_number_num = client_number[i]
        dictionary_t_1[client_number_num] = model_weights
        
    # plt.tight_layout()
    # plt.show()
    
    return dictionary_t_1, test_losses, test_accs
    
    


In [None]:
dictionary_t_1, test_losses, test_accs = initial_training(10, 0.01)

# average test loss and accuracy before communication 

mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

In [None]:

import matplotlib.pyplot as plt

plt.hist(test_accs)
plt.title("Test accuracies")
plt.show()

In [10]:
def fed_avg(models):
    global_model = copy.deepcopy(models[0]) # start with the first model
    for k in global_model.keys():
        global_model[k] = torch.stack([model[k].float() for model in models], 0).mean(0)
    return global_model

In [11]:
def weighted_fed_avg(models, accuracies):
    #model and accuracies are dictionaries
    global_model = copy.deepcopy(models[0]) # start with the first model
    for k in global_model.keys():
        weighted_sum = 0
        for i in range(len(models)):
            weighted_sum += accuracies[i] * models[i][k].float()
        global_model[k] = weighted_sum / sum(accuracies)
    return global_model

In [12]:
def exponential_weighted_fed_avg(models, accuracies):
    accuracies = np.exp(accuracies) / np.sum(np.exp(accuracies))
    global_model = copy.deepcopy(models[0]) # start with the first model
    for k in global_model.keys():
        weighted_sum = 0
        for i in range(len(models)):
            weighted_sum += accuracies[i] * models[i][k].float()
        global_model[k] = weighted_sum / sum(accuracies)
    return global_model

In [13]:
def normalized_fed_avg(models, accuracies):
    global_model = copy.deepcopy(models[0])
    normalized_accuracies = [float(i)/sum(accuracies) for i in accuracies]
    for k in global_model.keys():
        weighted_sum = 0
        for i in range(len(models)):
            weighted_sum += normalized_accuracies[i] * models[i][k].float()
        global_model[k] = weighted_sum
    return global_model

In [14]:
def threshold_fed_avg(models, accuracies, threshold):
    global_model = copy.deepcopy(models[0])
    total_weight = 0
    for k in global_model.keys():
        weighted_sum = 0
        unweighted_sum = 0
        for i in range(len(models)):
            unweighted_sum += models[i][k].float()
            if accuracies[i] >= threshold:
                weighted_sum += accuracies[i] * models[i][k].float()
                total_weight += accuracies[i]
        if total_weight == 0:
            global_model[k] = unweighted_sum / len(models)
        else :
            global_model[k] = weighted_sum / total_weight
    return global_model

In [None]:


def communication(dictionary_t_1_og, neighbors, epochs, rounds, learning_rate):
    # make a copy of dictionary_t_1
    test_losses = []
    test_accs = []
    
    dictionary_t_1 = dictionary_t_1_og.copy()
    dictionary_t = {}
    for k in range(1, rounds + 1):

    # get the neighbprs of each client
        for i in range(len(sub_data_list)):
            sub_data = sub_data_list[i]
            
            model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)
            model.load_state_dict(dictionary_t_1[client_number[i]])
            
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
            criterion = torch.nn.CrossEntropyLoss()

            train_losses = [] 
            val_losses = []    

            # Training loop for each epoch (adjust the range as needed)
            for epoch in range(1, epochs + 1):
                train_loss = train(sub_data, model, optimizer, criterion)  # Assuming 'train' returns a loss
                val_loss, _ = validate(test_data, model, criterion)         # Assuming 'validate' returns a loss and something else (like accuracy)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                
            
            test_loss, test_acc = test(test_data, criterion, model)
            if k == rounds:
                test_losses.append(test_loss.item())
                test_accs.append(test_acc)

        
            # add the model weights to the t -1 dictionary
            model_weights = model.state_dict()
            client_number_num = client_number[i]
            dictionary_t_1[client_number_num] = model_weights
        

        for i in range(100):
            client_number_num = client_number[i]
            # go through neighbors of client_number
            string_client_num = str(client_number_num)
            client_neighbors = neighbors[string_client_num]
            
            neighbors_state_dicts = []
            neighbors_state_dicts.append(dictionary_t_1[client_number_num])
            for j in range(len(client_neighbors)):
                # get model weights of neighbor
                neighbor_model = dictionary_t_1[int(client_neighbors[j])]
                neighbors_state_dicts.append(neighbor_model)
            
            #call fed_avg
            average_state_dict = {}
            average_state_dict = fed_avg(neighbors_state_dicts)
                
                
            dictionary_t[client_number_num] = average_state_dict    
            
        dictionary_t_1 = dictionary_t

    return test_losses, test_accs
 
        
         

In [None]:
#communication with weight averaging based on model accuracy


def communication(dictionary_t_1_og, neighbors, epochs, rounds, learning_rate):
    # make a copy of dictionary_t_1
    test_losses = []
    test_accs = []
    
    dictionary_t_1 = dictionary_t_1_og.copy()
    t_1_accuracies = {}
    dictionary_t = {}
    for k in range(1, rounds + 1):

    # get the neighbprs of each client
        for i in range(len(sub_data_list)):
            sub_data = sub_data_list[i]
            
            model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)
            model.load_state_dict(dictionary_t_1[client_number[i]])
            
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
            criterion = torch.nn.CrossEntropyLoss()

            train_losses = [] 
            val_losses = []    

            # Training loop for each epoch (adjust the range as needed)
            for epoch in range(1, epochs + 1):
                train_loss = train(sub_data, model, optimizer, criterion)  # Assuming 'train' returns a loss
                val_loss, _ = validate(test_data, model, criterion)         # Assuming 'validate' returns a loss and something else (like accuracy)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                
            
            test_loss, test_acc = test(test_data, criterion, model)
            if k == rounds:
                test_losses.append(test_loss.item())
                test_accs.append(test_acc)

        
            # add the model weights to the t -1 dictionary
            model_weights = model.state_dict()
            client_number_num = client_number[i]
            dictionary_t_1[client_number_num] = model_weights
            t_1_accuracies[client_number_num] = test_acc
        

        for i in range(100):
            client_number_num = client_number[i]
            # go through neighbors of client_number
            string_client_num = str(client_number_num)
            client_neighbors = neighbors[string_client_num]
            
            neighbors_state_dicts = []
            neighbors_state_dicts.append(dictionary_t_1[client_number_num])
            accuracies = []
            accuracies.append(t_1_accuracies[client_number_num])
            for j in range(len(client_neighbors)):
                # get model weights of neighbor
                neighbor_model = dictionary_t_1[int(client_neighbors[j])]
                neighbors_state_dicts.append(neighbor_model)
                accuracies.append(t_1_accuracies[int(client_neighbors[j])])
                
            
            #call fed_avg
            average_state_dict = {}
            average_state_dict = threshold_fed_avg(neighbors_state_dicts, accuracies, 0.15)
                
                
            dictionary_t[client_number_num] = average_state_dict    
            
        dictionary_t_1 = dictionary_t

    return test_losses, test_accs
 
        
         

In [None]:

test_losses, test_accs = communication(dictionary_t_1, neighbors, 10, 10, 0.01)

# average test loss and accuracy after communication
mean_test_losses = sum(test_losses)/len(test_losses)

#mean of test_accs
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)


In [None]:
# print a distriution of the accuracies of test_losses and test_accs

import matplotlib.pyplot as plt

plt.hist(test_accs)
plt.title("Test accuracies")
plt.show()

In [None]:
# pr

# Fully Connected Graph

## 1 epoch, 100 rounds

In [None]:
# 1 epoch

dictionary_t_1, test_losses, test_accs = initial_training(1)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors_fully_connected, 1, 100)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)


## 5 epochs, 100 rounds

In [None]:
# 5 epoch

dictionary_t_1, test_losses, test_accs = initial_training(5)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors_fully_connected, 5, 100)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)

## 10 epochs, 100 rounds

In [None]:
# 01 epoch

dictionary_t_1, test_losses, test_accs = initial_training(10)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors_fully_connected, 10, 100)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)

# Our Social Network

## 1 epoch, 100 rounds

In [None]:
# 1 epoch

dictionary_t_1, test_losses, test_accs = initial_training(1)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 1, 100)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)

## 5 epochs, 100 rounds

In [None]:
# 5 epoch

dictionary_t_1, test_losses, test_accs = initial_training(5, 0.01)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 5, 100, 0.01)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)

In [None]:
#plotting the test accuracies

import matplotlib.pyplot as plt

plt.hist(test_accs)
plt.title("Test accuracies")
plt.show()


## 10 epochs, 100 rounds

In [None]:
# 10 epoch

dictionary_t_1, test_losses, test_accs = initial_training(10)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 10, 100)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)

## 20 epochs, 100 rounds

In [None]:
# 20 epoch

dictionary_t_1, test_losses, test_accs = initial_training(20)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 20, 100)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)

## 1 epochs, 10 rounds

In [None]:
# 1 epoch

dictionary_t_1, test_losses, test_accs = initial_training(1)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 1, 10)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)

## 5 epochs, 10 rounds

In [None]:
# 5 epoch

dictionary_t_1, test_losses, test_accs = initial_training(5)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 5, 10)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)

## 10 epochs, 10 rounds

In [None]:
# 10 epoch

dictionary_t_1, test_losses, test_accs = initial_training(10)
# print average test loss and average test accs
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss before communication: ", mean_test_losses)
print("mean test accuracy before communication: ", mean_test_accs)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 10, 10)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test loss after communication: ", mean_test_losses)
print("mean test accuracy after communication: ", mean_test_accs)

## Fully Connected Graph Results
### (Training with original subgraphs)

| Epochs | Communication Rounds | Avg Test Accuracy (Before) | Avg Test Accuracy (After) | Avg Test Loss (Before) | Avg Test Loss (After) |
| ------ | -------------------- | -------------------------- | ------------------------- | ---------------------- | --------------------- |
| 1      | 100                  | 15.07%                     | 14.29%                    | 1.947                  | 2.083                 |
| 5      | 100                  | 14.89%                     | 15.49%                    | 2.021                  | 2.943                 |
| 10     | 100                  | 14.73%                     | 14.90%                    | 2.387                  | 5.010                 |



## Social Network Graph Results
### (Training with original subgraphs)

| Epochs | Communication Rounds | Avg Test Accuracy (Before) | Avg Test Accuracy (After) | Avg Test Loss (Before) | Avg Test Loss (After) |
| ------ | -------------------- | -------------------------- | ------------------------- | ---------------------- | --------------------- |
| 1      | 100                  | 15.13%                     | 14.31%                    | 1.947                  | 1.996                 |
| 1      | 10                   | 15.40%                     | 15.12%                    | 1.946                  | 1.945                 |
| 5      | 100                  | 14.38%                     | 16.16%                    | 2.028                  | 2.730                 |
| 5      | 10                   | 14.52%                     | 14.49%                    | 2.014                  | 2.282                 |
| 10     | 100                  | 14.74%                     | 16.07%                    | 2.321                  | 4.402                 |
| 10     | 10                   | 14.69%                     | 15.01%                    | 2.349                  | 3.084                 |
| 20     | 100                  | 15.39%                     | 18.50%                    | 3.101                  | 4.654                 |


## Fully Connected Graph Results
### (Training with second order subgraphs)


| Epochs | Communication Rounds | Avg Test Accuracy (Before) | Avg Test Accuracy (After) | Avg Test Loss (Before) | Avg Test Loss (After) |
| ------ | -------------------- | -------------------------- | ------------------------- | ---------------------- | --------------------- |
| 1      | 100                  | 15.02%                     | 19.24%                    | 1.951                  | 2.052                 |
| 5      | 100                  | 15.22%                     | 31.80%                    | 2.114                 | 2.614                |
| 10     | 100                  | 15.69%                     | 29.41%                    | 2.794                 | 2.932              |

## Social Network Graphs Results
### (Training with second order subgraphs)

| Epochs | Communication Rounds | Avg Test Accuracy (Before) | Avg Test Accuracy (After) | Avg Test Loss (Before) | Avg Test Loss (After) |
|--------|----------------------|----------------------------|---------------------------|-----------------------|-----------------------|
| 1      | 100                  | 0.1532                     | 0.1961                    | 1.9492                | 1.9300                |
| 5      | 100                  | 0.1523                     | 0.3180                    | 2.1628                | 2.2851                |
| 10     | 100                  | 0.1593                     | 0.3135                    | 2.7872                | 2.9388                |
| 20     | 100                  | 0.1779                     | 0.2935                    | 3.4153                | 3.4956                |
| 1      | 10                   | 0.1486                     | 0.1543                    | 1.9489                | 1.9431                |
| 5      | 10                   | 0.1514                     | 0.1521                    | 2.1468                | 2.3963                |
| 10     | 10                   | 0.1590                     | 0.1783                    | 2.6969                | 3.1621                |


# Centralized Training on Entire Cora Dataset

In [None]:
#create Dataset class

class GraphDataset(torch.utils.data.Dataset):
    
    def __init__(self, graph_data):
        self.graph_data = graph_data
        
    def __len__(self):
        return len(self.graph_data.x[0])
    
    def __getitem__(self, idx):
        print(idx)
        x = self.graph_data.x[idx,:]
        y = self.graph_data.y[idx]
        edge_index = self.graph_data.edge_index
        # connected_edges = (self.graph_data.edge_index[0] == idx) | (self.graph_data.edge_index[1] == idx)
        # edge_index = self.graph_data.edge_index[:, connected_edges]
        return x, edge_index, y, idx
    
    

In [None]:
# get cora_graph_training 

import os
import networkx as nx
from torch.utils.data import DataLoader


cora_graph_training = nx.read_gml('cora_graph_training.gml')
test_graph = test_data


graph_nodes = list(cora_graph_training.nodes)
graph_nodes = [int(node) for node in graph_nodes]  # Convert to integer if they are not

graph_edge_index, _ = subgraph(graph_nodes, data.edge_index, relabel_nodes=True)
graph_data = Data(x=data.x[graph_nodes], edge_index=graph_edge_index, y=data.y[graph_nodes])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   

def train_test(graph_data, epochs, dataset, test_data, device):
    # model = GCN(graph_data.num_node_features, dataset.num_classes).to(device)
    model  = GAT(graph_data.num_node_features, dataset.num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    # optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    criterion = torch.nn.CrossEntropyLoss()  
    epochs = 10
    # Training loop for each epoch (adjust the range as needed)
    for epoch in range(1, epochs + 1):
        train_loss = train(graph_data, model, optimizer, criterion)  # Assuming 'train' returns a loss
        val_loss, _ = validate(test_data, model, criterion)         # Assuming 'validate' returns a loss and something else (like accuracy) 

    test_loss, test_acc = test(test_data, criterion, model)
    
    # print("Centralized learning test loss: ", test_loss)
    print("Centralized learning test accuracy: ", test_acc)

train_test(graph_data, 10, dataset, test_data, device)




## Centralized Learning with Different Batch Sizes

In [None]:
from torch_geometric.data import DataLoader

def train_test_batches(graph_data, epochs, dataset, test_data, device, batch_size):
    model = GCN(graph_data.num_node_features, dataset.num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss()  
    epochs = 10
    # Training loop for each epoch (adjust the range as needed)
    for epoch in range(1, epochs + 1):
        train_loss = train_batches(graph_data, model, optimizer, criterion, batch_size)  # Assuming 'train' returns a loss
        val_loss, _ = validate(test_data, model, criterion)         # Assuming 'validate' returns a loss and something else (like accuracy) 

    test_loss, test_acc = test(test_data, criterion, model)
    
    # print("Centralized learning test loss: ", test_loss)
    print("Batch size: ", batch_size)
    print("Centralized learning test accuracy: ", test_acc)

def train_batches(graph_data, model, optimizer, criterion, batch_size):
    # graph_dataset = [graph_data.x, graph_data.edge_index, graph_data.y]
    # print(graph_data.edge_index.shape)
    graph_dataset = GraphDataset(graph_data)
    graph_data_loader = DataLoader(graph_dataset, batch_size=batch_size, shuffle=True)
    # print(len(graph_data_loader.dataset))
    
    total_loss = 0
    model.train()
    
    for i, data in enumerate(graph_data_loader):
      x, edge_index, y, idx = data
    #   print(idx)
      nodes = list(idx)
      sub_edge_index, _ = subgraph(nodes, edge_index[0], relabel_nodes=True)
    #   print(sub_edge_index.shape)
      optimizer.zero_grad() 
      out = model(x, sub_edge_index) 
      loss = criterion(out, y)  
      loss.backward() 
      optimizer.step() 
      total_loss += loss.item()
      
    return total_loss/(i+1)

print()
batch_sizes = [10, 16, 32]
for batch_size in batch_sizes:
    train_test_batches(graph_data, 10, dataset, test_data, device, batch_size)


# Testing our Social Network

### Adam Optimizer, changing learning rates
### 1 epoch, 100 communication rounds

In [None]:
learning_rates = [0.1, 0.01, 0.001, 0.0001, 0.00001]


for learning_rate in learning_rates:
    dictionary_t_1, test_losses, test_accs = initial_training(1, learning_rate)

    test_losses, test_accs = communication(dictionary_t_1, neighbors, 1, 100, learning_rate)
    mean_test_losses = sum(test_losses)/len(test_losses)
    mean_test_accs = sum(test_accs)/len(test_accs)
    print("learning rate: ", learning_rate)
    print("mean test accuracy after communication: ", mean_test_accs)
    print("mean test loss after communication: ", mean_test_losses)
    print()
    

### 5 epochs, 100 rounds

In [None]:
learning_rates = [0.01, 0.001, 0.0001, 0.00001]

for learning_rate in learning_rates:
    dictionary_t_1, test_losses, test_accs = initial_training(5, learning_rate)

    test_losses, test_accs = communication(dictionary_t_1, neighbors, 5, 100, learning_rate)
    mean_test_losses = sum(test_losses)/len(test_losses)
    mean_test_accs = sum(test_accs)/len(test_accs)
    print("learning rate: ", learning_rate)
    print("mean test accuracy after communication: ", mean_test_accs)
    print("mean test loss after communication: ", mean_test_losses) # print(f"{mean_test_losses:.4f},{}\n")
    print()

### 10 epochs, 100 rounds

In [None]:
for learning_rate in learning_rates:
    dictionary_t_1, test_losses, test_accs = initial_training(10, learning_rate)

    test_losses, test_accs = communication(dictionary_t_1, neighbors, 10, 100, learning_rate)
    mean_test_losses = sum(test_losses)/len(test_losses)
    mean_test_accs = sum(test_accs)/len(test_accs)
    print("learning rate: ", learning_rate)
    print("mean test accuracy after communication: ", mean_test_accs)
    print("mean test loss after communication: ", mean_test_losses)
    print()

In [None]:
min_nodes = 1000000
for subgraph in sub_data_list:
    if subgraph.num_nodes < min_nodes:
        min_nodes = subgraph.num_nodes
print("min nodes: ", min_nodes)

In [None]:

def initial_training(epochs, learning_rate, batch_size):
    test_losses = []
    test_accs = []
    dictionary_t_1 = {}
    dictionary_t = {}
    
    for i in range(len(sub_data_list)):
        sub_data = sub_data_list[i]
        model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
        criterion = torch.nn.CrossEntropyLoss()

        train_losses = [] 
        val_losses = []    

        # Training loop for each epoch (adjust the range as needed)
        for epoch in range(1, epochs + 1):
            train_loss = train_batches(sub_data, model, optimizer, criterion, batch_size)  # Assuming 'train' returns a loss
            val_loss, _ = validate(test_data, model, criterion)         # Assuming 'validate' returns a loss and something else (like accuracy)
            
            train_losses.append(train_loss)
            val_losses.append(val_loss)
            
        
        test_loss, test_acc = test(test_data, criterion, model)
        test_losses.append(test_loss.item())
        test_accs.append(test_acc)
        row = i // 10
        col = i % 10

        model_weights = model.state_dict()
        client_number_num = client_number[i]
        dictionary_t_1[client_number_num] = model_weights
    return dictionary_t_1, test_losses, test_accs
        
def communication(dictionary_t_1_og, neighbors, epochs, rounds, learning_rate, batch_size):
    # make a copy of dictionary_t_1
    test_losses = []
    test_accs = []
    
    dictionary_t_1 = dictionary_t_1_og.copy()
    dictionary_t = {}
    for k in range(1, rounds + 1):

    # get the neighbprs of each client
        for i in range(len(sub_data_list)):
            sub_data = sub_data_list[i]
            
            model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)
            model.load_state_dict(dictionary_t_1[client_number[i]])
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
            criterion = torch.nn.CrossEntropyLoss()
            train_losses = [] 
            val_losses = []    

            # Training loop for each epoch (adjust the range as needed)
            for epoch in range(1, epochs + 1):
                train_loss = train_batches(sub_data, model, optimizer, criterion, batch_size)  # Assuming 'train' returns a loss
                val_loss, _ = validate(test_data, model, criterion)         # Assuming 'validate' returns a loss and something else (like accuracy)
                train_losses.append(train_loss)
                val_losses.append(val_loss)                   
            test_loss, test_acc = test(test_data, criterion, model)
            if k == rounds:
                test_losses.append(test_loss.item())
                test_accs.append(test_acc)
            # add the model weights to the t -1 dictionary
            model_weights = model.state_dict()
            client_number_num = client_number[i]
            dictionary_t_1[client_number_num] = model_weights
        

        for i in range(100):
            client_number_num = client_number[i]
            # go through neighbors of client_number
            string_client_num = str(client_number_num)
            client_neighbors = neighbors[string_client_num]
            
            neighbors_state_dicts = []
            neighbors_state_dicts.append(dictionary_t_1[client_number_num])
            for j in range(len(client_neighbors)):
                # get model weights of neighbor
                neighbor_model = dictionary_t_1[int(client_neighbors[j])]
                neighbors_state_dicts.append(neighbor_model)
            
            #call fed_avg
            average_state_dict = {}
            average_state_dict = fed_avg(neighbors_state_dicts)
            dictionary_t[client_number_num] = average_state_dict    
            
        dictionary_t_1 = dictionary_t
    return test_losses, test_accs
    
    

def train_batches(graph_data, model, optimizer, criterion, batch_size):
    # graph_dataset = [graph_data.x, graph_data.edge_index, graph_data.y]
    graph_dataset = GraphDataset(graph_data)
    graph_data_loader = DataLoader(graph_dataset, batch_size=batch_size, shuffle=True)
    total_loss = 0
    model.train()
    for i, data in enumerate(graph_data_loader):
      x, edge_index, y, idx = data
      nodes = list(idx)
      sub_edge_index, _ = subgraph(nodes, edge_index[0], relabel_nodes=True)
    #   print(sub_edge_index.shape)
      optimizer.zero_grad() 
      out = model(x, sub_edge_index) 
      loss = criterion(out, y)  
      loss.backward() 
      optimizer.step() 
      total_loss += loss.item()  
    return total_loss/(i+1)

In [29]:
# only 188 nodes

class GraphDataset(torch.utils.data.Dataset):
    
    def __init__(self, graph_data):
        
        self.graph_data = graph_data
        
    def __len__(self):
        return len(self.graph_data.x)
    
    def __getitem__(self, idx):

        x = self.graph_data.x[idx,:]
        y = self.graph_data.y[idx]
        edge_index = self.graph_data.edge_index
        # connected_edges = (self.graph_data.edge_index[0] == idx) | (self.graph_data.edge_index[1] == idx)
        # edge_index = self.graph_data.edge_index[:, connected_edges]
        return x, edge_index, y, idx

## Using Adam Optimizer, lr = 0.01, and varying batch sizes

In [None]:
batch_sizes = [21, 16, 32, 64]


for batch_size in batch_sizes:
    dictionary_t_1, test_losses, test_accs = initial_training(1, 0.01 , batch_size)

    test_losses, test_accs = communication(dictionary_t_1, neighbors, 1, 100, 0.01, batch_size)
    mean_test_losses = sum(test_losses)/len(test_losses)
    mean_test_accs = sum(test_accs)/len(test_accs)
    print("batch size: ", batch_size)
    print("mean test accuracy after communication: ", mean_test_accs)
    print("mean test loss after communication: ", mean_test_losses)
    print()

### 5 epochs, 100 rounds

In [None]:
batch_sizes = [32, 64, 100, 128]

for batch_size in batch_sizes:
    dictionary_t_1, test_losses, test_accs = initial_training(5, 0.01 , batch_size)

    test_losses, test_accs = communication(dictionary_t_1, neighbors, 5, 100, 0.01, batch_size)
    mean_test_losses = sum(test_losses)/len(test_losses)
    mean_test_accs = sum(test_accs)/len(test_accs)
    print("batch size: ", batch_size)
    print("mean test accuracy after communication: ", mean_test_accs)
    print("mean test loss after communication: ", mean_test_losses)
    print()

In [None]:
batch_sizes = [16, 32, 64, 100, 128]

for batch_size in batch_sizes:
    dictionary_t_1, test_losses, test_accs = initial_training(10, 0.01 , batch_size)

    test_losses, test_accs = communication(dictionary_t_1, neighbors, 10, 100, 0.01, batch_size)
    mean_test_losses = sum(test_losses)/len(test_losses)
    mean_test_accs = sum(test_accs)/len(test_accs)
    print("batch size: ", batch_size)
    print("mean test accuracy after communication: ", mean_test_accs)
    print("mean test loss after communication: ", mean_test_losses)
    print()

# Different Federated Averaging Techniques

## Normalized federated averaging technique 

In [None]:
dictionary_t_1, test_losses, test_accs = initial_training(5, 0.01)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 5, 100, 0.01)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test accuracy after communication: ", mean_test_accs)
print("mean test loss after communication: ", mean_test_losses)

## Exponential weighted federated averaging technique

In [None]:
import numpy as np

dictionary_t_1, test_losses, test_accs = initial_training(5, 0.01)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 5, 100, 0.01)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test accuracy after communication: ", mean_test_accs)
print("mean test loss after communication: ", mean_test_losses)

# Threshold federated averaging technique

In [None]:
dictionary_t_1, test_losses, test_accs = initial_training(5, 0.01)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 5, 100, 0.01)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test accuracy after communication: ", mean_test_accs)
print("mean test loss after communication: ", mean_test_losses)

## Differential Privacy 

In [114]:
import torch
import torch.nn as nn

class DifferentiallyPrivateGCN(nn.Module):
    def __init__(self, gcn_model, lr, noise_multiplier):
        super(DifferentiallyPrivateGCN, self).__init__()
        self.gcn_model = gcn_model
        self.lr = lr
        self.noise_multiplier = noise_multiplier
        self.optimizer = torch.optim.Adam(self.gcn_model.parameters(), lr=self.lr, weight_decay=5e-4)

    def forward(self, x, edge_index):
        return self.gcn_model(x, edge_index)

    def train_step(self, x, edge_index, y):
        # self.gcn_model.zero_grad()
        self.optimizer.zero_grad()
        output = self.forward(x, edge_index)
        criterion = torch.nn.CrossEntropyLoss()
        loss = criterion(output, y)
        # loss = F.nll_loss(output, y)
        loss.backward() 
        
        for param in self.gcn_model.parameters():
            noise = torch.randn_like(param.grad) * self.noise_multiplier
            param.grad.add_(noise)
            # param.data.add_(-self.lr * param.grad)
            
        self.optimizer.step()

        return output, loss

    #   sub_data = sub_data.to(device)
    #   model.train()
    #   optimizer.zero_grad() 
    #   out = model(sub_data.x, sub_data.edge_index)  # Perform a single forward pass.
    #   loss = criterion(out, sub_data.y)  # Compute the loss solely based on the training nodes.
    #   loss.backward() 
    #   optimizer.step() 
    
    # def load_state_dict(self, state_dict):
    #     state_dict = {k.replace('gcn_model.', ''): v for k, v in state_dict.items()}
    #     self.gcn_model.load_state_dict(state_dict)

In [118]:
import matplotlib.pyplot as plt
from opacus import PrivacyEngine


def test(test_data, criterion, model):
      model.eval()
      out = model(test_data.x, test_data.edge_index)
      pred = out.argmax(dim=1)  # Use the class with highest probability.
      test_correct = pred == test_data.y # Check against ground-truth labels.
      test_acc = int(test_correct.sum()) / len(test_data.y)  # Derive ratio of correct predictions.
      test_loss = criterion(out, test_data.y)  # Compute validation loss
      
      return test_loss, test_acc


def validate(test_data, model, criterion):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Do not compute gradients during this step
        out = model(test_data.x, test_data.edge_index)  # Forward pass
        pred = out.argmax(dim=1)  # Get predicted classes
        val_correct = pred == test_data.y # Compare with ground-truth
        val_loss = criterion(out, test_data.y)  # Compute validation loss
        val_acc = int(val_correct.sum()) / int(len(test_data.y))  # Compute validation accuracy
    return val_loss.item(), val_acc  


def initial_training(epochs, learning_rate):
    test_losses = []
    test_accs = []
    dictionary_t_1 = {}
    dictionary_t = {}

    for i in range(len(sub_data_list)):
        sub_data = sub_data_list[i]
        gcn_model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)
        model = DifferentiallyPrivateGCN(gcn_model, learning_rate, 0.01)
        criterion = torch.nn.CrossEntropyLoss()
        train_losses = [] 
        val_losses = []    

        # Training loop for each epoch (adjust the range as needed)
        for epoch in range(1, epochs + 1):
            model.train()  # Set the model to training mode
            output, loss = model.train_step(sub_data.x, sub_data.edge_index, sub_data.y)
            train_losses.append(loss.item())
            val_loss, _ = validate(test_data, model, criterion)  # Assuming 'validate' returns a loss and something else (like accuracy)
            val_losses.append(val_loss)
            
        
        test_loss, test_acc = test(test_data, criterion, model)
        test_losses.append(test_loss.item())
        test_accs.append(test_acc)


        model_weights =  model.gcn_model.state_dict()
        client_number_num = client_number[i]
        dictionary_t_1[client_number_num] = model_weights
        
    
    return dictionary_t_1, test_losses, test_accs

In [119]:

def communication(dictionary_t_1_og, neighbors, epochs, rounds, learning_rate):
    # make a copy of dictionary_t_1
    test_losses = []
    test_accs = []
    
    dictionary_t_1 = dictionary_t_1_og.copy()
    dictionary_t = {}
    for k in range(1, rounds + 1):

    # get the neighbprs of each client
        for i in range(len(sub_data_list)):
            sub_data = sub_data_list[i]
            
            gcn_model = GCN(sub_data.num_node_features, dataset.num_classes).to(device)
            state_dict = dictionary_t_1[client_number[i]]
            gcn_model.load_state_dict(state_dict)
            model = DifferentiallyPrivateGCN(gcn_model, learning_rate, 0.01)
            criterion = torch.nn.CrossEntropyLoss()
            train_losses = [] 
            val_losses = []    

            for epoch in range(1, epochs + 1):
                
                model.train()
                output, loss = model.train_step(sub_data.x, sub_data.edge_index, sub_data.y)
                train_losses.append(loss.item())
                val_loss, _ = validate(test_data, model, criterion)
                val_losses.append(val_loss)
                
            
            test_loss, test_acc = test(test_data, criterion, model)
            if k == rounds:
                test_losses.append(test_loss.item())
                test_accs.append(test_acc)

        
            # add the model weights to the t -1 dictionary
            model_weights =  model.gcn_model.state_dict()
            client_number_num = client_number[i]
            dictionary_t_1[client_number_num] = model_weights
        

        for i in range(100):
            client_number_num = client_number[i]
            # go through neighbors of client_number
            string_client_num = str(client_number_num)
            client_neighbors = neighbors[string_client_num]
            
            neighbors_state_dicts = []
            neighbors_state_dicts.append(dictionary_t_1[client_number_num])
            for j in range(len(client_neighbors)):
                # get model weights of neighbor
                neighbor_model = dictionary_t_1[int(client_neighbors[j])]
                neighbors_state_dicts.append(neighbor_model)
            
            #call fed_avg
            average_state_dict = {}
            average_state_dict = fed_avg(neighbors_state_dicts)
                
            dictionary_t[client_number_num] = average_state_dict    
            
        dictionary_t_1 = dictionary_t

    return test_losses, test_accs
 
        

In [120]:
dictionary_t_1, test_losses, test_accs = initial_training(5, 0.01)

test_losses, test_accs = communication(dictionary_t_1, neighbors, 5, 100, 0.01)
mean_test_losses = sum(test_losses)/len(test_losses)
mean_test_accs = sum(test_accs)/len(test_accs)
print("mean test accuracy after communication: ", mean_test_accs)
print("mean test loss after communication: ", mean_test_losses)

mean test accuracy after communication:  0.6669329073482426
mean test loss after communication:  1.173405447602272
