# FID

## Importacoes, classes e configuracoes

In [20]:
import torch
import os

In [21]:
batch_size = 50
num_cpus = os.cpu_count()
num_workers = min(8, num_cpus,0)
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
dims = 2048

In [22]:
# --- Detectar Ambiente (Colab ou Local) ---
IN_COLAB = False
try:
    # Tenta importar um módulo específico do Colab
    from google.colab import drive
    import shutil # Usaremos para copiar, se necessário, mas salvar direto é melhor

    try:
        drive.mount('/content/drive')
        # Crie um diretório específico para salvar os resultados desta execução
        save_base_dir = "/content/drive/MyDrive/GAN_Training_Results" # Ajuste o caminho como desejar
        os.makedirs(save_base_dir, exist_ok=True)
        # Opcional: Crie um subdiretório único para esta execução específica (ex: baseado em timestamp)
        # import datetime
        # timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        # save_dir = os.path.join(save_base_dir, f"run_{timestamp}")
        # os.makedirs(save_dir, exist_ok=True)
        # Por simplicidade, vamos usar o diretório base diretamente por enquanto
        save_dir = save_base_dir
        print(f"✅ Google Drive montado. Arquivos serão salvos em: {save_dir}")
    except Exception as e:
        print(f"⚠️ Erro ao montar o Google Drive: {e}")
        print("   Downloads diretos serão tentados, mas podem atrasar.")
        save_dir = "." # Salvar localmente se o Drive falhar
    IN_COLAB = True
    print("✅ Ambiente Google Colab detectado. Downloads automáticos (a cada 2 épocas) ativados.")
except ImportError:
    print("✅ Ambiente local detectado. Downloads automáticos desativados.")


✅ Ambiente local detectado. Downloads automáticos desativados.


In [23]:
import torchvision
from torch.utils.model_zoo import load_url as load_state_dict_from_url
import torch.nn.functional as F

In [24]:
def _inception_v3(*args, **kwargs):
    """Wraps `torchvision.models.inception_v3`"""
    try:
        version = tuple(map(int, torchvision.__version__.split(".")[:2]))
    except ValueError:
        # Just a caution against weird version strings
        version = (0,)

    # Skips default weight inititialization if supported by torchvision
    # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
    if version >= (0, 6):
        kwargs["init_weights"] = False

    # Backwards compatibility: `weights` argument was handled by `pretrained`
    # argument prior to version 0.13.
    if version < (0, 13) and "weights" in kwargs:
        if kwargs["weights"] == "DEFAULT":
            kwargs["pretrained"] = True
        elif kwargs["weights"] is None:
            kwargs["pretrained"] = False
        else:
            raise ValueError(
                "weights=={} not supported in torchvision {}".format(
                    kwargs["weights"], torchvision.__version__
                )
            )
        del kwargs["weights"]

    return torchvision.models.inception_v3(*args, **kwargs)


def fid_inception_v3():
    """Build pretrained Inception model for FID computation

    The Inception model for FID computation uses a different set of weights
    and has a slightly different structure than torchvision's Inception.

    This method first constructs torchvision's Inception and then patches the
    necessary parts that are different in the FID Inception model.
    """
    inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    state_dict = load_state_dict_from_url("https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth", progress=True)
    inception.load_state_dict(state_dict)
    return inception


class FIDInceptionA(torchvision.models.inception.InceptionA):
    """InceptionA block patched for FID computation"""

    def __init__(self, in_channels, pool_features):
        super(FIDInceptionA, self).__init__(in_channels, pool_features)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(
            x, kernel_size=3, stride=1, padding=1, count_include_pad=False
        )
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class FIDInceptionC(torchvision.models.inception.InceptionC):
    """InceptionC block patched for FID computation"""

    def __init__(self, in_channels, channels_7x7):
        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(
            x, kernel_size=3, stride=1, padding=1, count_include_pad=False
        )
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)


