## Dataset Link: https://github.com/Aniruddha-Tapas/Predicting-Diseases-From-Symptoms/tree/master/Manual-Data

## Import Libraries

In [None]:
!pip install --quiet --upgrade tensorflow_federated

In [None]:
import os
import re
import csv
import pickle
import random
import time
import datetime
import unicodedata
import pandas as pd
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

## Directory-Paths

**No need to change if dataset or path is not changed!**



In [None]:
# base directory
BASE_DIR = '/content/drive/MyDrive/Research Works/International Papers/BigComp/2020/disease_prediction_attention'

# data directory
DATA_DIR = os.path.join(BASE_DIR, 'data_attention_federated')

# manual input directory
MANUAL_INPUT_DIR = os.path.join(DATA_DIR, 'manual_input')

# output directory
OUTPUT_DIR = os.path.join(DATA_DIR, 'output_2021July')
os.makedirs(OUTPUT_DIR, exist_ok=True)

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# current output directory
CURRENT_OUTPUT_DIR = os.path.join(OUTPUT_DIR, current_time)
os.makedirs(CURRENT_OUTPUT_DIR, exist_ok=True)

# output log directory
LOG_DIR = os.path.join(CURRENT_OUTPUT_DIR, 'logs')
os.makedirs(LOG_DIR, exist_ok=True)

# output image directory
IMG_DIR = os.path.join(CURRENT_OUTPUT_DIR, 'image')
os.makedirs(IMG_DIR, exist_ok=True)

# output model directory
MODEL_DIR = os.path.join(CURRENT_OUTPUT_DIR, 'model')
os.makedirs(MODEL_DIR, exist_ok=True)

## Analysis on Manual Data

In [None]:
for dirname, _, filenames in os.walk(MANUAL_INPUT_DIR):
    for filename in filenames:
        print(os.path.join(dirname, filename))

### Read CSV to Dataframe

In [None]:
# Training data
train_df = pd.read_csv(os.path.join(MANUAL_INPUT_DIR, 'Training.csv'))
print("Dataset with rows {} and columns {}".format(train_df.shape[0],train_df.shape[1]))
train_df.head()

In [None]:
train_df.describe()

In [None]:
# Testing data
test_df = pd.read_csv(os.path.join(MANUAL_INPUT_DIR, 'Testing.csv'))
print("Dataset with rows {} and columns {}".format(test_df.shape[0],test_df.shape[1]))
test_df.head()

In [None]:
test_df.describe()

## Data Preprocessing (manual data already split into 5 clients and datafiles created)

In [None]:
# Converts the unicode file to ascii
def unicode_to_ascii(s):
  return ''.join(c for c in unicodedata.normalize('NFD', s)
      if unicodedata.category(c) != 'Mn')

In [None]:
def preprocess_sentence(w):
  w = unicode_to_ascii(w.lower().strip())

  # creating a space between a word and the punctuation following it
  w = re.sub(r"([?.!,¿])", r" \1 ", w)
  w = re.sub(r'[" "]+', " ", w)

  # replacing everything with space except (a-z, A-Z, ".", "?", "!", ",", "_")
  w = re.sub(r"[^a-zA-Z?.!,¿_]+", " ", w)

  w = w.strip()

  # adding a start and an end token to the sentence
  # so that the model know when to start and stop predicting.
  w = '<start> ' + w + ' <end>'
  return w

In [None]:
def create_dataset(path, num_examples):
  lines = []
  symptoms_list = []
  diseases_list = []

  f = open(path, 'r')
  rd = csv.reader(f, delimiter=' ', skipinitialspace=True)

  for row in rd:
    lines.append(row)

  for i in range(len(lines)):
    # symptom = '<start> ' + str(lines[i][:-1]) + ' <end>'
    symptom = '<start> '
    for symp_item in lines[i][:-1]:
      symptom += symp_item + ' '
    symptom += '<end>'
    symptoms_list.append(symptom)

    disease_arr = lines[i][-1].split('_')
    disease = '<start> '
    for disease_item in disease_arr:
      disease += disease_item + ' '
    disease += '<end>'
    diseases_list.append(disease)

  return symptoms_list, diseases_list

