In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install --upgrade tensorflow tensorflow-federated


Collecting tensorflow
  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tensorflow-federated
  Downloading tensorflow_federated-0.87.0-py3-none-manylinux_2_31_x86_64.whl.metadata (19 kB)
Collecting tensorboard<2.19,>=2.18 (from tensorflow)
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting attrs~=23.1 (from tensorflow-federated)
  Downloading attrs-23.2.0-py3-none-any.whl.metadata (9.5 kB)
Collecting dp-accounting==0.4.3 (from tensorflow-federated)
  Downloading dp_accounting-0.4.3-py3-none-any.whl.metadata (1.8 kB)
Collecting google-vizier==0.1.11 (from tensorflow-federated)
  Downloading google_vizier-0.1.11-py3-none-any.whl.metadata (10 kB)
Collecting jaxlib==0.4.14 (from tensorflow-federated)
  Downloading jaxlib-0.4.14-cp310-cp310-manylinux2014_x86_64.whl.metadata (2.0 kB)
Collecting jax==0.4.14 (from tensorflow-federated)
  Downloading jax-0.4.14.tar.gz (1.3 MB)
[2K     [90m━━━━━━━

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from sklearn.model_selection import train_test_split

# Load dataset
dataset = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/urls_final_complete.csv', low_memory=False, na_values='')

# Prepare data
X = dataset.drop('URL_Type_obf_Type', axis=1)
y = dataset['URL_Type_obf_Type']

label_map = {'benign': 0, 'phishing': 1, 'malware': 2, 'defacement': 3, 'spam': 4}
y = y.map(label_map)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

# Split the training data into 6 clients
clients_data = []
num_clients = 6
client_data_size = len(X_train) // num_clients
for i in range(num_clients):
    start_idx = i * client_data_size
    end_idx = start_idx + client_data_size
    clients_data.append((X_train[start_idx:end_idx], y_train[start_idx:end_idx]))

# Convert clients data to TFF datasets
def create_tf_dataset_for_client(client_data):
    client_X, client_y = client_data
    dataset = tf.data.Dataset.from_tensor_slices((client_X.values.astype(np.float32), client_y.values.astype(np.int64)))
    return dataset.shuffle(buffer_size=client_X.shape[0]).batch(64)

federated_train_data = [create_tf_dataset_for_client(client_data) for client_data in clients_data]

# Define the model
def create_keras_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(X_train.shape[1],)),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(32, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(5, activation='softmax')
    ])
    return model

# Define TFF model
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=federated_train_data[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

# Federated learning process
def build_federated_averaging_process(model_fn):
    return tff.learning.algorithms.build_weighted_fed_avg(
        model_fn,
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01),
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.4)
    )

iterative_process = build_federated_averaging_process(model_fn)

state = iterative_process.initialize()

# Train for multiple rounds
num_rounds = 25
for round_num in range(num_rounds):
    state, metrics = iterative_process.next(state, federated_train_data)
    print(f'Round {round_num + 1}, Metrics: {metrics}')

# Extract the updated model weights after training
trained_model_weights = iterative_process.get_model_weights(state)

# Initialize the evaluation process
evaluation = tff.learning.algorithms.build_fed_eval(model_fn)
evaluation_state = evaluation.initialize()

# Update the evaluation state with the trained model weights
evaluation_state = evaluation.set_model_weights(evaluation_state, trained_model_weights)

# Create federated test dataset
def create_federated_test_data(X, y):
    dataset = tf.data.Dataset.from_tensor_slices((X.values.astype(np.float32), y.values.astype(np.int64)))
    return dataset.batch(64)

federated_test_data = [create_federated_test_data(X_test, y_test)]

evaluation_output = evaluation.next(evaluation_state, federated_test_data)


test_metrics = evaluation_output.metrics
print('Test metrics:', test_metrics)

ERROR:jax._src.xla_bridge:Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 438, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/usr/local/lib/python3.10/dist-packages/jax_plugins/xla_cuda12/__init__.py", line 85, in initialize
    options = xla_client.generate_pjrt_gpu_plugin_options()
AttributeError: module 'jaxlib.xla_client' has no attribute 'generate_pjrt_gpu_plugin_options'


Round 1, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.6870229), ('loss', 0.83555746), ('num_examples', 655278), ('num_batches', 10242)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
Round 2, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.72101307), ('loss', 0.7302479), ('num_examples', 655278), ('num_batches', 10242)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])
Round 3, Metrics: OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.74443823), ('loss', 0.67098796), ('num_examples', 655278), ('num_batches', 10242)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight',