In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [9]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"device: {device}")
torch.set_default_device(device)

device: cuda


In [10]:
class Generator(nn.Module):
    def __init__(self, noise_dim=100, img_dim=784):  # 28*28=784 cho MNIST
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self, img_dim=784):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

In [11]:
# Kích thước ngẫu nhiên cho noise vector
noise_dim = 100

# Khởi tạo Generator và Discriminator
gen = Generator(noise_dim).to(device=device)
disc = Discriminator().to(device=device)

# Hàm loss và optimizers
criterion = nn.BCELoss().to(device=device)
optimizer_gen = optim.Adam(gen.parameters(), lr=0.0002)
optimizer_disc = optim.Adam(disc.parameters(), lr=0.0002)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalization để đưa ảnh về khoảng [-1, 1]
])

dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, generator=torch.Generator(device=device))

TypeError: Compose.__init__() got an unexpected keyword argument 'device'

In [None]:
num_epochs = 100 # 50
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(dataloader):
        batch_size = real.size(0)
        real = real.view(batch_size, -1)

        # Labels cho hình thật và giả
        labels_real = torch.ones(batch_size, 1)
        labels_fake = torch.zeros(batch_size, 1)

        # Huấn luyện Discriminator với hình thật
        disc_real = disc(real)
        loss_disc_real = criterion(disc_real, labels_real)

        # Huấn luyện Discriminator với hình giả
        noise = torch.randn(batch_size, noise_dim)
        fake_images = gen(noise)
        disc_fake = disc(fake_images.detach())
        loss_disc_fake = criterion(disc_fake, labels_fake)

        # Tổng loss của Discriminator
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        optimizer_disc.zero_grad()
        loss_disc.backward()
        optimizer_disc.step()

        # Huấn luyện Generator để đánh lừa Discriminator
        output = disc(fake_images)
        loss_gen = criterion(output, labels_real)
        optimizer_gen.zero_grad()
        loss_gen.backward()
        optimizer_gen.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] | Loss D: {loss_disc:.4f}, Loss G: {loss_gen:.4f}")

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [None]:
def generate_and_plot_images(generator, noise_dim, num_images=25):
    noise = torch.randn(num_images, noise_dim)
    fake_images = generator(noise).view(-1, 1, 28, 28)
    fake_images = (fake_images + 1) / 2  # Đưa ảnh về khoảng [0, 1]

    grid = torchvision.utils.make_grid(fake_images, nrow=5)
    plt.figure(figsize=(10, 10))
    plt.imshow(np.transpose(grid.detach().cpu().numpy(), (1, 2, 0)))
    plt.axis("off")
    plt.show()

# Hiển thị ảnh mẫu từ Generator sau khi huấn luyện
generate_and_plot_images(gen, noise_dim)