class FIDInceptionE_1(torchvision.models.inception.InceptionE):
    """First InceptionE block patched for FID computation"""

    def __init__(self, in_channels):
        super(FIDInceptionE_1, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: Tensorflow's average pool does not use the padded zero's in
        # its average calculation
        branch_pool = F.avg_pool2d(
            x, kernel_size=3, stride=1, padding=1, count_include_pad=False
        )
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class FIDInceptionE_2(torchvision.models.inception.InceptionE):
    """Second InceptionE block patched for FID computation"""

    def __init__(self, in_channels):
        super(FIDInceptionE_2, self).__init__(in_channels)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        # Patch: The FID Inception model uses max pooling instead of average
        # pooling. This is likely an error in this specific Inception
        # implementation, as other Inception models use average pooling here
        # (which matches the description in the paper).
        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

In [25]:
import torch.nn as nn

In [26]:
class InceptionV3(nn.Module):
    """Pretrained InceptionV3 network returning feature maps"""

    # Index of default block of inception to return,
    # corresponds to output of final average pooling
    DEFAULT_BLOCK_INDEX = 3

    # Maps feature dimensionality to their output blocks indices
    BLOCK_INDEX_BY_DIM = {
        64: 0,  # First max pooling features
        192: 1,  # Second max pooling featurs
        768: 2,  # Pre-aux classifier features
        2048: 3,  # Final average pooling features
    }

    def __init__(
        self,
        output_blocks=(DEFAULT_BLOCK_INDEX,),
        resize_input=True,
        normalize_input=True,
        requires_grad=False,
        use_fid_inception=True,
    ):
        """Build pretrained InceptionV3

        Parameters
        ----------
        output_blocks : list of int
            Indices of blocks to return features of. Possible values are:
                - 0: corresponds to output of first max pooling
                - 1: corresponds to output of second max pooling
                - 2: corresponds to output which is fed to aux classifier
                - 3: corresponds to output of final average pooling
        resize_input : bool
            If true, bilinearly resizes input to width and height 299 before
            feeding input to model. As the network without fully connected
            layers is fully convolutional, it should be able to handle inputs
            of arbitrary size, so resizing might not be strictly needed
        normalize_input : bool
            If true, scales the input from range (0, 1) to the range the
            pretrained Inception network expects, namely (-1, 1)
        requires_grad : bool
            If true, parameters of the model require gradients. Possibly useful
            for finetuning the network
        use_fid_inception : bool
            If true, uses the pretrained Inception model used in Tensorflow's
            FID implementation. If false, uses the pretrained Inception model
            available in torchvision. The FID Inception model has different
            weights and a slightly different structure from torchvision's
            Inception model. If you want to compute FID scores, you are
            strongly advised to set this parameter to true to get comparable
            results.
        """
        super(InceptionV3, self).__init__()

        self.resize_input = resize_input
        self.normalize_input = normalize_input
        self.output_blocks = sorted(output_blocks)
        self.last_needed_block = max(output_blocks)

        assert self.last_needed_block <= 3, "Last possible output block index is 3"

        self.blocks = nn.ModuleList()

        if use_fid_inception:
            inception = fid_inception_v3()
        else:
            inception = _inception_v3(weights="DEFAULT")

        # Block 0: input to maxpool1
        block0 = [
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2),
        ]
        self.blocks.append(nn.Sequential(*block0))

        # Block 1: maxpool1 to maxpool2
        if self.last_needed_block >= 1:
            block1 = [
                inception.Conv2d_3b_1x1,
                inception.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2),
            ]
            self.blocks.append(nn.Sequential(*block1))

        # Block 2: maxpool2 to aux classifier
        if self.last_needed_block >= 2:
            block2 = [
                inception.Mixed_5b,
                inception.Mixed_5c,
                inception.Mixed_5d,
                inception.Mixed_6a,
                inception.Mixed_6b,
                inception.Mixed_6c,
                inception.Mixed_6d,
                inception.Mixed_6e,
            ]
            self.blocks.append(nn.Sequential(*block2))

        # Block 3: aux classifier to final avgpool
        if self.last_needed_block >= 3:
            block3 = [
                inception.Mixed_7a,
                inception.Mixed_7b,
                inception.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            ]
            self.blocks.append(nn.Sequential(*block3))

        for param in self.parameters():
            param.requires_grad = requires_grad

    def forward(self, inp):
        """Get Inception feature maps

        Parameters
        ----------
        inp : torch.autograd.Variable
            Input tensor of shape Bx3xHxW. Values are expected to be in
            range (0, 1)

        Returns
        -------
        List of torch.autograd.Variable, corresponding to the selected output
        block, sorted ascending by index
        """
        outp = []
        x = inp

        if self.resize_input:
            x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)

        if self.normalize_input:
            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)

        for idx, block in enumerate(self.blocks):
            x = block(x)
            if idx in self.output_blocks:
                outp.append(x)

            if idx == self.last_needed_block:
                break

        return outp

## Calculo da distribuicao gerada

In [27]:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

In [28]:
model = InceptionV3([block_idx]).to(device)

In [29]:
model.eval()

