In [None]:
!pip install kaggle

In [None]:
!mkdir -p ~/.kaggle

kaggleからAPIキーを取得し、colabにアップロード

In [None]:
!mv /content/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets list

In [None]:
!kaggle datasets download -d splcher/animefacedataset

In [None]:
!unzip /content/animefacedataset.zip

In [None]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image


class AnimeFaceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        super().__init__()
        self.transform = transform
        # 例: JPGファイルをすべて取得
        self.image_paths = glob.glob(os.path.join(root_dir, "*.jpg"))
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, 0  # ラベル不要なのでダミー

############################
# ハイパーパラメータ設定
############################
batch_size = 32
image_size = 64
nz = 100           # 潜在ベクトル(ノイズ)次元
num_epochs = 20
lr = 0.0002
beta1 = 0.5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

############################
# データローダの準備
############################
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))  # -1~1に正規化
])

dataset = AnimeFaceDataset(root_dir="/content/images", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

############################
# Generator & Discriminator (DCGAN)
############################
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=3):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.main(x).view(-1)

netG = Generator(nz=nz).to(device)
netD = Discriminator().to(device)

# DCGAN推奨の重み初期化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

############################
# 損失関数 & オプティマイザ
############################
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

############################
# 学習ループ & 損失の記録
############################
lossD_list = []  # Discriminatorの損失を格納
lossG_list = []  # Generatorの損失を格納

print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        # ============================
        # (1) Discriminatorの更新
        # ============================
        netD.zero_grad()
        real_images = real_images.to(device)
        b_size = real_images.size(0)
        
        labels_real = torch.ones(b_size, device=device)
        output_real = netD(real_images)
        lossD_real = criterion(output_real, labels_real)
        lossD_real.backward()
        
        # 偽データの損失
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake_images = netG(noise)
        
        labels_fake = torch.zeros(b_size, device=device)
        output_fake = netD(fake_images.detach())
        lossD_fake = criterion(output_fake, labels_fake)
        lossD_fake.backward()
        
        optimizerD.step()
        
        # ============================
        # (2) Generatorの更新
        # ============================
        netG.zero_grad()
        labels_gen = torch.ones(b_size, device=device)
        output_gen = netD(fake_images)
        lossG = criterion(output_gen, labels_gen)
        lossG.backward()
        optimizerG.step()
        
        # ============================
        # 損失を記録
        # ============================
        lossD_val = (lossD_real + lossD_fake).item()
        lossG_val = lossG.item()
        lossD_list.append(lossD_val)
        lossG_list.append(lossG_val)
        
        if i % 200 == 0:
            print(f"[Epoch {epoch+1}/{num_epochs}] [Batch {i}/{len(dataloader)}] "
                  f"Loss_D: {lossD_val:.4f} Loss_G: {lossG_val:.4f}")

    # エポック毎に生成画像を保存
    with torch.no_grad():
        fixed_noise = torch.randn(64, nz, 1, 1, device=device)
        fake_sample = netG(fixed_noise).cpu()
    vutils.save_image(fake_sample, f"/content/epoch_{epoch+1}.png", normalize=True)
    print(f"=> Saved generated samples at /content/epoch_{epoch+1}.png")

############################
# 7) 損失の可視化 (matplotlib)
############################
import matplotlib.pyplot as plt

plt.figure()
plt.title("Training Loss")
plt.plot(lossD_list, label="Discriminator Loss")
plt.plot(lossG_list, label="Generator Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()
plt.show()