In [1]:
%%capture
!pip install torch
!pip install torchvision
!pip install torchmetrics pytorch-fid
!pip install torchmetrics[image]
!pip install torch-fidelity

In [2]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch import nn
from tqdm import tqdm
from torchsummary import summary
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision import models


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
# transform = transforms.Compose([
#     transforms.ToTensor()
# ])
training_data = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform
)

train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 608kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 3.49MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 10.9MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [6]:
test_data[10]

(tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.000

In [7]:
print("Number train set: ", len(training_data))
print("Number test set: ", len(test_data))

Number train set:  60000
Number test set:  10000


In [8]:
# VAE definition
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 400),
            nn.ReLU(),
            nn.Linear(400, 2 * latent_dim)  # Mean and log-variance
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 28 * 28),
            nn.Sigmoid()
        )
        self.latent_dim = latent_dim

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        # Encode
        q = self.encoder(x)
        mean, logvar = torch.chunk(q, chunks=2, dim=1)
        z = self.reparameterize(mean, logvar)
        # Decode
        x_recon = self.decoder(z)
        return x_recon, mean, logvar

In [9]:
# Loss function
def vae_loss(recon_x, x, mean, logvar, beta=1.0):
    # Ensure recon_x is in [0, 1]
    recon_x = torch.sigmoid(recon_x)

    # Normalize x to [0, 1]
    x = (x - x.min()) / (x.max() - x.min())

    # Compute reconstruction loss
    recon_loss = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')

    # Compute KL divergence
    kl_div = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())

    return recon_loss + beta * kl_div


latent_dim = 20
vae = VAE(latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=0.0002)

In [10]:
# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    epoch_progress = tqdm(train_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)
    total_loss = 0
    beta = min(1.0, epoch / 10)  # Warm-up cho KL divergence

    for batch_idx, (x, _) in enumerate(epoch_progress):
        x = x.to(device)
        x = x.view(x.size(0), -1)  # Flatten
        optimizer.zero_grad()
        x_recon, mean, logvar = vae(x)
        loss = vae_loss(x_recon, x, mean, logvar, beta=beta)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        epoch_progress.set_postfix(Loss=f"{loss.item():.4f}")

    tqdm.write(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {total_loss / len(train_dataloader.dataset):.4f}")

    # Hiển thị hình ảnh được tạo
    with torch.no_grad():
            z = torch.randn(64, latent_dim).to(device)
            generated = vae.decoder(z)
            generated = generated.view(-1, 1, 28, 28)
            grid = torchvision.utils.make_grid(generated, nrow=8, normalize=True)
            plt.figure(figsize=(8, 8))
            plt.imshow(grid.permute(1, 2, 0).cpu().numpy(), cmap="gray")
            plt.title(f"Generated Images at Epoch {epoch+1}")
            plt.show()


Output hidden; open in https://colab.research.google.com to view.

In [11]:


num_epochs = 10
# Initialize FID and Inception Score
fid = FrechetInceptionDistance(feature=2048).to(device) # Default Feature is 2048
inception = InceptionScore().to(device)
for epoch in range(num_epochs):
    epoch_progress = tqdm(test_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)

    for batch_idx, (real_images, _) in enumerate(epoch_progress):
      with torch.no_grad():
          batch_size = real_images.size(0)
          z = torch.randn(batch_size, latent_dim).to(device)
          fake_images = vae.decoder(z)

          # Convert to 3 channels
          fake_images_3channel = fake_images.repeat(1, 3, 1, 1)
          # real_images, _ = next(iter(test_dataloader))
          real_images_3channel = real_images.repeat(1, 3, 1, 1).to(device)

          fid.update(real_images_3channel.to(torch.uint8), real=True)
          fid.update(fake_images_3channel.to(torch.uint8), real=False)
          inception.update(fake_images_3channel.to(torch.uint8))

    fid_score = fid.compute()
    inception_score, inception_std = inception.compute()

    print(f"Epoch [{epoch+1}/{num_epochs}] | FID: {fid_score:.4f} | Inception Score: {inception_score:.4f} ± {inception_std:.4f}")

    fid.reset()
    inception.reset()

Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 347MB/s]


Epoch [1/10] | FID: 550.9489 | Inception Score: 1.0020 ± 0.0004




Epoch [2/10] | FID: 551.5508 | Inception Score: 1.0020 ± 0.0003




Epoch [3/10] | FID: 550.7186 | Inception Score: 1.0021 ± 0.0003




Epoch [4/10] | FID: 551.6161 | Inception Score: 1.0018 ± 0.0004




Epoch [5/10] | FID: 551.0883 | Inception Score: 1.0021 ± 0.0007




Epoch [6/10] | FID: 550.9278 | Inception Score: 1.0020 ± 0.0005




Epoch [7/10] | FID: 550.7907 | Inception Score: 1.0021 ± 0.0005




Epoch [8/10] | FID: 551.1342 | Inception Score: 1.0020 ± 0.0005




Epoch [9/10] | FID: 551.0695 | Inception Score: 1.0022 ± 0.0004




Epoch [10/10] | FID: 551.1095 | Inception Score: 1.0021 ± 0.0004