InceptionV3(
  (blocks): ModuleList(
    (0): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): BasicConv2d(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      

In [30]:
import numpy as np

In [31]:
import pathlib

In [32]:
class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, files, transforms=None):
        self.files = files
        self.transforms = transforms

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert("RGB")
        if self.transforms is not None:
            img = self.transforms(img)
        return img

### Por imagens geradas prontas

In [None]:
path = "../imagens geradas/cgan_samples"
path = pathlib.Path(path)
files = sorted(file for file in path.glob("*.png"))

In [None]:
pred_arr = np.empty((len(files), dims))

In [None]:
dataset = ImagePathDataset(files, transforms=torchvision.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

### Por modelo pre-treinado gerando imagens

#### CGAN

In [27]:
class CGAN(nn.Module):
    def __init__(self, dataset="mnist", img_size=28, latent_dim=100):
        super(CGAN, self).__init__()
        if dataset == "mnist":
            self.classes = 10
            self.channels = 1
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes)
        self.adv_loss = torch.nn.BCELoss()


        self.generator = nn.Sequential(
            *self._create_layer_gen(self.latent_dim + self.classes, 128, False),
            *self._create_layer_gen(128, 256),
            *self._create_layer_gen(256, 512),
            *self._create_layer_gen(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

        self.discriminator = nn.Sequential(
            *self._create_layer_disc(self.classes + int(np.prod(self.img_shape)), 1024, False, True),
            *self._create_layer_disc(1024, 512, True, True),
            *self._create_layer_disc(512, 256, True, True),
            *self._create_layer_disc(256, 128, False, False),
            *self._create_layer_disc(128, 1, False, False),
            nn.Sigmoid()
        )

        #self._initialize_weights()

    def _create_layer_gen(self, size_in, size_out, normalize=True):
        layers = [nn.Linear(size_in, size_out)]
        if normalize:
            layers.append(nn.BatchNorm1d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def _create_layer_disc(self, size_in, size_out, drop_out=True, act_func=True):
        layers = [nn.Linear(size_in, size_out)]
        if drop_out:
            layers.append(nn.Dropout(0.4))
        if act_func:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def _initialize_weights(self):
        # Itera sobre todos os módulos da rede geradora
        for m in self.generator:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

    def forward(self, input, labels):
        if input.dim() == 2:
            z = torch.cat((self.label_embedding(labels), input), -1)
            x = self.generator(z)
            x = x.view(x.size(0), *self.img_shape) #Em
            return x
        elif input.dim() == 4:
            x = torch.cat((input.view(input.size(0), -1), self.label_embedding(labels)), -1)
            return self.discriminator(x)

    def loss(self, output, label):
        return self.adv_loss(output, label)

In [28]:
cgan = CGAN()
cgan.load_state_dict(torch.load("CGAN_50epochs.pth"))

<All keys matched successfully>

#### F2U_GAN

In [33]:
class F2U_GAN(nn.Module):
    def __init__(self, dataset="mnist", img_size=28, latent_dim=128, condition=True):
        super(F2U_GAN, self).__init__()
        if dataset == "mnist":
            self.classes = 10
            self.channels = 1
        else:
            raise NotImplementedError("Only MNIST is supported")

        self.condition = condition
        self.label_embedding = nn.Embedding(self.classes, self.classes) if condition else None
        #self.label_embedding_disc = nn.Embedding(self.classes, self.img_size*self.img_size) if condition else None
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.input_shape_gen = self.latent_dim + self.label_embedding.embedding_dim if condition else self.latent_dim
        self.input_shape_disc = self.channels + self.classes if condition else self.channels

        self.adv_loss = torch.nn.BCEWithLogitsLoss()

        # Generator (unchanged) To calculate output shape of convtranspose layers, we can use the formula:
        # output_shape = (input_shape - 1) * stride - 2 * padding + kernel_size + output_padding (or dilation * (kernel_size - 1) + 1 inplace of kernel_size if using dilation)
        self.generator = nn.Sequential(
            nn.Linear(self.input_shape_gen, 256 * 7 * 7),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (256, 7, 7)),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (256,7,7) -> (128,14,14)
            nn.BatchNorm2d(128, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # (128,14,14) -> (64,28,28)
            nn.BatchNorm2d(64, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, self.channels, kernel_size=3, stride=1, padding=1), # (64,28,28) -> (1,28,28)
            nn.Tanh()
        )

        # Discriminator (corrected) To calculate output shape of conv layers, we can use the formula:
        # output_shape = ⌊(input_shape - kernel_size + 2 * padding) / stride + 1⌋ (or (dilation * (kernel_size - 1) - 1) inplace of kernel_size if using dilation)
        self.discriminator = nn.Sequential(
        # Camada 1: (1,28,28) -> (32,13,13)
        nn.utils.spectral_norm(nn.Conv2d(self.input_shape_disc, 32, kernel_size=3, stride=2, padding=0)),
        nn.LeakyReLU(0.2, inplace=True),

        # Camada 2: (32,14,14) -> (64,7,7)
        nn.utils.spectral_norm(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)),
        nn.LeakyReLU(0.2, inplace=True),

        # Camada 3: (64,7,7) -> (128,3,3)
        nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0)),
        nn.LeakyReLU(0.2, inplace=True),

        # Camada 4: (128,3,3) -> (256,1,1)
        nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=0)),  # Padding 0 aqui!
        nn.LeakyReLU(0.2, inplace=True),

        # Achata e concatena com as labels
        nn.Flatten(), # (256,1,1) -> (256*1*1,)
        nn.utils.spectral_norm(nn.Linear(256 * 1 * 1, 1))  # 256 (features)
        )

    def forward(self, input, labels=None):
        if input.dim() == 2:
            # Generator forward pass (unchanged)
            if self.condition:
                embedded_labels = self.label_embedding(labels)
                gen_input = torch.cat((input, embedded_labels), dim=1)
                x = self.generator(gen_input)
            else:
                x = self.generator(input)
            return x.view(-1, *self.img_shape)

        elif input.dim() == 4:
            # Discriminator forward pass
            if self.condition:
                embedded_labels = self.label_embedding(labels)
                image_labels = embedded_labels.view(embedded_labels.size(0), self.label_embedding.embedding_dim, 1, 1).expand(-1, -1, self.img_size, self.img_size)
                x = torch.cat((input, image_labels), dim=1)
            else:
                x = input
            return self.discriminator(x)

    def loss(self, output, label):
        return self.adv_loss(output, label)

