Playground

In [1]:
import torch
import torch.nn as nn

In [2]:
class F2U_GAN(nn.Module):
    def __init__(self, dataset="mnist", img_size=28, latent_dim=128, condition=True):
        super(F2U_GAN, self).__init__()
        if dataset == "mnist":
            self.classes = 10
            self.channels = 1
        else:
            raise NotImplementedError("Only MNIST is supported")
        
        self.condition = condition
        self.label_embedding = nn.Embedding(self.classes, self.classes) if condition else None
        #self.label_embedding_disc = nn.Embedding(self.classes, self.img_size*self.img_size) if condition else None
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.input_shape_gen = self.latent_dim + self.label_embedding.embedding_dim if condition else self.latent_dim
        self.input_shape_disc = self.channels + self.classes if condition else self.channels

        self.adv_loss = torch.nn.BCEWithLogitsLoss()

        # Generator (unchanged) To calculate output shape of convtranspose layers, we can use the formula:
        # output_shape = (input_shape - 1) * stride - 2 * padding + kernel_size + output_padding (or dilation * (kernel_size - 1) + 1 inplace of kernel_size if using dilation)
        self.generator = nn.Sequential(
            nn.Linear(self.input_shape_gen, 256 * 7 * 7),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (256, 7, 7)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (256,7,7) -> (128,14,14)
            nn.BatchNorm2d(128, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # (128,14,14) -> (64,28,28)
            nn.BatchNorm2d(64, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, self.channels, kernel_size=3, stride=1, padding=1), # (64,28,28) -> (1,28,28)
            nn.Tanh()
        )

        # Discriminator (corrected) To calculate output shape of conv layers, we can use the formula:
        # output_shape = ⌊(input_shape - kernel_size + 2 * padding) / stride + 1⌋ (or (dilation * (kernel_size - 1) - 1) inplace of kernel_size if using dilation)
        self.discriminator = nn.Sequential(
        # Camada 1: (1,28,28) -> (32,13,13)
        nn.utils.spectral_norm(nn.Conv2d(self.input_shape_disc, 32, kernel_size=3, stride=2, padding=0)),
        nn.LeakyReLU(0.2, inplace=True),

        # Camada 2: (32,14,14) -> (64,7,7)
        nn.utils.spectral_norm(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)),
        nn.LeakyReLU(0.2, inplace=True),

        # Camada 3: (64,7,7) -> (128,3,3)
        nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0)),
        nn.LeakyReLU(0.2, inplace=True),

        # Camada 4: (128,3,3) -> (256,1,1)
        nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=0)),  # Padding 0 aqui!
        nn.LeakyReLU(0.2, inplace=True),

        # Achata e concatena com as labels
        nn.Flatten(), # (256,1,1) -> (256*1*1,)
        nn.utils.spectral_norm(nn.Linear(256 * 1 * 1, 1))  # 256 (features)
        )

    def forward(self, input, labels=None):
        if input.dim() == 2:
            # Generator forward pass (unchanged)
            if self.condition:
                embedded_labels = self.label_embedding(labels)
                gen_input = torch.cat((input, embedded_labels), dim=1)
                x = self.generator(gen_input)
            else:
                x = self.generator(input)
            return x.view(-1, *self.img_shape)

        elif input.dim() == 4:
            # Discriminator forward pass
            if self.condition:
                embedded_labels = self.label_embedding(labels)
                image_labels = embedded_labels.view(embedded_labels.size(0), self.label_embedding.embedding_dim, 1, 1).expand(-1, -1, self.img_size, self.img_size)
                x = torch.cat((input, image_labels), dim=1)
            else:
                x = input
            return self.discriminator(x)

    def loss(self, output, label):
        return self.adv_loss(output, label)

In [None]:
# Load the model
model_path = "gen_round54.pt"
model = F2U_GAN(condition=True)  # Assuming the model is unconditional
model.load_state_dict(torch.load(model_path))
model.eval()