In [None]:
# train_data = pd.read_csv(os.path.join(MANUAL_INPUT_DIR, 'Training.csv'))
# train_data.prognosis = train_data.prognosis.apply(lambda x: x.replace(' ','_'))

# # shuffle dataframe
# shuffled_train_data = shuffle(train_data)

# num_clients = 5
# split_interval = 1000
# client_data = []
# processed_client = []

# for i in range(num_clients):
#     data = shuffled_train_data.iloc[split_interval*i:split_interval*(i+1)]
#     client_data.append(data)
#     client_data[i].to_csv(os.path.join(MANUAL_INPUT_DIR, 'client_'+str(i)+'.csv'), index=False)

#     client = pd.DataFrame()
#     for j in range(len(client_data[i].columns)):
#       z = client_data[i].iloc[:,j].replace(1, client_data[i].columns[j])
#       client = client.append(z)
#     client = client.replace(0, '')
#     client = client.T
    
#     processed_client.append(client)
#     np.savetxt(os.path.join(MANUAL_INPUT_DIR, 'data_file_' + str(i) + '.txt'), processed_client[i].values, fmt='%s')

In [None]:
def tokenize(lang):
  lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
  lang_tokenizer.fit_on_texts(lang)

  tensor = lang_tokenizer.texts_to_sequences(lang)

  tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor, padding='post')

  return tensor, lang_tokenizer

In [None]:
def load_dataset(path, num_examples=None):
  # creating cleaned input, output pairs
  symptoms_list, diseases_list = create_dataset(path, num_examples)

  in_tensor, in_token = tokenize(symptoms_list)
  tar_tensor, tar_token = tokenize(diseases_list)

  return in_tensor, tar_tensor, in_token, tar_token

In [None]:
def convert(lang, tensor):
  for t in tensor:
    if t!=0:
      print("%d ----> %s" % (t, lang.index_word[t]))

### Training Data

In [None]:
# Try experimenting with the size of that dataset
num_clients = 5
num_examples = 1000

symptom_tensor = []
disease_tensor = []
symptom_token = []
disease_token = []

max_length_symptom = []
max_length_disease = []

symptom_tensor_train = []
symptom_tensor_val = []
disease_tensor_train = []
disease_tensor_val = []

for i in range(num_clients):
  in_tensor, tar_tensor, in_token, tar_token = load_dataset(os.path.join(MANUAL_INPUT_DIR, 'data_file_' + str(i) + '.txt'), num_examples)
  symptom_tensor.append(in_tensor)
  disease_tensor.append(tar_tensor)
  symptom_token.append(in_token)
  disease_token.append(tar_token)

  # Calculate max_length of the tensors
  symptom_len, disease_len = symptom_tensor[i].shape[1], disease_tensor[i].shape[1]
  max_length_symptom.append(symptom_len)
  max_length_disease.append(disease_len)
  
  # Creating training and validation sets using an 80-20 split
  in_tensor_train, in_tensor_val, tar_tensor_train, tar_tensor_val = train_test_split(symptom_tensor[i], disease_tensor[i], test_size=0.2)

  symptom_tensor_train.append(in_tensor_train)
  symptom_tensor_val.append(in_tensor_val)
  
  disease_tensor_train.append(tar_tensor_train)
  disease_tensor_val.append(tar_tensor_val)

  # Show length
  print(len(symptom_tensor_train[i]), len(disease_tensor_train[i]), len(symptom_tensor_val[i]), len(disease_tensor_val[i]))

  # print('Symptom_' + str(i) + '; index to word mapping')
  # convert(symptom_token[i], symptom_tensor_train[i][0])
  # print()
  # print('Disease_' + str(i) + '; index to word mapping')
  # convert(disease_token[i], disease_tensor_train[i][0])
  # print()

### Test Data

In [None]:
test_data = pd.read_csv(os.path.join(MANUAL_INPUT_DIR, 'Testing.csv'))
test_data.prognosis = test_data.prognosis.apply(lambda x: x.replace(' ','_'))