In [35]:
f2u_gan = F2U_GAN(condition=True)
f2u_gan.load_state_dict(torch.load("gen_round54.pt"))

<All keys matched successfully>

### Geração das imagens

In [None]:
if IN_COLAB:
  !pip install datasets

In [36]:
from torchvision.transforms.functional import to_pil_image
from datasets import Features, ClassLabel
from datasets import Dataset as hf_dataset
from datasets import Image as IMG

#### Condicional

In [56]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

class GeneratedDataset(Dataset):
    """
    PyTorch Dataset that generates images on the fly from a GAN generator.

    Args:
        generator (torch.nn.Module): GAN generator model.
        num_samples (int): Number of samples per class (if balanced) or total samples (if random).
        latent_dim (int): Dimensionality of the latent vector z.
        num_classes (int): Number of classes.
        device (torch.device or str): Device to run generation on.
        balanced (bool): If True, generates num_samples per class (balanced dataset).
                         If False, generates num_samples samples with random class labels.
    """
    def __init__(self, generator, num_samples, latent_dim, num_classes, device, balanced=True):
        self.generator = generator
        self.num_samples = num_samples
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.device = device
        self.balanced = balanced
        self.model_name = type(self.generator).__name__
        self.classes = [i for i in range(self.num_classes)]

        # Generate data once at initialization
        if self.balanced:
            # Balanced: num_samples per class
            self.images = self._generate_balanced()
            self.total_len = num_samples * num_classes
        else:
            # Random: total num_samples images with random labels
            self.images = self._generate_random()
            self.total_len = num_samples

    def _generate_balanced(self):
        """
        Generate num_samples images for each class.
        Returns a dict mapping class_idx to tensor of images.
        """
        self.generator.eval()
        gen_imgs = {}
        # Prepare one-hot labels if needed
        for class_idx in range(self.num_classes):
            # Create labels tensor
            labels = torch.full((self.num_samples,), class_idx, dtype=torch.long, device=self.device)

            # Prepare latent vectors
            z = torch.randn(self.num_samples, self.latent_dim, device=self.device)

            with torch.no_grad():
                if self.model_name == 'Generator':
                    # One-hot encoding
                    labels_one_hot = F.one_hot(labels, num_classes=self.num_classes).float().to(self.device)
                    gen = self.generator(torch.cat([z, labels_one_hot], dim=1))
                else:
                    gen = self.generator(z, labels)

            gen_imgs[class_idx] = gen

        return gen_imgs

    def _generate_random(self):
        """
        Generate num_samples images with random class labels.
        Returns a list of generated images.
        """
        self.generator.eval()
        images = []

        # Sample random class labels
        labels = torch.randint(0, self.num_classes, (self.num_samples,), device=self.device)
        z = torch.randn(self.num_samples, self.latent_dim, device=self.device)

        with torch.no_grad():
            if self.model_name == 'Generator':
                labels_one_hot = F.one_hot(labels, num_classes=self.num_classes).float().to(self.device)
                gen = self.generator(torch.cat([z, labels_one_hot], dim=1))
            else:
                gen = self.generator(z, labels)

        # gen shape: [num_samples, channels, height, width]
        # Split into list
        for i in range(self.num_samples):
            images.append(gen[i])

        return images

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        """
        Returns only the image at index idx.
        """
        if self.balanced:
            class_idx = idx // self.num_samples
            sample_idx = idx % self.num_samples
            return self.images[class_idx][sample_idx]
        else:
            return self.images[idx]


