Playground

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [3]:
class Net(nn.Module):
    def __init__(self, seed=None):
        if seed is not None:
          torch.manual_seed(seed)
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*4*4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [34]:
import numpy as np

In [35]:
class F2U_GAN(nn.Module):
    def __init__(self, dataset="mnist", img_size=28, latent_dim=128, condition=True, seed=None):
        if seed is not None:
          torch.manual_seed(seed)
        super(F2U_GAN, self).__init__()
        if dataset == "mnist":
            self.classes = 10
            self.channels = 1
        else:
            raise NotImplementedError("Only MNIST is supported")

        self.condition = condition
        self.label_embedding = nn.Embedding(self.classes, self.classes) if condition else None
        #self.label_embedding_disc = nn.Embedding(self.classes, self.img_size*self.img_size) if condition else None
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.input_shape_gen = self.latent_dim + self.label_embedding.embedding_dim if condition else self.latent_dim
        self.input_shape_disc = self.channels + self.classes if condition else self.channels

        self.adv_loss = torch.nn.BCEWithLogitsLoss()

        # Generator (unchanged) To calculate output shape of convtranspose layers, we can use the formula:
        # output_shape = (input_shape - 1) * stride - 2 * padding + kernel_size + output_padding (or dilation * (kernel_size - 1) + 1 inplace of kernel_size if using dilation)
        self.generator = nn.Sequential(
            nn.Linear(self.input_shape_gen, 256 * 7 * 7),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (256, 7, 7)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (256,7,7) -> (128,14,14)
            nn.BatchNorm2d(128, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # (128,14,14) -> (64,28,28)
            nn.BatchNorm2d(64, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, self.channels, kernel_size=3, stride=1, padding=1), # (64,28,28) -> (1,28,28)
            nn.Tanh()
        )

        # Discriminator (corrected) To calculate output shape of conv layers, we can use the formula:
        # output_shape = ⌊(input_shape - kernel_size + 2 * padding) / stride + 1⌋ (or (dilation * (kernel_size - 1) - 1) inplace of kernel_size if using dilation)
        self.discriminator = nn.Sequential(
        # Camada 1: (1,28,28) -> (32,13,13)
        nn.utils.spectral_norm(nn.Conv2d(self.input_shape_disc, 32, kernel_size=3, stride=2, padding=0)),
        nn.LeakyReLU(0.2, inplace=True),

        # Camada 2: (32,14,14) -> (64,7,7)
        nn.utils.spectral_norm(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)),
        nn.LeakyReLU(0.2, inplace=True),

        # Camada 3: (64,7,7) -> (128,3,3)
        nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0)),
        nn.LeakyReLU(0.2, inplace=True),

        # Camada 4: (128,3,3) -> (256,1,1)
        nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=0)),  # Padding 0 aqui!
        nn.LeakyReLU(0.2, inplace=True),

        # Achata e concatena com as labels
        nn.Flatten(), # (256,1,1) -> (256*1*1,)
        nn.utils.spectral_norm(nn.Linear(256 * 1 * 1, 1))  # 256 (features)
        )

    def forward(self, input, labels=None):
        if input.dim() == 2:
            # Generator forward pass (unchanged)
            if self.condition:
                embedded_labels = self.label_embedding(labels)
                gen_input = torch.cat((input, embedded_labels), dim=1)
                x = self.generator(gen_input)
            else:
                x = self.generator(input)
            return x.view(-1, *self.img_shape)

        elif input.dim() == 4:
            # Discriminator forward pass
            if self.condition:
                embedded_labels = self.label_embedding(labels)
                image_labels = embedded_labels.view(embedded_labels.size(0), self.label_embedding.embedding_dim, 1, 1).expand(-1, -1, self.img_size, self.img_size)
                x = torch.cat((input, image_labels), dim=1)
            else:
                x = input
            return self.discriminator(x)

    def loss(self, output, label):
        return self.adv_loss(output, label)

In [36]:
global_net = Net()
optims = [torch.optim.Adam(global_net.parameters())]
gen = F2U_GAN()
optim_G = torch.optim.Adam(gen.generator.parameters())
models = [F2U_GAN() for _ in range(4)]
optim_Ds = [torch.optim.Adam(disc.discriminator.parameters()) for disc in models]

In [7]:
checkpoint = {
            'epoch': 0,  # número da última época concluída
            'alvo_state_dict': global_net.state_dict(),
            'optimizer_alvo_state_dict': [optim.state_dict() for optim in optims],
            'gen_state_dict': gen.state_dict(),
            'optim_G_state_dict': optim_G.state_dict(),
            'discs_state_dict': [model.state_dict() for model in models],
            'optim_Ds_state_dict:': [optim_d.state_dict() for optim_d in optim_Ds]
          }
checkpoint_file = f"checkpoint_epoch{000}.pth"
torch.save(checkpoint, checkpoint_file)

In [37]:
check = torch.load("../Experimentos/Flwr_run/GeraFed_F2U_4c_NIID_Class/checkpoint_epoch1.pth", map_location="cpu")

In [44]:
global_net.load_state_dict(check['classifier_state_dict'])

gen.load_state_dict(check["gen_state_dict"])
optim_G.load_state_dict(check["optim_G_state_dict"])

for model, optim_d, state_model, state_optim in zip(models, optim_Ds, check["discs_state_dict"].values(), check["optimDs_state_dict"].values()):
    model.load_state_dict(state_model)
    optim_d.load_state_dict(state_optim)

In [16]:
global_net.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

In [47]:
optim_d.state_dict()['state'][0]

{'step': tensor(29.),
 'exp_avg': tensor([-0.0003, -0.0047, -0.0022, -0.0022, -0.0015, -0.0020, -0.0012, -0.0028,
          0.0185, -0.0029,  0.0089, -0.0002,  0.0081, -0.0070,  0.0069,  0.0002,
         -0.0055, -0.0011, -0.0154,  0.0005,  0.0144,  0.0142,  0.0084,  0.0154,
         -0.0352,  0.0064,  0.0239,  0.0103, -0.0247, -0.0012, -0.0229, -0.0190]),
 'exp_avg_sq': tensor([5.8063e-08, 7.0552e-07, 2.1241e-06, 1.0153e-06, 1.8042e-07, 9.2494e-08,
         5.3453e-07, 1.7549e-06, 4.7928e-06, 3.6051e-07, 1.5378e-06, 1.5724e-07,
         1.4774e-06, 3.7966e-06, 1.4055e-06, 1.0009e-07, 1.9653e-06, 1.4337e-06,
         5.3053e-06, 1.9597e-07, 2.4811e-06, 3.1176e-06, 2.0234e-06, 3.1287e-06,
         1.3793e-05, 6.2629e-07, 5.7887e-06, 4.2533e-06, 5.7951e-06, 5.5780e-07,
         7.9965e-06, 3.5455e-06])}

In [25]:
model.state_dict()['discriminator.2.bias'][4:6]

tensor([ 0.0086, -0.0516])

In [32]:
1%1

0