test_processed_df = pd.DataFrame()
for i in range(len(test_data.columns)):
  z = test_data.iloc[:,i].replace(1, test_data.columns[i])
  test_processed_df = test_processed_df.append(z)
test_processed_df = test_processed_df.replace(0, '')
test_processed_df = test_processed_df.T
test_processed_df.head()

In [None]:
np.savetxt(os.path.join(MANUAL_INPUT_DIR, 'test_file.txt'), test_processed_df.values, fmt='%s')

In [None]:
def create_testset(path, num_examples):
  lines = []
  symptoms_list = []
  diseases_list = []

  f = open(path, 'r')
  rd = csv.reader(f, delimiter=' ', skipinitialspace=True)
  for row in rd:
    lines.append(row)
  for i in range(len(lines)):
    symptom = ''
    for symp_item in lines[i][:-1]:
      symptom += symp_item + ' '
    symptoms_list.append(symptom)
    
    disease_arr = lines[i][-1].split('_')
    disease = '<start> '
    for disease_item in disease_arr:
      disease += disease_item + ' '
    disease += '<end>'
    diseases_list.append(disease)

  return symptoms_list, diseases_list

## Create a tf.data dataset

In [None]:
BATCH_SIZE = 16
embedding_dim = 256
units = 1024

train_buffer_size = []
train_step_per_epoch = []
train_vocab_inp_size = []
train_vocab_tar_size = []
train_dataset = []

val_buffer_size = []
val_step_per_epoch = []
val_vocab_inp_size = []
val_vocab_tar_size = []
val_dataset = []

train_example_input_batch = []
train_example_target_batch = []

val_example_input_batch = []
val_example_target_batch = []

for i in range(num_clients):
  train_buff_sz = len(symptom_tensor_train[i])
  train_buffer_size.append(train_buff_sz)

  val_buff_sz = len(symptom_tensor_val[i])
  val_buffer_size.append(val_buff_sz)

  train_epoch_step = len(symptom_tensor_train[i])//BATCH_SIZE
  train_step_per_epoch.append(train_epoch_step)

  val_epoch_step = len(symptom_tensor_val[i])//BATCH_SIZE
  val_step_per_epoch.append(val_epoch_step)

  train_vocab_in_sz = len(symptom_token[i].word_index)+1
  train_vocab_inp_size.append(train_vocab_in_sz)
  train_vocab_tar_sz = len(disease_token[i].word_index)+1
  train_vocab_tar_size.append(train_vocab_tar_sz)

  val_vocab_in_sz = len(symptom_token[i].word_index)+1
  val_vocab_inp_size.append(val_vocab_in_sz)
  val_vocab_tar_sz = len(disease_token[i].word_index)+1
  val_vocab_tar_size.append(val_vocab_tar_sz)

  train_dtst = tf.data.Dataset.from_tensor_slices((symptom_tensor_train[i], disease_tensor_train[i])).shuffle(train_buffer_size[i])
  train_dtst = train_dtst.batch(BATCH_SIZE, drop_remainder=True)
  train_dataset.append(train_dtst)

  val_dtst = tf.data.Dataset.from_tensor_slices((symptom_tensor_val[i], disease_tensor_val[i])).shuffle(val_buffer_size[i])
  val_dtst = val_dtst.batch(BATCH_SIZE, drop_remainder=True)
  val_dataset.append(val_dtst)

  train_eg_in_batch, train_eg_tar_batch = next(iter(train_dataset[i]))
  train_example_input_batch.append(train_eg_in_batch)
  train_example_target_batch.append(train_eg_tar_batch)
  print("Train --> ", train_example_input_batch[i].shape, train_example_target_batch[i].shape)

  val_eg_in_batch, val_eg_tar_batch = next(iter(val_dataset[i]))
  val_example_input_batch.append(val_eg_in_batch)
  val_example_target_batch.append(val_eg_tar_batch)
  print("Val --> ", val_example_input_batch[i].shape, val_example_target_batch[i].shape)

## Encoder Model