In [60]:
# Parameters
num_samples = 1000  # Número de amostras por classe
latent_dim = 128

# Create the dataset and dataloader
generated_dataset = GeneratedDataset(generator=f2u_gan, num_samples=num_samples, latent_dim=latent_dim, num_classes=10, device="cpu", balanced=True)
gen_dataset = generated_dataset.images

In [61]:
for c in gen_dataset.keys():
    gen_dataset[c] = (gen_dataset[c] + 1) / 2
    gen_dataset[c] = gen_dataset[c].repeat(1, 3, 1, 1)
# # Ajustar para o intervalo [0, 1]
# gen_dataset = (gen_dataset + 1) / 2
# Expandir o canal para RGB (replicando o canal 1 para 3)
# gen_dataset = gen_dataset.repeat(1, 3, 1, 1)  # Agora tem shape [2050, 3, 28, 28]

In [63]:
dataloaders = [torch.utils.data.DataLoader(gen_dataset[c], batch_size=batch_size, num_workers=num_workers, shuffle=False) for c in range(10)]

#### Não condicional

In [None]:
# Parameters
num_samples = 10000  # Número de amostras totais
latent_dim = 128

# Create the dataset and dataloader
generated_dataset_unc = GeneratedDataset(generator=f2u_gan, num_samples=num_samples, latent_dim=latent_dim, num_classes=10, device="cpu", balanced=False)
gen_dataset_unc = generated_dataset_unc.images

In [39]:
class UnconditionalGeneratedDataset(Dataset):
    def __init__(self,
                 generator,
                 num_samples,
                 latent_dim=128,
                 device="cpu",
                 image_col_name="image"):
        """
        Generates a dataset using an unconditional generative model.

        Args:
            generator: The pre-trained unconditional generative model.
            num_samples (int): Total number of images to generate.
            latent_dim (int): Dimension of the latent space vector (z).
            device (str): Device to run generation on ('cpu' or 'cuda').
            image_col_name (str): Name for the image column in the output dictionary.
        """
        self.generator = generator
        self.num_samples = num_samples
        self.latent_dim = latent_dim
        self.device = device
        self.image_col_name = image_col_name

        if self.num_samples < 0:
            raise ValueError("num_samples must be non-negative")
        elif self.num_samples == 0:
            print("Warning: num_samples is 0. Dataset will be empty.")
            self.images = torch.empty(0)
        else:
            self.images = self._generate_images()

    def _generate_images(self):
        self.generator.eval()
        self.generator.to(self.device)

        # Create latent noise
        z = torch.randn(self.num_samples, self.latent_dim, device=self.device)

        # Generate images in batches
        generated_images = []
        batch_size = min(1024, self.num_samples)
        with torch.no_grad():
            for i in range(0, self.num_samples, batch_size):
                z_batch = z[i : min(i + batch_size, self.num_samples)]
                gen_imgs = self.generator(z_batch)
                generated_images.append(gen_imgs.cpu())

        self.generator.cpu()
        return torch.cat(generated_images, dim=0)

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError("Dataset index out of range")
        return { self.image_col_name: self.images[idx] }


In [89]:
# Parameters
num_samples = 100
latent_dim = 128

# Create the dataset and dataloader
generated_dataset_unc = UnconditionalGeneratedDataset(generator=f2u_gan, num_samples=num_samples, latent_dim=latent_dim, device="cpu")

In [90]:
gen_dataset_unc = generated_dataset_unc.images
N, _, H, W = gen_dataset_unc.shape
out = torch.empty((N, 3, H, W), device=gen_dataset_unc.device)

for i in range(N):
    img = gen_dataset_unc[i]           # shape [1, H, W]
    img = (img + 1) / 2            # normalize to [0,1]
    out[i] = img.repeat(3, 1, 1)   # shape [3, H, W]

gen_dataset_rgb = out


In [91]:
dataloader = torch.utils.data.DataLoader(gen_dataset_rgb, batch_size=batch_size, num_workers=num_workers, shuffle=False)

### Calculo

In [65]:
from tqdm import tqdm
from PIL import Image

#### Cond

In [66]:
mus_gen = []
sigmas_gen = []

