## Imports and configurations

In [None]:
# Standard library imports
import pickle
import warnings
import math
import base64
import time

# Third-party imports
import zmq
import numpy as np
import torch
import torch.nn as nn

# Configurations and settings
warnings.filterwarnings('ignore')
np.set_printoptions(suppress=False)
torch.set_printoptions(sci_mode=False)
# Assuming UTF-8 encoding, change to something else if you need to
base64.b64encode("password".encode("utf-8"))

## Discriminator class

In [None]:
class Discriminator(nn.Module):
    def __init__(self, no_of_channels=1, disc_dim=32):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(

                nn.Conv2d(in_channels=no_of_channels, out_channels=disc_dim, kernel_size=4, stride=2, padding=1, bias=False),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),

                nn.Conv2d(in_channels=disc_dim, out_channels=disc_dim * 2, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(disc_dim * 2, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True),

                nn.Conv2d(in_channels=disc_dim * 2, out_channels=disc_dim * 4, kernel_size=3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(disc_dim * 4, track_running_stats=False),
                nn.LeakyReLU(0.2, inplace=True),

                nn.Conv2d(in_channels=disc_dim * 4, out_channels=1, 
                          kernel_size=4, stride=1, padding=0, bias=False),
                nn.Sigmoid()
            )
    def forward(self, input):
        '''
        forward pass of the discriminator
        Input is an image tensor, 
        returns a 1-dimension tensor representing image as    
        fake/real.
        '''
        output = self.network(input)
        return output.view(-1, 1).squeeze(1)

## Generator class

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layers = nn.Sequential(*[
                                      self.conv_block(100, 128, padding=0),
                                      self.conv_block(128, 64, stride=2, ks=3),
                                      self.conv_block(64, 32, stride=2),
                                      self.conv_block(32, 1, stride=2, bn=False, out_layer=True)
        ])
        # Our input is 100 dimensional random noise

    @staticmethod
    def conv_block(in_c, out_c, out_layer=False, ks=4, stride=1, padding=1, bias=False, bn=True):
        l = [nn.ConvTranspose2d(in_c, out_c, ks, stride=stride, padding=padding, bias=bias)]
        if bn: l.append(nn.BatchNorm2d(out_c, track_running_stats=False))
        if out_layer: l.append(nn.Tanh())
        else: l.append(nn.ReLU(True))
        return nn.Sequential(*l)

    def forward(self, x):
        return self.layers(x)

## Declaring method to convert the size

In [None]:
def convert_size(size_bytes):
    if size_bytes == 0:
        return "0B"
    size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return "%s %s" % (s, size_name[i])

## Running ?

In [None]:
ngpu = 1 
G = Generator()
# print(G)
D = Discriminator()

D.load_state_dict(torch.load('Dis_HE.ckpt'))
discrimiator = D.state_dict()
# for k, v in discrimiator.items():
#     print(k)
    
G.load_state_dict(torch.load('Gen_HE.ckpt'))
generator = G.state_dict()
# for k, v in generator.items():
#     print(k)

vals111 = []
keys = []
vals = []
for k, v in discrimiator.items():
    keys.append(k)
    vals.append(v)
    a = v.numpy()
    vals111.append(a.shape)
# print(vals111)
# vals111 = numpy.array(vals)   
# keys = numpy.array(keys)
# vals = numpy.array(vals)

keys1 = []
vals1 = []
for k1, v1 in generator.items():
    keys1.append(k1)
    vals1.append(v1)
# keys1 = numpy.array(keys1)
# vals1 = numpy.array(vals1)

# print("Model's state_dict:")
# for param_tensor in model2.state_dict():
#     print(torch.numel(model2.state_dict()[param_tensor]))

## Declaring methods to print the elapsed time

In [None]:
def elapsed_time_total(start, end):
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Total Traning Time: {:0>2}:{:0>2}:{:05.2f}"
                .format(int(hours),int(minutes),seconds))

def elapsed_time_avg(start, end):
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Averaging overhead: {:0>2}:{:0>2}:{:05.2f}"
                .format(int(hours),int(minutes),seconds))

## Declaring methods to save and load the model

In [None]:
def write_data(file_name, data):
    if type(data) == bytes:
        #bytes to base64
        data = base64.b64encode(data)
         
    with open(file_name, 'wb') as f: 
        f.write(data)
 
def read_data(file_name):
    with open(file_name, "rb") as f:
        data = f.read()
    #base64 to bytes
    return base64.b64decode(data)

## Main server loop

In [None]:
global data_list
global client_num

client_num = 0

context = zmq.Context()
socket = context.socket(zmq.ROUTER)
socket.bind("tcp://*:5555")

pub_socket = context.socket(zmq.PUB)
pub_socket.bind("tcp://*:5557")

start_total = time.time()
print("The server is running now!")

c = 0
data_list = []
loaded_enc = []
loaded_enc_tmp = []
cipher1 = []
cipher2 = []
sum_ = 0

data_list_dicriminator = []
data_list_generator = []
sum_1 = 0
sum_2 = 0

client_num_ = 10
while c < client_num_ * 10:
#     print(G)
    ident1, msg1 = socket.recv_multipart()
    ident2, msg2 = socket.recv_multipart()
        
    string = b"New"
    
    if string == msg1 and string == msg2: 
        client_num = client_num + 1
        
        message1 = pickle.dumps(vals)
        socket.send_multipart([ident1, message1])
        
        message2 = pickle.dumps(vals1)
        socket.send_multipart([ident2, message2])
        
        print("Base model sent to the new client!")
    else: 
        print("Training round started")
        
        message1 = pickle.loads(msg1)
        message2 = pickle.loads(msg2)
        
        if len(message1) == 8:
            data_list_dicriminator.append(message1)
        else:
            data_list_generator.append(message1)
        
        if len(message2) == 8:
            data_list_dicriminator.append(message2)
        else:
            data_list_generator.append(message2)
            
        print(len(data_list_dicriminator))
        print(len(data_list_generator))

        if len(data_list_dicriminator) == client_num_ and len(data_list_generator) == client_num_:
            print("Enough data recevied")
            
            start_avg = time.time()
        
            cipher1 = sum(data_list_dicriminator) / client_num_
            cipher2 = sum(data_list_generator) / client_num_

            print("Avgg generator encrypted computed")
            end_avg = time.time()
            elapsed_time_avg(start_avg, end_avg) 
            
            message1 = pickle.dumps(cipher1)
            print("Plain data size in bytes {}".format(convert_size(len(message1))))

            message2 = pickle.dumps(cipher2)
            print("Avaraged dicriminator {}".format(convert_size(len(message2))))
            
            pub_socket.send(message1)
            pub_socket.send(message2)
            
            print("Sent!")
            
            cipher1 = []
            cipher2 = []
            sum_1 = 0
            sum_2 = 0
            sum_final1 = 0
            sum_final2 = 0
            data_list_dicriminator = []
            data_list_generator = []

        c = c + 1
        
end_total = time.time()
elapsed_time_total(start_total, end_total)
print(c)