In [None]:
class Encoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
    super(Encoder, self).__init__()
    self.batch_sz = batch_sz
    self.enc_units = enc_units
    self.vocab_size = vocab_size

    # The embedding layer converts tokens to vectors
    self.embedding = tf.keras.layers.Embedding(self.vocab_size, embedding_dim)

    # The GRU RNN layer processes those vectors sequentially.
    self.gru = tf.keras.layers.GRU(self.enc_units,
                                   # Return the sequence and state
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')

  def call(self, x, hidden):
    x = self.embedding(x)
    output, state = self.gru(x, initial_state = hidden)
    return output, state

  def initialize_hidden_state(self):
    return tf.zeros((self.batch_sz, self.enc_units))

### Local Encoders

In [None]:
local_encoder = []
local_hidden = []
local_output = []

for i in range(num_clients):
  local_enc = Encoder(train_vocab_inp_size[i], embedding_dim, units, BATCH_SIZE)
  local_encoder.append(local_enc)

  loc_hid = local_encoder[i].initialize_hidden_state()
  local_hidden.append(loc_hid)

  loc_out, loc_hid = local_encoder[i](train_example_input_batch[i], loc_hid)
  local_output.append(loc_out)

  print ('Encoder ' + str(i) + ' output shape: (batch size, sequence length, units) {}'.format(local_output[i].shape))
  print ('Encoder ' + str(i) + ' hidden state shape: (batch size, units) {}'.format(local_hidden[i].shape))

### Global Encoder

In [None]:
global_encoder = Encoder(train_vocab_inp_size[0], embedding_dim, units, BATCH_SIZE)

global_hidden = global_encoder.initialize_hidden_state()

global_output, global_hidden = global_encoder(train_example_input_batch[0], global_hidden)

print ('Global Encoder output shape: (batch size, sequence length, units) {}'.format(global_output.shape))
print ('Global Encoder  Hidden state shape: (batch size, units) {}'.format(global_hidden.shape))

## Attention Layer

In [None]:
class BahdanauAttention(tf.keras.layers.Layer):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, query, values):
    # query hidden state shape == (batch_size, hidden size)
    # query_with_time_axis shape == (batch_size, 1, hidden size)
    # values shape == (batch_size, max_len, hidden size)
    # we are doing this to broadcast addition along the time axis to calculate the score
    query_with_time_axis = tf.expand_dims(query, 1)

    # score shape == (batch_size, max_length, 1)
    # we get 1 at the last axis because we are applying score to self.V
    # the shape of the tensor before applying self.V is (batch_size, max_length, units)
    score = self.V(tf.nn.tanh(
        self.W1(query_with_time_axis) + self.W2(values)))

    # attention_weights shape == (batch_size, max_length, 1)
    attention_weights = tf.nn.softmax(score, axis=1)

    # context_vector shape after sum == (batch_size, hidden_size)
    context_vector = attention_weights * values
    context_vector = tf.reduce_sum(context_vector, axis=1)

    return context_vector, attention_weights

### Local Attention

In [None]:
attention_layer = []
attention_result = []
attention_weights = []

for i in range(num_clients):
  att_layer = BahdanauAttention(10)
  att_result, att_weights = att_layer(local_hidden[i], local_output[i])

  attention_layer.append(att_layer)
  attention_result.append(att_result)
  attention_weights.append(att_weights)

  print('Attention ' + str(i) + ' result shape: (batch size, units) {}'.format(attention_result[i].shape))
  print('Attention ' + str(i) + ' weights shape: (batch_size, sequence_length, 1) {}'.format(attention_weights[i].shape))

### Global Attention

In [None]:
global_attention = BahdanauAttention(10)
global_attention_result, global_attention_weight = global_attention(global_hidden, global_output)

print('Global Attention result shape: (batch size, units) {}'.format(global_attention_result.shape))
print('Global Attention weights shape: (batch_size, sequence_length, 1) {}'.format(global_attention_weight.shape))

## Decoder Model

In [None]:
class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.vocab_size = vocab_size

    # The embedding layer converts tokens to vectors
    self.embedding = tf.keras.layers.Embedding(self.vocab_size, embedding_dim)

    # The GRU RNN layer processes those vectors sequentially.
    self.gru = tf.keras.layers.GRU(self.dec_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    
    self.fc = tf.keras.layers.Dense(self.vocab_size)

    # used for attention
    self.attention = BahdanauAttention(self.dec_units)

  def call(self, x, hidden, enc_output):
    # enc_output shape == (batch_size, max_length, hidden_size)
    context_vector, attention_weights = self.attention(hidden, enc_output)

    # x shape after passing through embedding == (batch_size, 1, embedding_dim)
    x = self.embedding(x)

    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # passing the concatenated vector to the GRU
    output, state = self.gru(x)

    # output shape == (batch_size * 1, hidden_size)
    output = tf.reshape(output, (-1, output.shape[2]))

    # output shape == (batch_size, vocab)
    x = self.fc(output)

    return x, state, attention_weights

### Local Decoders

In [None]:
local_decoder = []

local_decoder_output = []

for i in range(num_clients):
  local_dec = Decoder(train_vocab_tar_size[i], embedding_dim, units, BATCH_SIZE)
  local_decoder.append(local_dec)

  loc_dec_out, _, _ = local_decoder[i](tf.random.uniform((BATCH_SIZE, 1)),
                                      local_hidden[i], local_output[i])
  local_decoder_output.append(loc_dec_out)
  print ('Decoder ' + str(i) + ' output shape: (batch_size, vocab size) {}'.format(local_decoder_output[i].shape))

### Global Decoder

In [None]:
global_decoder = Decoder(train_vocab_tar_size[0], embedding_dim, units, BATCH_SIZE)
global_decoder_output, _, _ = global_decoder(tf.random.uniform((BATCH_SIZE, 1)),
                                      global_hidden, global_output)
print ('Global Decoder output shape: (batch_size, vocab size) {}'.format(global_decoder_output.shape))

## Optimizer and Loss Function

In [None]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_mean(loss_)

## Checkpoints (Object-based saving)

In [None]:
checkpoint_dir = []
checkpoint_prefix = []
checkpoint = []

for i in range(num_clients):
  ckpt_dir = os.path.join(CURRENT_OUTPUT_DIR, 'training_checkpoints_' + str(i))
  checkpoint_dir.append(ckpt_dir)
  ckpt_prefix = os.path.join(checkpoint_dir[i], "ckpt")
  checkpoint_prefix.append(ckpt_prefix)
  ckpt = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=local_encoder[i],
                                 decoder=local_decoder[i])
  checkpoint.append(ckpt)

## Training Clients

In [None]:
def train_step(inp, targ, enc_hidden, encoder, decoder, target_lang):
    loss = 0

    with tf.GradientTape() as tape:
      enc_output, enc_hidden = encoder(inp, enc_hidden)
      dec_hidden = enc_hidden
      dec_input = tf.expand_dims([target_lang.word_index['<start>']] * BATCH_SIZE, 1)

      # Teacher forcing - feeding the target as the next input
      for t in range(1, targ.shape[1]):
        # passing enc_output to the decoder
        predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)

        loss += loss_function(targ[:, t], predictions)

        # using teacher forcing
        dec_input = tf.expand_dims(targ[:, t], 1)

    batch_loss = (loss / int(targ.shape[1]))

    variables = encoder.trainable_variables + decoder.trainable_variables

    gradients = tape.gradient(loss, variables)

    optimizer.apply_gradients(zip(gradients, variables))

    return batch_loss

In [None]:
def get_model_loss(inp, targ, enc_hidden, encoder, decoder, target_lang):
    
    loss = 0

    enc_output, enc_hidden = encoder(inp, enc_hidden)
    dec_hidden = enc_hidden
    dec_input = tf.expand_dims([target_lang.word_index['<start>']] * BATCH_SIZE, 1)

    for t in range(1, targ.shape[1]):
        predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
        loss += loss_function(targ[:, t], predictions)
        dec_input = tf.expand_dims(targ[:, t], 1)

    batch_loss = (loss / int(targ.shape[1]))
    return batch_loss

In [None]:
enc_hidden = []