In [94]:
for c in range(10):
  pred_arr = np.empty((len(gen_dataset[c]), dims))
  start_idx = 0
  for batch in tqdm(dataloaders[c]):
          batch = batch.to(device)

          with torch.no_grad():
              pred = model(batch)[0]

          # If model output is not scalar, apply global spatial average pooling.
          # This happens if you choose a dimensionality not equal 2048.
          if pred.size(2) != 1 or pred.size(3) != 1:
              pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))

          pred = pred.squeeze(3).squeeze(2).cpu().numpy()

          pred_arr[start_idx : start_idx + pred.shape[0]] = pred

          start_idx = start_idx + pred.shape[0]
  mus_gen.append(np.mean(pred_arr, axis=0))
  sigmas_gen.append(np.cov(pred_arr, rowvar=False))

100%|██████████| 2/2 [00:01<00:00,  1.33it/s]
100%|██████████| 2/2 [00:00<00:00,  4.85it/s]
100%|██████████| 2/2 [00:00<00:00,  4.69it/s]
100%|██████████| 2/2 [00:00<00:00,  4.74it/s]
100%|██████████| 2/2 [00:00<00:00,  4.67it/s]
100%|██████████| 2/2 [00:00<00:00,  4.62it/s]
100%|██████████| 2/2 [00:00<00:00,  4.65it/s]
100%|██████████| 2/2 [00:00<00:00,  4.73it/s]
100%|██████████| 2/2 [00:00<00:00,  4.73it/s]
100%|██████████| 2/2 [00:00<00:00,  4.70it/s]


#### Uncond

In [96]:
pred_arr = np.empty((len(gen_dataset_rgb), dims))
start_idx = 0
for batch in tqdm(dataloader):
      batch = batch.to(device)

      with torch.no_grad():
          pred = model(batch)[0]

      # If model output is not scalar, apply global spatial average pooling.
      # This happens if you choose a dimensionality not equal 2048.
      if pred.size(2) != 1 or pred.size(3) != 1:
          pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))

      pred = pred.squeeze(3).squeeze(2).cpu().numpy()

      pred_arr[start_idx : start_idx + pred.shape[0]] = pred

      start_idx = start_idx + pred.shape[0]
mu_gen = np.mean(pred_arr, axis=0)
sigma_gen = np.cov(pred_arr, rowvar=False)

100%|██████████| 2/2 [00:00<00:00,  4.18it/s]


## Calculo da distribuicao real

In [102]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [103]:
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load the training and test datasets
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
trainset_reduzido = torch.utils.data.random_split(trainset, [1000, len(trainset) - 1000])[0]
# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
trainloader_reduzido = torch.utils.data.DataLoader(trainset_reduzido, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=128)

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 495kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.55MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.2MB/s]


### Pegando imagens sem salvar

#### Cond

In [105]:
def select_samples_per_class(dataset, num_samples):
    """
    Selects a specified number of samples per class from the dataset and returns them as tensors.

    Parameters:
    dataset (torch.utils.data.Dataset): The dataset to select samples from.
    num_samples (int): The number of samples to select per class.

    Returns:
    dict: A dictionary where each key corresponds to a class and the value is a tensor of shape [num_samples, 1, 28, 28].
    """
    class_samples = {i: [] for i in range(len(dataset.classes))}
    class_counts = {i: 0 for i in range(len(dataset.classes))}

    for img, label in dataset:
        if class_counts[label] < num_samples:
            class_samples[label].append(img)
            class_counts[label] += 1
        if all(count >= num_samples for count in class_counts.values()):
            break
    else:
        print("Warning: Not all classes have the requested number of samples.")

    # Convert lists of tensors to a single tensor per class
    for label in class_samples:
        if class_samples[label]:  # Check if the list is not empty
            class_samples[label] = torch.stack(class_samples[label], dim=0)
            class_samples[label] = (class_samples[label] + 1) / 2
            class_samples[label] = class_samples[label].repeat(1, 3, 1, 1)
        else:
            # Handle empty classes if necessary; here we leave an empty tensor
            class_samples[label] = torch.Tensor()

    return class_samples

In [107]:
img_reais = select_samples_per_class(testset, 100)

In [108]:
dataloaders = [torch.utils.data.DataLoader(img_reais[c], batch_size=batch_size, num_workers=num_workers, shuffle=False) for c in range(10)]

#### Uncond

