In [1]:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
from jax import jit
import numpy as np
import jax.numpy as jnp

x = np.random.rand(1000,1000)
y = jnp.array(x)

def f(x):

  for _ in range(10):
      x = 0.5*x + 0.1* jnp.sin(x)

  return x

g = jit(f)

%timeit -n 5 -r 5 f(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop

%timeit -n 5 -r 5 g(y).block_until_ready()
# 5 loops, best of 5: 341 µs per loop

54.8 ms ± 4.52 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
10.2 ms ± 4.52 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


In [4]:
import flwr
print(flwr.__version__)

1.5.0


In [6]:
import flwr as fl


# Define Flower client for TensorFlow
class TensorFlowClient(fl.client.NumPyClient):
    def __init__(self, model, dataset, epochs=1, batch_size=32):
        self.model = model
        self.dataset = dataset
        self.epochs = epochs
        self.batch_size = batch_size
        self.model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

    def get_parameters(self):
        # Convert model parameters to a list of NumPy ndarrays
        return [np.asarray(v) for v in self.model.get_weights()]

    def set_parameters(self, parameters):
        # Set model parameters from a list of NumPy ndarrays
        self.model.set_weights(parameters)

    def fit(self, parameters, config):
        # Set the parameters, train the model, return the updated parameters
        self.set_parameters(parameters)
        self.model.fit(self.dataset[0], self.dataset[1], epochs=self.epochs, batch_size=self.batch_size)
        return self.get_parameters(), len(self.dataset[0]), {}

    def evaluate(self, parameters, config):
        # Set the parameters, evaluate the model, return the result
        self.set_parameters(parameters)
        loss, accuracy = self.model.evaluate(self.dataset[0], self.dataset[1])
        return float(loss), len(self.dataset[0]), {"accuracy": float(accuracy)}

# Define Flower client for JAX


In [None]:
# Initialize VGG16 model for TensorFlow
tf_model = tf.keras.applications.VGG16(weights=None, input_shape=(224, 224, 3), classes=10)


# # For demonstration purposes, let's use a dummy dataset
# x_train = np.random.rand(1000, 28, 28)
# y_train = np.random.randint(10, size=1000)

# # Create an instance of the TensorFlowClient
# client = TensorFlowClient(tf_model, (x_train, y_train))

# # Start Flower server and client
# fl.server.start_server("[::]:8080", config={"num_rounds": 3}, strategy=fl.server.strategy.FedAvg())
# fl.client.start_numpy_client("[::]:8080", client)

In [8]:
import flwr as fl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class PyTorchClient(fl.client.NumPyClient):
    def __init__(self, model, device, train_loader, test_loader, epochs=1, batch_size=32):
        self.model = model
        self.device = device
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.epochs = epochs
        self.batch_size = batch_size
        self.optimizer = torch.optim.Adam(self.model.parameters())

    def get_parameters(self):
        # Convert model parameters to a list of NumPy ndarrays
        return [param.cpu().numpy() for param in self.model.parameters()]

    def set_parameters(self, parameters):
        # Set model parameters from a list of NumPy ndarrays
        params = [torch.from_numpy(p).to(self.device) for p in parameters]
        for p, param in zip(self.model.parameters(), params):
            p.data = param

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        for epoch in range(self.epochs):
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = F.cross_entropy(output, target)
                loss.backward()
                self.optimizer.step()
        return self.get_parameters(), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model


In [None]:
# Initialize VGG16 model for PyTorch
torch_model = models.vgg16(pretrained=True)
torch_model.classifier[6] = torch.nn.Linear(4096, 10)
#modify to have 10 classes

# # Assuming your PyTorch model is named 'PyTorchModel'
# model = PyTorchModel().to(device)

# # Prepare dataset
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# # Create an instance of the PyTorchClient
# client = PyTorchClient(torch_model, device, train_loader, test_loader)

# # Start Flower server and client
# fl.server.start_server("[::]:8080", config={"num_rounds": 3}, strategy=fl.server.strategy.FedAvg())
# fl.client.start_numpy_client("[::]:8080", client)


In [None]:
import flwr as fl
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
from flax import linen as nn
from flax.training import train_state




class JAXClient(fl.client.NumPyClient):
    def __init__(self, model, params, dataset, rng, epochs=1, batch_size=32):
        self.model = model
        self.params = params
        self.dataset = dataset
        self.rng = rng
        self.epochs = epochs
        self.batch_size = batch_size
        self.opt_state = self.create_optimizer()

    def create_optimizer(self):
        # Define optimizer and initialize optimizer state
        optimizer_def = flax.optim.Momentum(learning_rate=0.01, beta=0.9)
        return optimizer_def.create(self.params)

    def get_parameters(self):
        # Convert model parameters to a list of NumPy ndarrays
        return flax.serialization.to_state_dict(self.params)

    def set_parameters(self, parameters):
        # Set model parameters from a list of NumPy ndarrays
        self.params = flax.serialization.from_state_dict(self.params, parameters)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        for epoch in range(self.epochs):
            for batch in self.dataset:
                self.rng, rng_input = random.split(self.rng)
                grads = self.get_grads(rng_input, batch)
                self.opt_state = self.update(self.opt_state, grads)
        return self.get_parameters(), len(self.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        accuracy = self.compute_accuracy(self.dataset)
        return float(1 - accuracy), len(self.dataset), {"accuracy": float(accuracy)}

    @jit
    def get_grads(self, rng, batch):
        # Compute gradients for a given batch
        def loss_fn(params):
            logits = self.model.apply({'params': params}, rng, batch[0])
            loss = jnp.mean(nn.softmax_cross_entropy(logits=logits, labels=batch[1]))
            return loss
        grads = grad(loss_fn)(self.params)
        return grads

    @jit
    def update(self, opt_state, grads):
        # Update optimizer state using computed gradients
        return self.optimizer.update(grads, opt_state)

    @jit
    def compute_accuracy(self, dataset):
        # Compute accuracy for the given dataset
        logits = self.model.apply({'params': self.params}, self.rng, dataset[0])
        predicted_class = jnp.argmax(logits, axis=1)
        return jnp.mean(predicted_class == dataset[1])


In [None]:
class VGG16Block(nn.Module):
    filters: int
    repetitions: int

    @nn.compact
    def __call__(self, x):
        for _ in range(self.repetitions):
            x = nn.Conv(features=self.filters, kernel_size=(3, 3), padding="SAME")(x)
            x = nn.relu(x)
        return nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

class VGG16(nn.Module):
    num_classes: int = 1000

    @nn.compact
    def __call__(self, x):
        x = VGG16Block(filters=64, repetitions=2)(x)
        x = VGG16Block(filters=128, repetitions=2)(x)
        x = VGG16Block(filters=256, repetitions=3)(x)
        x = VGG16Block(filters=512, repetitions=3)(x)
        x = VGG16Block(filters=512, repetitions=3)(x)
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = nn.Dense(features=4096)(x)
        x = nn.relu(x)
        x = nn.Dense(features=4096)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.num_classes)(x)
        return x

# Initialize VGG16 model for JAX
_, params = VGG16(num_classes=10).init_by_shape(jax.random.PRNGKey(0), [(1, 224, 224, 3)])

In [None]:
# # For demonstration purposes, let's use a dummy dataset
# x_train = jnp.array(np.random.rand(1000, 28, 28))
# y_train = jnp.array(np.random.randint(10, size=1000))

# # Create an instance of the JAXClient
# client = JAXClient(VGG16(), params, (x_train, y_train), jax.random.PRNGKey(0))

# # Start Flower server and client
# fl.server.start_server("[::]:8080", config={"num_rounds": 3}, strategy=fl.server.strategy.FedAvg())
# fl.client.start_numpy_client("[::]:8080", client)

In [None]:

def train_with_flower(architecture):
    if architecture == 1:
        client = PyTorchClient()
    elif architecture == 2:
        client = TensorFlowClient()
    elif architecture == 3:
        client = JAXClient()
    else:
        raise ValueError("Invalid architecture choice")

    # Start Flower server and client (for simplicity, running in the same process here)
    fl.server.start_server("[::]:8080", config={"num_rounds": 3}, strategy=fl.server.strategy.FedAvg())
    fl.client.start_numpy_client("[::]:8080", client)
