## GAT IMPLEMENTATION

In [1]:
TOTAL_CLIENTS = 610
LR=0.01
INPUT_CHANNELS = 256
HIDDEN_CHANNELS = 256
EPOCHS = 300 #3
CLIENTS_COUNT= 10 #128

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, GATConv
from torch.autograd import Variable
from torch_geometric.data import Data

import numpy as np
import pandas as pd
import random

In [3]:
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.hidden_channels = INPUT_CHANNELS
        self.input_channels = HIDDEN_CHANNELS
        self.headsv1 = 4
        self.headsv2 = 1
 
        self.conv1 = GATConv(in_channels=self.input_channels, out_channels=self.hidden_channels, 
                               heads=self.headsv1, dropout=0.2)
#         self.conv2 = GATv2Conv(in_channels=self.hidden_channels*self.headsv1, out_channels=self.hidden_channels,
#                              heads=self.headsv2, dropout=0.6, concat=False,)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index  
        x_in = Variable(x, requires_grad=True)
        x = F.dropout(x_in, p=0.2, training=self.training) 
        x = self.conv1(x, edge_index)                   
        x = F.elu(x)
        
        y = x[0,:] * x[1:,:]
        y = torch.sum(y, dim=1, dtype=float)
        y = F.relu(y)
        y.retain_grad()
        return x_in, y
    
    

## Federated Pipeline

In [4]:
class FederatedNetwork:
    def __init__(self, device, state_dict):
        self.model = None
        self.optimizer = None
        self.criterion = None
        self.initialize_model(device, state_dict)
        
    def initialize_model(self, device, state_dict):
        self.model = GAT().to(device)
        if state_dict != None:
            self.model.load_state_dict(state_dict)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=LR, weight_decay=5e-4)
        self.criterion = nn.MSELoss()  #Square it later
        

In [5]:
class Client(FederatedNetwork):
    def __init__(self, client_id, rated_items, items_embeddings, device, state_dict):
        self.items_embeddings = dict()
        self.user_embedding = None
        self.items_rated = None
        self.items_embeddings_grad = None
        self.id = client_id
        self.graph_data = None
        self.y = None
        self.all_embeddings = None
        self.initialize_client(rated_items, items_embeddings, device)
        super().__init__(device, state_dict)
        
    def initialize_client(self, rated_items, items_embeddings, device):
        self.initialize_rated_items(rated_items)
        self.initialize_embeddings(items_embeddings)
        self.initalize_graph(device)
    
    def initialize_rated_items(self, rated_items):
        self.items_rated = rated_items
        
    def initialize_embeddings(self, items_embeddings):
        self.user_embedding = torch.nn.init.xavier_uniform_(torch.empty(1, 256))
        self.update_item_embeddings(items_embeddings)
        
    def update_item_embeddings(self, items_embeddings):
        self.all_embeddings = items_embeddings
        self.items_embeddings = items_embeddings[items_embeddings['movieId'].isin(self.items_rated["movieId"])]
            
            
    def get_items_embeddings_grad(self):
        return self.items_embeddings_grad
    
    def get_item_embeddings(self):
            return self.all_embeddings
    
    def generate_graph_from_data(self):
        list_a = [0]*(len(self.items_rated)) + [i for i in range(1, len(self.items_rated)+1)]
        list_b = [i for i in range(1, len(self.items_rated)+1)]+[0]*(len(self.items_rated)) 
        edge_index = torch.tensor([list_a,
                           list_b], dtype=torch.long)
        x = [self.user_embedding.numpy()[0], ]
        item_emb = self.items_embeddings['embeddings'].values
        x += [item_emb[i] for i in range(len(item_emb))] 
        x = torch.tensor(np.array(x), dtype=torch.float)
        
        y = torch.tensor(self.items_rated['rating'].values)
        
        return x, y, edge_index
     
    def initalize_graph(self, device):
        x, y, edge_index = self.generate_graph_from_data()
        self.graph_data = Data(x=x, edge_index=edge_index)
        self.graph_data = self.graph_data.to(device)
        self.y = y
        
    def item_count(self):
        return len(self.items_rated)
        
    
    def train_model(self, lr=LR):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=5e-4)
        self.model.train()
        
        for epoch in range(EPOCHS):
            self.optimizer.zero_grad()
            x, out = self.model(self.graph_data)
            loss = torch.sqrt(self.criterion(out, self.y))
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
            optimizer.step()  
            self.graph_data.x -= lr*x.grad
            
        self.__update_embeddings_from_model(self.graph_data.x)
        return loss
    
    def __update_embeddings_from_model(self, graph_x):
        self.user_embedding = graph_x[0]
        for i in range(1, len(graph_x)):
            index = self.all_embeddings.index[self.all_embeddings['movieId'] == self.items_embeddings.iloc[i-1]['movieId']]
            self.items_embeddings.iat[i-1, 1] =  graph_x[i]
            self.all_embeddings.iat[index[0], 1] =  graph_x[i]
          
    
    def evaluate_model(self, data=None):
        if data == None:
            data = self.graph_data
        self.model.eval()
        _, pred = self.model(data)
        pred = torch.round(pred.data)

        print("\n\nLocal:\nActual: ", self.y[:10])
        print("Predicted: ", pred[:10])

        correct = float(pred.eq(self.y).sum().item())
        acc = correct / len(self.y)
        print('Accuracy: {:.4f}'.format(acc))
        return acc
    
    def evaluate_global(self, model):
        data = self.graph_data
        model.eval()
        _, pred = model(data)
        pred = torch.round(pred.data)
        print("\n\nClient: ", self.id)
        print("Global\nActual: ", self.y[:10])
        print("Predicted: ", pred[:10])

        correct = float(pred.eq(self.y).sum().item())
        acc = correct / len(self.y)
        return acc