In [120]:
def select_random_samples(dataset, num_samples):
    """
    Randomly selects a specified number of samples from a dataset,
    normalizes to [0,1], and converts single-channel to RGB.

    Args:
        dataset (torch.utils.data.Dataset): The source dataset with items as dicts or tensors.
        num_samples (int): Total number of samples to select.

    Returns:
        torch.Tensor: Tensor of shape [num_samples, 3, H, W], values in [0,1].
    """
    import random
    total = len(dataset)
    indices = random.sample(range(total), min(num_samples, total))

    samples = []
    for idx in indices:
        item = dataset[idx][0] # dataset elements are tuples (img, label)
        # item might be dict or tensor
        img = item['image'] if isinstance(item, dict) else item
        # Normalize from [-1,1] to [0,1]
        img = (img + 1) / 2
        # Expand to 3 channels
        if img.ndim == 3 and img.shape[0] == 1:
            img = img.repeat(3, 1, 1)
        elif img.ndim == 2:
            img = img.unsqueeze(0).repeat(3, 1, 1)
        samples.append(img)

    return torch.stack(samples, dim=0)

In [121]:
img_reais_rand = select_random_samples(testset, 100)

In [122]:
dataloader = torch.utils.data.DataLoader(img_reais_rand, batch_size=batch_size, num_workers=num_workers, shuffle=False)

### Salvando imagens

In [None]:
import random

In [None]:
# Function to save a random sample of images
def save_random_samples(dataset, num_samples=10, folder='Imagens Testes/mnist_samples', balanced=False, classes=None):
    if not os.path.exists(folder):
        os.makedirs(folder)

    if classes is None:
        classes = [int(c.split()[0]) for c in dataset.classes]  # Use all classes if none are specified

    if balanced:
        # Get the number of classes
        num_classes = len(classes)
        samples_per_class = -(-num_samples // num_classes)  # Round up division
        indices = []
        class_counts = {i: 0 for i in classes}

        # Shuffle the dataset
        shuffled_indices = list(range(len(dataset)))
        random.shuffle(shuffled_indices)

        for idx in shuffled_indices:
            img = dataset[idx][0]
            label = int(dataset[idx][1])
            if label in classes and class_counts[label] < samples_per_class:
                indices.append(idx)
                class_counts[label] += 1
            if len(indices) >= num_samples:
                break
    else:
        indices = []
        while len(indices) < num_samples:
            idx = random.randint(0, len(dataset) - 1)
            if int(dataset[idx][1]) in classes:
                indices.append(idx)
    for i, idx in enumerate(indices):
        img, label = dataset[idx]
        img = (img * 0.5 + 0.5) * 255  # Denormalize the image
        img = img.byte().numpy().transpose(1, 2, 0).squeeze()  # Convert to numpy array
        img = Image.fromarray(img)
        img.save(os.path.join(folder, f'mnist_sample_{i}_label_{label}.png'))

In [None]:
for i in range(10):
  save_random_samples(trainset, num_samples=2050, folder=f'Imagens Testes/mnist_samples_{i}', balanced=True, classes=[i])

In [None]:
pathes = [f"Imagens Testes/mnist_samples_{i}" for i in range(10)]
pathes = [pathlib.Path(path) for path in pathes]
files = [sorted(file for file in path.glob("*.png")) for path in pathes]

In [None]:
datasets = [ImagePathDataset(file, transforms=torchvision.transforms.ToTensor()) for file in files]
dataloaders = [torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) for dataset in datasets]

In [None]:
import time

### Calculo

#### Cond

In [113]:
mus_real = []
sigmas_real = []
for c in range(10):
  model = InceptionV3([block_idx]).to(device)
  model.eval()
  pred_arr = np.empty((len(img_reais[c]), dims))
  start_idx = 0
  for batch in tqdm(dataloaders[c]):
          batch = batch.to(device)

          with torch.no_grad():
              pred = model(batch)[0]

          # If model output is not scalar, apply global spatial average pooling.
          # This happens if you choose a dimensionality not equal 2048.
          if pred.size(2) != 1 or pred.size(3) != 1:
              pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))

          pred = pred.squeeze(3).squeeze(2).cpu().numpy()

          pred_arr[start_idx : start_idx + pred.shape[0]] = pred

          start_idx = start_idx + pred.shape[0]
  mus_real.append(np.mean(pred_arr, axis=0))
  sigmas_real.append(np.cov(pred_arr, rowvar=False))

100%|██████████| 2/2 [00:00<00:00,  4.31it/s]
100%|██████████| 2/2 [00:00<00:00,  4.78it/s]
100%|██████████| 2/2 [00:00<00:00,  4.73it/s]
100%|██████████| 2/2 [00:00<00:00,  4.81it/s]
100%|██████████| 2/2 [00:00<00:00,  4.77it/s]
100%|██████████| 2/2 [00:00<00:00,  4.74it/s]
100%|██████████| 2/2 [00:00<00:00,  4.68it/s]
100%|██████████| 2/2 [00:00<00:00,  4.72it/s]
100%|██████████| 2/2 [00:00<00:00,  4.69it/s]
100%|██████████| 2/2 [00:00<00:00,  4.68it/s]