for i in range(num_clients):
    enc_hid = local_encoder[i].initialize_hidden_state()
    enc_hidden.append(enc_hid)

    total_val_loss = 0

    for (batch, (inp, targ)) in enumerate(val_dataset[i].take(val_step_per_epoch[i])):

        batch_loss = get_model_loss(inp, targ, enc_hidden[i], local_encoder[i], local_decoder[i], disease_token[i] )
        total_val_loss += batch_loss

    print('Client ' + str(i) + ' val_loss: {:.4f}'.format(total_val_loss / val_step_per_epoch[i]))

In [None]:
loss_history = {}
for client in range(num_clients):
    loss_history[client] = []

global_index = 5
loss_history[global_index] = []

In [None]:
EPOCHS = 5
NUM_ROUNDS = 30

for round in range(NUM_ROUNDS):

    print('Round: ', str(round))
    enc_hidden = []
    for i in range(num_clients):
        local_encoder[i].set_weights(global_encoder.get_weights())
        local_decoder[i].set_weights(global_decoder.get_weights())

        for epoch in range(EPOCHS):
            start = time.time()

            enc_hid = local_encoder[i].initialize_hidden_state()
            enc_hidden.append(enc_hid)

            total_loss = 0

            for (batch, (inp, targ)) in enumerate(train_dataset[i].take(train_step_per_epoch[i])):
                batch_loss = train_step(inp, targ, enc_hidden[i], local_encoder[i], local_decoder[i], disease_token[i])
                total_loss += batch_loss


            # saving (checkpoint) the model every 2 epochs
            if (epoch + 1) % 5 == 0:
                checkpoint[i].save(file_prefix = checkpoint_prefix[i])

    #   print('Client ' + str(i) + ' Epoch {} Loss {:.4f}'.format(epoch + 1, total_loss / steps_per_epoch[i]))

        train_batch_loss = 0
        for (batch, (inp, targ)) in enumerate(train_dataset[i].take(train_step_per_epoch[i])):
            batch_loss = train_step(inp, targ, enc_hidden[i], local_encoder[i], local_decoder[i], disease_token[i])
            train_batch_loss += batch_loss

        train_loss = train_batch_loss / train_step_per_epoch[i]

        # Calculate validation loss on all valid data
        for dtst_index in range(num_clients):

            val_batch_loss = 0
            val_loss_dtst_list = []

            for (batch, (inp, targ)) in enumerate(val_dataset[dtst_index].take(val_step_per_epoch[dtst_index])):

                batch_loss = get_model_loss(inp, targ, enc_hidden[i], local_encoder[i], local_decoder[i], disease_token[dtst_index])
                val_batch_loss += batch_loss

            val_loss_dtst = val_batch_loss / val_step_per_epoch[dtst_index]
            val_loss_dtst_list.append(val_loss_dtst)


        val_loss = sum(val_loss_dtst_list) / num_clients
        print()
        print('Client: ', i)
        print('train_loss: ', train_loss)
        print('val_loss: ', val_loss)
       

        loss_history[i].append((train_loss, val_loss))

    # train_batch_loss = 0

    #############################################################################################################


    # Get encoder weights
    enc_wgt = []
    encoder_weights = []
    encoder_total_weight = 0

    for i in range(num_clients):
        enc_w = local_encoder[i].get_weights()
        enc_wgt.append(enc_w)

        enc_weight = np.array(enc_wgt[i])
        encoder_weights.append(enc_weight)

        encoder_total_weight += encoder_weights[i]
        encoder_avg_weight = encoder_total_weight / num_clients

    # print(encoder_avg_weight)

    global_encoder.set_weights(encoder_avg_weight)

    # Get decoder weights
    dec_wgt = []
    decoder_weights = []
    decoder_total_weight = 0

    for i in range(num_clients):
        dec_w = local_decoder[i].get_weights()
        dec_wgt.append(dec_w)

        dec_weight = np.array(dec_wgt[i])
        decoder_weights.append(dec_weight)

        decoder_total_weight += decoder_weights[i]
        decoder_avg_weight = decoder_total_weight / num_clients

    # print(decoder_avg_weight)

    global_decoder.set_weights(decoder_avg_weight)

    #############################################################################################################
    # Calculate loss for global model

    enc_hid = global_encoder.initialize_hidden_state()

    for dtst_index in range(num_clients):

        # Calculate train loss on all valid data
        train_batch_loss = 0
        train_loss_dtst_list = []

        for (batch, (inp, targ)) in enumerate(train_dataset[dtst_index].take(train_step_per_epoch[dtst_index])):

            batch_loss = get_model_loss(inp, targ, enc_hid, global_encoder, global_decoder, disease_token[dtst_index])
            train_batch_loss += batch_loss

        train_loss_dtst = train_batch_loss / train_step_per_epoch[dtst_index]
        train_loss_dtst_list.append(train_loss_dtst)

        # Calculate validation loss on all valid data
        val_batch_loss = 0
        val_loss_dtst_list = []

        for (batch, (inp, targ)) in enumerate(val_dataset[dtst_index].take(val_step_per_epoch[dtst_index])):

            batch_loss = get_model_loss(inp, targ, enc_hid, global_encoder, global_decoder, disease_token[dtst_index])
            val_batch_loss += batch_loss

        val_loss_dtst = val_batch_loss / val_step_per_epoch[dtst_index]
        val_loss_dtst_list.append(val_loss_dtst)


    val_loss = sum(val_loss_dtst_list) / num_clients
    train_loss = sum(train_loss_dtst_list) / num_clients
    print()
    print('Global Model')
    print('train_loss: ', train_loss)
    print('val_loss: ', val_loss)
    loss_history[global_index].append((train_loss, val_loss))
    print('#########################################################')

    #############################################################################################################