#         print('Accuracy: {:.4f}'.format(acc))
    
    def __str__(self):
        return repr(self.items_rated)
        

In [6]:
class Server(FederatedNetwork):
    def __init__(self, items, device, empty):
        super().__init__(device, empty)
        self.items = items        
        self.items_embeddings = None
    
    def generate_item_embeddings(self):
        embeddings = torch.nn.init.xavier_uniform_(torch.empty(self.items.shape[0], 256))
        df = pd.DataFrame({"id": np.arange(1, embeddings.shape[0]+1)})
        df["embeddings"] = list(embeddings.numpy())
        self.items_embeddings = pd.concat([self.items['movieId'], df["embeddings"]], axis=1)
        return self.items_embeddings
    
    def get_item_embeddings(self):
        return self.items_embeddings   
        

In [7]:
class Driver:

    def __init__(self, device, state_dict):
        self.server = None
        self.clients = None
        self.ratings_data = None
        self.initialize_server(device)
        self.initialize_clients(device, state_dict)

    def initialize_clients(self, device, state_dict, client_count=CLIENTS_COUNT):
        self.ratings_data = pd.read_csv('ml-latest-small/ratings.csv')
        self.ratings_data.drop('timestamp', inplace=True, axis=1)
        clients = random.sample(range(1, TOTAL_CLIENTS+1), client_count)
        self.clients = []
        for i in range(0, client_count):
            id = clients[i]
            client_i = Client(id, self.ratings_data[self.ratings_data['userId'] == id], self.server.get_item_embeddings(), device, state_dict)
            self.clients.append(client_i)
            
#         self.clients = pd.DataFrame(self.clients, columns=['clients'])
        return self.clients
    
    def initialize_server(self, device):
        items = pd.read_csv('ml-latest-small/movies.csv')
        self.server = Server(items, device, None)
        embeddings = self.server.generate_item_embeddings()
        return embeddings
        
    def get_embeddings(self):
        return self.server.get_item_embeddings()


## Driver Code for Training on Client 1

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
global_model = GAT().to(device)
driver_obj = Driver(device, global_model.state_dict())

for training_round in range(1):
    total_items = 0
    weights = []
    embeddings = [] 
    items_rated = []
    for client in driver_obj.clients:
        client.train_model()
        total_items += client.item_count()
        weights.append(client.model.state_dict())
        embeddings.append(client.get_item_embeddings())
        items_rated.append(client.item_count())


    #AVERAGE: CONVERT TO WEIGHTED AVERAGE
    new_parameters = global_model.state_dict()

    for key in new_parameters:
        new_parameters[key] = weights[0][key]
        for i in range(1, len(weights)):
            new_parameters[key] += weights[i][key]
        new_parameters[key]/=float(CLIENTS_COUNT)

    global_model.load_state_dict(new_parameters)

    acc = 0
    l_acc = 0
    for client in driver_obj.clients:
#         acc += client.evaluate_global(global_model)
        l_acc += client.evaluate_model()
    
    print("\n\nRound: ", training_round)
    print('Global Accuracy: {:.4f}'.format(acc/CLIENTS_COUNT))
    print('Local Accuracy: {:.4f}'.format(l_acc/CLIENTS_COUNT))

    # for client in driver_obj.clients:
    #     client.update_weights(global_model.state_dict())
    #     client.update_embeddings(global_embeddings)



Local:
Actual:  tensor([3., 3., 2., 3., 3., 3., 3., 2., 3., 3.], dtype=torch.float64)
Predicted:  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)
Accuracy: 0.0000


Local:
Actual:  tensor([4., 4., 5., 5., 4., 5., 4., 4., 4., 5.], dtype=torch.float64)
Predicted:  tensor([4., 4., 4., 4., 4., 4., 4., 4., 4., 4.], dtype=torch.float64)
Accuracy: 0.3960


Local:
Actual:  tensor([4.0000, 4.5000, 4.5000, 4.0000, 4.0000, 4.0000, 4.0000, 4.0000, 5.0000,
        5.0000], dtype=torch.float64)
Predicted:  tensor([4., 4., 4., 4., 4., 4., 4., 4., 4., 4.], dtype=torch.float64)
Accuracy: 0.3723


Local:
Actual:  tensor([3.0000, 3.0000, 3.0000, 2.0000, 3.5000, 3.0000, 4.0000, 4.5000, 4.5000,
        1.0000], dtype=torch.float64)
Predicted:  tensor([3., 3., 3., 3., 3., 4., 3., 4., 4., 3.], dtype=torch.float64)
Accuracy: 0.2357


Local:
Actual:  tensor([4.5000, 2.5000, 4.5000, 4.0000, 4.5000, 4.0000, 4.5000, 4.0000, 4.0000,
        4.5000], dtype=torch.float64)
Predicted:  tensor([

In our experiments, we use graph attention network (GAT) [28]
as the GNN model, and use dot product to implement the rating
predictor. The user and item embeddings and their hidden represen-
tations learned by graph neural networks are 256-dim. The epoch
threshold 𝑇 is 2. The gradient clipping threshold 𝛿 is set to 0.1, and
the strength of Laplacian noise in the LDP module is set to 0.2 to
achieve 1-differential privacy. The number of pseudo interacted
items is set to 1,000. The number of users used in each round of
model training is 128, and the total number of epoch is 3. The ratio
of dropout [25] is 0.2. SGD is selected as the optimization algorithm,
and its learning rate is 0.01. The splits of datasets are the same as
those used in [2], and these hyperparameters are selected accordingo
to the validation performance. The metric used in our experiments
is rooted mean square error (RMSE), and we report the average
RMSE scores over the 10 repetitions.


FedAVG is used as aggregator