In [1]:
from flwr_datasets.partitioner import DirichletPartitioner
from flwr_datasets import FederatedDataset

In [2]:
num_partitions = 4
alpha_dir = 0.1

In [3]:
partitioner = DirichletPartitioner(
    num_partitions=num_partitions,
    partition_by="label",
    alpha=alpha_dir,
    min_partition_size=0,
    self_balancing=False
)

In [4]:
fds = FederatedDataset(
    dataset="mnist",
    partitioners={"train": partitioner}
)

In [5]:
train_partitions = [fds.load_partition(i, split="train") for i in range(num_partitions)]

##### Rodar proxima celula somente se quiser testar com dataset reduzido

In [6]:
num_samples = [int(len(train_partition)/10) for train_partition in train_partitions]
train_partitions = [train_partition.select(range(n)) for train_partition, n in zip(train_partitions, num_samples)]

In [7]:
from torchvision.transforms import Compose, ToTensor, Normalize

In [8]:
pytorch_transforms = Compose([
    ToTensor(),
    Normalize((0.5,), (0.5,))
])

def apply_transforms(batch):
    batch["image"] = [pytorch_transforms(img) for img in batch["image"]]
    return batch

In [9]:
train_partitions = [train_partition.with_transform(apply_transforms) for train_partition in train_partitions]

In [None]:
from torch.utils.data import DataLoader

In [11]:
batch_size = 128
trainloaders = [DataLoader(train_partition, batch_size=batch_size, shuffle=True) for train_partition in train_partitions]

In [13]:
import numpy as np
import torch.nn as nn
import torch as torch