#### Uncond

In [123]:
model = InceptionV3([block_idx]).to(device)
model.eval()
pred_arr = np.empty((len(img_reais_rand), dims))
start_idx = 0
for batch in tqdm(dataloader):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch)[0]

        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))

        pred = pred.squeeze(3).squeeze(2).cpu().numpy()

        pred_arr[start_idx : start_idx + pred.shape[0]] = pred

        start_idx = start_idx + pred.shape[0]
mu_real = np.mean(pred_arr, axis=0)
sigma_real = np.cov(pred_arr, rowvar=False)

100%|██████████| 2/2 [00:00<00:00,  4.15it/s]


## Calculo FID

In [128]:
from scipy import linalg

In [129]:
mus_gen = [np.atleast_1d(mu_gen) for mu_gen in mus_gen]
mus_real = [np.atleast_1d(mu_real) for mu_real in mus_real]

sigmas_gen = [np.atleast_2d(sigma_gen) for sigma_gen in sigmas_gen]
sigmas_real = [np.atleast_2d(sigma_real) for sigma_real in sigmas_real]

for mu_gen, mu_real, sigma_gen, sigma_real in zip(mus_gen, mus_real, sigmas_gen, sigmas_real):
  assert (
      mu_gen.shape == mu_real.shape
  ), "Training and test mean vectors have different lengths"
  assert (
      sigma_gen.shape == sigma_real.shape
  ), "Training and test covariances have different dimensions"

diffs = [mu_gen - mu_real for mu_gen, mu_real in zip(mus_gen, mus_real)]

# Product might be almost singular
covmeans = [linalg.sqrtm(sigmas_gen.dot(sigmas_real), disp=False)[0] for sigmas_gen, sigmas_real in zip(sigmas_gen, sigmas_real)]
for covmean, sigma_gen, sigma_real in zip(covmeans, sigmas_gen, sigmas_real):
  if not np.isfinite(covmean).all():
    msg = (
        "fid calculation produces singular product; "
        "adding %s to diagonal of cov estimates"
    ) % 1e-6
    print(msg)
    offset = np.eye(sigma_gen.shape[0]) * 1e-6
    covmean = linalg.sqrtm((sigma_gen + offset).dot(sigma_real + offset))

# Numerical error might give slight imaginary component
for i, covmean in enumerate(covmeans):
  if np.iscomplexobj(covmean):
      if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
          m = np.max(np.abs(covmean.imag))
          raise ValueError("Imaginary component {}".format(m))
      covmeans[i] = covmean.real

tr_covmeans = [np.trace(covmean) for covmean in covmeans]

In [130]:
mu_gen = np.atleast_1d(mu_gen)
mu_real = np.atleast_1d(mu_real)

sigma_gen = np.atleast_2d(sigma_gen)
sigma_real = np.atleast_2d(sigma_real)

assert (
    mu_gen.shape == mu_real.shape
), "Training and test mean vectors have different lengths"
assert (
    sigma_gen.shape == sigma_real.shape
), "Training and test covariances have different dimensions"

diff = mu_gen - mu_real

# Product might be almost singular
covmean = linalg.sqrtm(sigma_gen.dot(sigma_real), disp=False)[0]

if not np.isfinite(covmean).all():
  msg = (
      "fid calculation produces singular product; "
      "adding %s to diagonal of cov estimates"
  ) % 1e-6
  print(msg)
  offset = np.eye(sigma_gen.shape[0]) * 1e-6
  covmean = linalg.sqrtm((sigma_gen + offset).dot(sigma_real + offset))

# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
    if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
        m = np.max(np.abs(covmean.imag))
        raise ValueError("Imaginary component {}".format(m))
    covmeans[i] = covmean.real

tr_covmean = np.trace(covmean)

In [131]:
fids = [diff.dot(diff) + np.trace(sigma_gen) + np.trace(sigma_real) - 2 * tr_covmean for diff, sigma_gen, sigma_real, tr_covmean in zip(diffs, sigmas_gen, sigmas_real, tr_covmeans)]

In [132]:
fid = diff.dot(diff) + np.trace(sigma_gen) + np.trace(sigma_real) - 2 * tr_covmean

In [133]:
fids

[np.float64(60.021444565132754),
 np.float64(48.42313957263667),
 np.float64(121.37066199800525),
 np.float64(51.123460579666784),
 np.float64(74.737113709185),
 np.float64(67.69499682509895),
 np.float64(68.71997970342036),
 np.float64(89.68777812501884),
 np.float64(70.17875276838706),
 np.float64(53.19441100532528)]

In [134]:
fid

np.complex128(53.19441100532529-1.1607598735405455e-06j)