In [None]:
loss_history_file = os.path.join(CURRENT_OUTPUT_DIR, 'training_history.txt')

with open(loss_history_file, 'wb') as file_pi:

    pickle.dump(loss_history, file_pi)

# training_history = pickle.load(open(training_history_file, "rb"))
# train_mse, test_mse, train_mae, test_mae = training_history

In [None]:
loss_history = pickle.load(open(loss_history_file, "rb"))

In [None]:
def plot_graph(model_loss_history, model_type='Client ', index=''):
  train_loss = []
  val_loss = []

  for item in model_loss_history:
      train_loss.append(item[0].numpy())
      val_loss.append(item[1].numpy())

  plt.figure(figsize=(10,4))
  plt.plot(train_loss, label='train')
  plt.plot(val_loss, label='val')
  plt.title('Training History ({}{}) '.format(model_type, index))
  plt.xlabel('Rounds')
  plt.ylabel('Loss')
  plt.grid(True)
  plt.legend()
  img_path = os.path.join(IMG_DIR, '{}{}_history.png'.format(model_type, index))
  plt.savefig(img_path, dpi=300, bbox_inches = 'tight')
  plt.show()

In [None]:
for i in range(num_clients):
  plot_graph(loss_history[i], index=i)

In [None]:
global_index = 5
plot_graph(loss_history[global_index], model_type='Global')

In [None]:
plt.figure(figsize=(12,6))

for i in range(6):    
    
    train_loss = []

    for item in loss_history[i]:
        train_loss.append(item[0].numpy())
    plt.plot(train_loss, label=('Client {}'.format(i+1) if i<5 else 'Global Model'))

plt.title('Training Loss')
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()
img_path = os.path.join(IMG_DIR, 'training_loss.png')
plt.savefig(img_path, dpi=300, bbox_inches = 'tight')

In [None]:
plt.figure(figsize=(12,6))

for i in range(6):    
    
    valid_loss = []

    for item in loss_history[i]:
        valid_loss.append(item[1].numpy())
    plt.plot(valid_loss, label=('Client {}'.format(i+1) if i<5 else 'Global Model'))

plt.title('Validation Loss')
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()
img_path = os.path.join(IMG_DIR, 'validation_loss.png')
plt.savefig(img_path, dpi=300, bbox_inches = 'tight')

In [None]:
# function for plotting the attention weights
def plot_attention(attention, sentence, predicted_sentence, img_path):
  fig = plt.figure(figsize=(10,10))
  ax = fig.add_subplot(1, 1, 1)
  ax.matshow(attention, cmap='viridis')

  fontdict = {'fontsize': 14}

  ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
  ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)

  ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
  ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

  plt.savefig(img_path, dpi=300, bbox_inches = 'tight')
  plt.show()