In [14]:
class CGAN(nn.Module):
    def __init__(self, dataset="mnist", img_size=28, latent_dim=100):
        super(CGAN, self).__init__()
        if dataset == "mnist":
            self.classes = 10
            self.channels = 1
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes)
        self.adv_loss = nn.BCELoss()


        self.generator = nn.Sequential(
            *self._create_layer_gen(self.latent_dim + self.classes, 128, False),
            *self._create_layer_gen(128, 256),
            *self._create_layer_gen(256, 512),
            *self._create_layer_gen(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

        self.discriminator = nn.Sequential(
            *self._create_layer_disc(self.classes + int(np.prod(self.img_shape)), 1024, False, True),
            *self._create_layer_disc(1024, 512, True, True),
            *self._create_layer_disc(512, 256, True, True),
            *self._create_layer_disc(256, 128, False, False),
            *self._create_layer_disc(128, 1, False, False),
            nn.Sigmoid()
        )

    def _create_layer_gen(self, size_in, size_out, normalize=True):
        layers = [nn.Linear(size_in, size_out)]
        if normalize:
            layers.append(nn.BatchNorm1d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def _create_layer_disc(self, size_in, size_out, drop_out=True, act_func=True):
        layers = [nn.Linear(size_in, size_out)]
        if drop_out:
            layers.append(nn.Dropout(0.4))
        if act_func:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, input, labels):
        if input.dim() == 2:
            z = torch.cat((self.label_embedding(labels), input), -1)
            x = self.generator(z)
            x = x.view(x.size(0), *self.img_shape) #Em
            return x
        elif input.dim() == 4:
            x = torch.cat((input.view(input.size(0), -1), self.label_embedding(labels)), -1)
            return self.discriminator(x)

    def loss(self, output, label):
        return self.adv_loss(output, label)


In [15]:
models = [CGAN() for i in range(num_partitions)]

In [16]:
optim_Ds = [
    torch.optim.Adam(model.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    for model in models
]

In [17]:
def generate_plot(net, device, round_number, client_id = None, examples_per_class: int=5, classes: int=10, latent_dim: int=100, server: bool=False):
    """Gera plot de imagens de cada classe"""
    if server:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    else:
        import matplotlib.pyplot as plt

    net_type = type(net).__name__
    net.to(device)
    net.eval()
    batch_size = examples_per_class * classes

    latent_vectors = torch.randn(batch_size, latent_dim, device=device)
    labels = torch.tensor([i for i in range(classes) for _ in range(examples_per_class)], device=device)

    with torch.no_grad():
        if net_type == "Generator":
            labels_one_hot = torch.nn.functional.one_hot(labels, 10).float().to(device)
            generated_images = net(torch.cat([latent_vectors, labels_one_hot], dim=1))
        else:
            generated_images = net(latent_vectors, labels).cpu()

    # Criar uma figura com 10 linhas e 5 colunas de subplots
    fig, axes = plt.subplots(classes, examples_per_class, figsize=(5, 9))

    # Adiciona tÃ­tulo no topo da figura
    if client_id:
        fig.text(0.5, 0.98, f"Round: {round_number} | Client: {client_id}", ha="center", fontsize=12)
    else:
        fig.text(0.5, 0.98, f"Round: {round_number}", ha="center", fontsize=12)

    # Exibir as imagens nos subplots
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated_images[i, 0, :, :], cmap='gray')
        ax.set_xticks([])
        ax.set_yticks([])

    # Ajustar o layout antes de calcular as posiÃ§Ãµes
    plt.tight_layout(rect=[0.05, 0, 1, 0.96])

    # Reduzir espaÃ§o entre colunas
    # plt.subplots_adjust(wspace=0.05)

    # Adicionar os rÃ³tulos das classes corretamente alinhados
    fig.canvas.draw()  # Atualiza a renderizaÃ§Ã£o para obter posiÃ§Ãµes corretas
    for row in range(classes):
        # Obter posiÃ§Ã£o do subplot em coordenadas da figura
        bbox = axes[row, 0].get_window_extent(fig.canvas.get_renderer())
        pos = fig.transFigure.inverted().transform([(bbox.x0, bbox.y0), (bbox.x1, bbox.y1)])
        center_y = (pos[0, 1] + pos[1, 1]) / 2  # Centro exato da linha

        # Adicionar o rÃ³tulo
        fig.text(0.04, center_y, str(row), va='center', fontsize=12, color='black')

    fig.savefig(f"mnist_CGAN_r{round_number}_f2a.png")
    plt.close(fig)
    return

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
gen = CGAN().to(device)
optim_G = torch.optim.Adam(gen.generator.parameters(), lr=0.0001, betas=(0.5, 0.999))

# scheduler_D = torch.optim.lr_scheduler.StepLR(optim_D, step_size=5, gamma=0.9)
# scheduler_G = torch.optim.lr_scheduler.StepLR(optim_G, step_size=5, gamma=0.9)

In [None]:
from tqdm.notebook import tqdm

In [29]:
from flwr.server.strategy.aggregate import aggregate_inplace
from flwr.common import FitRes, Status, Code, ndarrays_to_parameters
from collections import OrderedDict

ModuleNotFoundError: No module named 'flwr'

In [None]:
from torch.utils.data import Dataset

class GeneratedDataset(Dataset):
    def __init__(self, generator, num_samples, latent_dim, num_classes, device):
        self.generator = generator
        self.num_samples = num_samples
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.device = device
        self.model = type(self.generator).__name__
        self.images, self.labels = self.generate_data()
        self.classes = [i for i in range(self.num_classes)]


    def generate_data(self):
        self.generator.eval()
        labels = torch.tensor([i for i in range(self.num_classes) for _ in range(self.num_samples // self.num_classes)], device=self.device)
        if self.model == 'Generator':
            labels_one_hot = F.one_hot(labels, self.num_classes).float().to(self.device) #
        z = torch.randn(self.num_samples, self.latent_dim, device=self.device)
        with torch.no_grad():
            if self.model == 'Generator':
                gen_imgs = self.generator(torch.cat([z, labels_one_hot], dim=1))
            elif self.model == 'CGAN' or self.model=="F2U_GAN":
                gen_imgs = self.generator(z, labels)

        return gen_imgs.cpu(), labels.cpu()

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

In [None]:
rounds = 2
epochs = 2
latent_dim = 100
accuracies = []
g_losses_e = [[] for _ in range(num_partitions)]
d_losses_e = [[] for _ in range(num_partitions)]
g_losses_b = [[] for _ in range(num_partitions)]
d_losses_b = [[] for _ in range(num_partitions)]
criterion = torch.nn.CrossEntropyLoss()
global_net = CGAN().to(device)

round_bar = tqdm(range(rounds), desc="Rodadas", leave=True, position=0)

for round in round_bar:
    
    print(f"\nðŸ”¸ Round {round+1}/{rounds}")

    acc = 0

    params = []
    results = []

    client_bar = tqdm(enumerate(zip(models, trainloaders, g_losses_e, d_losses_e, g_losses_b, d_losses_b)), desc="Clientes", leave=False, position=1)
    
    for i, (model, trainloader, g_loss_e, d_loss_e, g_loss_b, d_loss_b) in client_bar:
        print(f"\nðŸ”¹ Client {i+1}/{num_partitions}")
        model.load_state_dict(global_net.state_dict(), strict=True)
        model.to(device) # move model to GPU if available
        model.train() # set model to training mode
        optim_G = torch.optim.Adam(model.generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
        optim_D = torch.optim.Adam(model.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

        epoch_bar = tqdm(range(epochs), desc="Epocas locais", leave=False, position=2)

        for epoch in epoch_bar:

            print(f"\nðŸ”¹ Epoch {epoch+1}/{epochs}")

            for batch in trainloader:
                print(f"\nðŸ”¸ Batch {trainloader.batch_sampler.sampler.indices[0]}/{len(trainloader)}")
                images = batch["image"].to(device)
                labels = batch["label"].to(device)

                batch_size = images.size(0)
                if batch_size == 1:
                    print(f"Batch size is 1, skipping...")
                    continue
                real_ident = torch.full((batch_size, 1), 1., device=device)
                fake_ident = torch.full((batch_size, 1), 0., device=device)

                # Train Discriminator
                optim_D.zero_grad()

                # Real images
                y_real = model(images, labels)
                d_real_loss = model.loss(y_real, real_ident)

                # Fake images
                z_noise = torch.randn(batch_size, latent_dim, device=device)
                x_fake_labels = torch.randint(0, 10, (batch_size,), device=device)
                y_fake_d = model(x_fake.detach(), x_fake_labels)
                d_fake_loss = model.loss(y_fake_d, fake_ident)

                #Loss calculation
                d_loss = (d_real_loss + d_fake_loss) / 2
                d_loss.backward()
                optim_D.step()
                
                # Train Generator
                optim_G.zero_grad()
                
                z_noise = torch.randn(batch_size, latent_dim, device=device)
                x_fake_labels = torch.randint(0, 10, (batch_size,), device=device)

                x_fake = model(z_noise, x_fake_labels)

                y_fake_g = model(x_fake, x_fake_labels)

                g_loss = model.loss(y_fake_g, real_ident)
                g_loss.backward()
                optim_G.step()

                g_loss_b.append(g_loss.item())
                d_loss_b.append(d_loss.item())
            
            g_loss_e.append(np.mean(g_loss_b[epoch*len(trainloader):(epoch+1)*len(trainloader)]))
            d_loss_e.append(np.mean(d_loss_b[epoch*len(trainloader):(epoch+1)*len(trainloader)]))


            #epoch_bar.set_postfix_str(f"Client {models.index(model)+1}/{num_partitions}")
        params.append(ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()]))
        results.append((i, FitRes(status=Status(code=Code.OK, message="Success"), parameters=params[i], num_examples=len(trainloader.dataset), metrics={})))

    # Agrega modelos

    aggregated_ndarrays = aggregate_inplace(results)
  
    params_dict = zip(global_net.state_dict().keys(), aggregated_ndarrays)
    state_dict = OrderedDict({k: torch.tensor(v).to(device) for k, v in params_dict})
    global_net.load_state_dict(state_dict, strict=True)

    figura = generate_plot(net=global_net.generator, device="cpu", round_number=round, latent_dim=128)

    # Create the dataset and dataloader
    generated_dataset = GeneratedDataset(generator=global_net.generator, num_samples=num_samples, latent_dim=latent_dim, num_classes=10, device="cpu")
    generated_dataloader = DataLoader(generated_dataset, batch_size=64, shuffle=True)
    

    net = Net()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
    net.train()
    for epoch in range(5):
        for data in trainloader:
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    correct, loss = 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images = batch[0]
            labels = batch[1]
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
    accuracy = correct / len(testloader.dataset)