In [4]:
!pip install sympy

[0m

In [1]:
from PIL import Image
import os

from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
class AnimeDataset(Dataset):
    def __init__(self, path):
        files = os.listdir(path)
        self.file_list = [os.path.join(path,file) for file in files]
        self.transform = transforms.Compose(
        [
        transforms.Resize(64),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(), 
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         ]
    )
    def __len__(self):
        return len(self.file_list)
    def __getitem__(self, i):
        img = Image.open(self.file_list[i])
        return self.transform(img)

In [2]:
from typing import Dict, Optional, Tuple
from sympy import Ci
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.utils import save_image, make_grid

from mindiffusion.unet import NaiveUnet
from mindiffusion.ddpm import DDPM

def train_anime(
    n_epoch: int = 100, device: str = "cuda:0", load_pth: Optional[str] = None
) -> None:

    ###設定############
    n_feat = 640
    batch_size = 128
    lr = 5e-5
    dataset_dir = "/storage/animeface/images/"
    ###################

    ddpm = DDPM(eps_model=NaiveUnet(3, 3, n_feat=n_feat), betas=(1e-4, 0.02), n_T=1000)

    if load_pth is not None:
        ddpm.load_state_dict(torch.load(load_pth))

    ddpm.to(device)

    dataset = AnimeDataset(dataset_dir)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
    optim = torch.optim.Adam(ddpm.parameters(), lr=lr)

    for i in range(n_epoch):
        print(f"Epoch {i} : ")
        ddpm.train()

        pbar = tqdm(dataloader)
        loss_ema = None
        for x in pbar:
            optim.zero_grad()
            x = x.to(device)
            loss = ddpm(x)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()

        ddpm.eval()
        with torch.no_grad():
            xh = ddpm.sample(8, (3, 32, 32), device)
            xset = torch.cat([xh, x[:8]], dim=0)
            grid = make_grid(xset, normalize=True, value_range=(-1, 1), nrow=4)
            save_image(grid, f"./contents/ddpm_sample_anime{str(i).zfill(3)}.png")

            # save model
            torch.save(ddpm.state_dict(), f"./ddpm_anime{i%3}.pth")

In [None]:
train_anime(100)

Epoch 0 : 


loss: 1.1019:   3%|▎         | 13/497 [00:36<20:34,  2.55s/it]