In [None]:
# restoring the latest checkpoint in checkpoint_dir
for i in range(num_clients):
  checkpoint[i].restore(tf.train.latest_checkpoint(checkpoint_dir[i]))

In [None]:
def evaluate(sentence, max_len_tar, max_len_in, in_lan, encoder, tar_lan, decoder):
  attention_plot = np.zeros((max_len_tar, max_len_in))

  sentence = preprocess_sentence(sentence)

  inputs = [in_lan.word_index[i] for i in sentence.split(' ')]
  inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],
                                                         maxlen=max_len_in,
                                                         padding='post')
  inputs = tf.convert_to_tensor(inputs)

  result = ''

  hidden = [tf.zeros((1, units))]
  enc_out, enc_hidden = encoder(inputs, hidden)

  dec_hidden = enc_hidden
  dec_input = tf.expand_dims([tar_lan.word_index['<start>']], 0)

  for t in range(max_len_tar):
    predictions, dec_hidden, attention_weights = decoder(dec_input,
                                                         dec_hidden,
                                                         enc_out)

    # storing the attention weights to plot later on
    attention_weights = tf.reshape(attention_weights, (-1, ))
    attention_plot[t] = attention_weights.numpy()

    predicted_id = tf.argmax(predictions[0]).numpy()

    result += tar_lan.index_word[predicted_id] + ' '

    if tar_lan.index_word[predicted_id] == '<end>':
      return result, sentence, attention_plot

    # the predicted ID is fed back into the model
    dec_input = tf.expand_dims([predicted_id], 0)

  return result, sentence, attention_plot 

In [None]:
def translate(sentence, max_len_tar, max_len_in, in_lan, encoder, tar_lan, decoder, img_path):
  result, sentence, attention_plot = evaluate(sentence, max_len_tar, max_len_in, in_lan, encoder, tar_lan, decoder)

  print('Input: %s' % (sentence))
  print('Predicted translation: {}'.format(result))

  attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
  plot_attention(attention_plot, sentence.split(' '), result.split(' '), img_path)

In [None]:
img_path = os.path.join(IMG_DIR, 'sample_1.png')
translate(u'itching skin_rash nodal_skin_eruptions continuous_sneezing chest_pain', max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
img_path = os.path.join(IMG_DIR, 'sample_2.png')
translate(u'continuous_sneezing chills fatigue cough high_fever headache muscle_pain chest_pain swelled_lymph_nodes malaise phlegm throat_irritation redness_of_eyes sinus_pressure runny_nose congestion loss_of_smell', max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
sym, dis = create_testset(os.path.join(MANUAL_INPUT_DIR, 'test_file.txt'), None)

In [None]:
print(sym[19])
print(dis[19])

img_path = os.path.join(IMG_DIR, 'sample_3.png')
translate(sym[19], max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
print(sym[7])
print(dis[7])

img_path = os.path.join(IMG_DIR, 'sample_4.png')
translate(sym[7], max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
print(sym[9])
print(dis[9])

img_path = os.path.join(IMG_DIR, 'sample_5.png')
translate(sym[9], max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
print(sym[17])
print(dis[17])

img_path = os.path.join(IMG_DIR, 'sample_6.png')
translate(sym[17], max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
print(sym[26])
print(dis[26])

img_path = os.path.join(IMG_DIR, 'sample_7.png')
translate(sym[26], max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
print(sym[27])
print(dis[27])

img_path = os.path.join(IMG_DIR, 'sample_8.png')
translate(sym[27], max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
print(sym[37])
print(dis[37])

img_path = os.path.join(IMG_DIR, 'sample_9.png')
translate(sym[37], max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
print(sym[1])
print(dis[1])

img_path = os.path.join(IMG_DIR, 'sample_10.png')
translate(sym[1], max_length_disease[0], max_length_symptom[0], symptom_token[0], global_encoder, disease_token[0], global_decoder, img_path)

In [None]:
print(dis)