In [4]:
def generate_plot(net, device, round_number, client_id = None, examples_per_class: int=5, classes: int=10, latent_dim: int=100, server: bool=False):
    """Gera plot de imagens de cada classe"""
    if server:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
    else:
        import matplotlib.pyplot as plt

    net_type = type(net).__name__
    net.to(device)
    net.eval()
    batch_size = examples_per_class * classes

    latent_vectors = torch.randn(batch_size, latent_dim, device=device)
    labels = torch.tensor([i for i in range(classes) for _ in range(examples_per_class)], device=device)

    with torch.no_grad():
        if net_type == "Generator":
            labels_one_hot = torch.nn.functional.one_hot(labels, 10).float().to(device)
            generated_images = net(torch.cat([latent_vectors, labels_one_hot], dim=1))
        else:
            generated_images = net(latent_vectors, labels)

    # Criar uma figura com 10 linhas e 5 colunas de subplots
    fig, axes = plt.subplots(classes, examples_per_class, figsize=(5, 9))

    # Adiciona título no topo da figura
    if isinstance(client_id, int):
        fig.text(0.5, 0.98, f"Round: {round_number} | Client: {client_id}", ha="center", fontsize=12)
    else:
        fig.text(0.5, 0.98, f"Round: {round_number}", ha="center", fontsize=12)

    # Exibir as imagens nos subplots
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated_images[i, 0, :, :], cmap='gray')
        ax.set_xticks([])
        ax.set_yticks([])

    # Ajustar o layout antes de calcular as posições
    plt.tight_layout(rect=[0.05, 0, 1, 0.96])

    # Reduzir espaço entre colunas
    # plt.subplots_adjust(wspace=0.05)

    # Adicionar os rótulos das classes corretamente alinhados
    fig.canvas.draw()  # Atualiza a renderização para obter posições corretas
    for row in range(classes):
        # Obter posição do subplot em coordenadas da figura
        bbox = axes[row, 0].get_window_extent(fig.canvas.get_renderer())
        pos = fig.transFigure.inverted().transform([(bbox.x0, bbox.y0), (bbox.x1, bbox.y1)])
        center_y = (pos[0, 1] + pos[1, 1]) / 2  # Centro exato da linha

        # Adicionar o rótulo
        fig.text(0.04, center_y, str(row), va='center', fontsize=12, color='black')

    IN_COLAB = False
    try:
        # Tenta importar um módulo específico do Colab
        import google.colab
        IN_COLAB = True
    except:
        pass
    if IN_COLAB:
        if isinstance(client_id, int):
            fig.savefig(os.path.join(save_dir, f"mnist_{net_type}_r{round_number}_c{client_id}.png"))
            print("Imagem do cliente salva no drive")
        else:
            fig.savefig(os.path.join(save_dir, f"mnist_{net_type}_r{round_number}.png"))
            print("Imagem do servidor salva no drive")
    else:
        if isinstance(client_id, int):
            fig.savefig(f"mnist_{net_type}_r{round_number}_c{client_id}.png")
            print("Imagem do cliente salva")
        else:
            fig.savefig(f"mnist_{net_type}_r{round_number}.png")
            print("Imagem do servidor salva")
    plt.close(fig)
    return

In [5]:
generate_plot(model, torch.device("cpu"), 54, server=True, examples_per_class=5, classes=10, latent_dim=128)

Imagem do servidor salva


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

class GeneratedDataset(Dataset):
    """
    PyTorch Dataset that generates images on the fly from a GAN generator.

    Args:
        generator (torch.nn.Module): GAN generator model.
        num_samples (int): Number of samples per class (if balanced) or total samples (if random).
        latent_dim (int): Dimensionality of the latent vector z.
        num_classes (int): Number of classes.
        device (torch.device or str): Device to run generation on.
        balanced (bool): If True, generates num_samples per class (balanced dataset).
                         If False, generates num_samples samples with random class labels.
    """
    def __init__(self, generator, num_samples, latent_dim, num_classes, device, balanced=True):
        self.generator = generator
        self.num_samples = num_samples
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.device = device
        self.balanced = balanced
        self.model_name = type(self.generator).__name__
        self.classes = [i for i in range(self.num_classes)]

        # Generate data once at initialization
        if self.balanced:
            # Balanced: num_samples per class
            self.images = self._generate_balanced()
            self.total_len = num_samples * num_classes
        else:
            # Random: total num_samples images with random labels
            self.images = self._generate_random()
            self.total_len = num_samples

    def _generate_balanced(self):
        """
        Generate num_samples images for each class.
        Returns a dict mapping class_idx to tensor of images.
        """
        self.generator.eval()
        gen_imgs = {}
        # Prepare one-hot labels if needed
        for class_idx in range(self.num_classes):
            # Create labels tensor
            labels = torch.full((self.num_samples,), class_idx, dtype=torch.long, device=self.device)

            # Prepare latent vectors
            z = torch.randn(self.num_samples, self.latent_dim, device=self.device)

            with torch.no_grad():
                if self.model_name == 'Generator':
                    # One-hot encoding
                    labels_one_hot = F.one_hot(labels, num_classes=self.num_classes).float().to(self.device)
                    gen = self.generator(torch.cat([z, labels_one_hot], dim=1))
                else:
                    gen = self.generator(z, labels)

            gen_imgs[class_idx] = gen

        return gen_imgs

    def _generate_random(self):
        """
        Generate num_samples images with random class labels.
        Returns a list of generated images.
        """
        self.generator.eval()
        images = []

        # Sample random class labels
        labels = torch.randint(0, self.num_classes, (self.num_samples,), device=self.device)
        z = torch.randn(self.num_samples, self.latent_dim, device=self.device)

        with torch.no_grad():
            if self.model_name == 'Generator':
                labels_one_hot = F.one_hot(labels, num_classes=self.num_classes).float().to(self.device)
                gen = self.generator(torch.cat([z, labels_one_hot], dim=1))
            else:
                gen = self.generator(z, labels)

        # gen shape: [num_samples, channels, height, width]
        # Split into list
        for i in range(self.num_samples):
            images.append(gen[i])

        return images

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        """
        Returns only the image at index idx.
        """
        if self.balanced:
            class_idx = idx // self.num_samples
            sample_idx = idx % self.num_samples
            return self.images[class_idx][sample_idx]
        else:
            return self.images[idx]
