In [1]:
from PIL import Image
import os

from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
class AnimeDataset(Dataset):
    def __init__(self, path):
        files = os.listdir(path)
        self.file_list = [os.path.join(path,file) for file in files]
        self.transform = transforms.Compose(
        [
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(), 
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         ]
    )
    def __len__(self):
        return len(self.file_list)
    def __getitem__(self, i):
        img = Image.open(self.file_list[i])
        return self.transform(img)

In [9]:
from typing import Dict, Optional, Tuple
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image, make_grid

from mindiffusion.vae import VAE

def train_anime(
    n_epoch: int = 100, device: str = "cuda:0", load_pth: Optional[str] = None
) -> None:

    ###設定############
    n_feat = 64
    batch_size = 256
    downs = 3
    lr = 1e-4
    dataset_dir = "/storage/animeface/images/"
    resume = "vae1.pth"
    ###################

    vae = VAE(3,n_feat,downs)

    vae.to(device)
    
    if resume is not None:
        vae.load_state_dict(torch.load(resume))
    dataset = AnimeDataset(dataset_dir)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
    optim = torch.optim.Adam(vae.parameters(), lr=lr)

    for i in range(n_epoch):
        print(f"Epoch {i} : ")
        vae.train()

        pbar = tqdm(dataloader)
        loss_ema = None
        for x in pbar:
            optim.zero_grad()
            x = x.to(device)
            x_pred, mu, log_var, _ = vae(x)
            loss, _, _ = vae.loss_function(x,x_pred,mu,log_var)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()

        vae.eval()
        with torch.no_grad():
            xh, _, _ , _ = vae(x[:8])
            xset = torch.cat([xh, x[:8]], dim=0)
            grid = make_grid(xset, normalize=True, value_range=(-1, 1), nrow=4)
            save_image(grid, f"./contents/vae{str(i).zfill(3)}.png")

            # save model
            torch.save(vae.state_dict(), f"./vae{i%3}.pth")

In [10]:
train_anime(100)

Epoch 0 : 


loss: -1.3733: 100%|██████████| 249/249 [03:37<00:00,  1.14it/s]


Epoch 1 : 


loss: -2.7952: 100%|██████████| 249/249 [03:37<00:00,  1.14it/s]


Epoch 2 : 


loss: -3.2604: 100%|██████████| 249/249 [03:38<00:00,  1.14it/s]


Epoch 3 : 


loss: -3.7352: 100%|██████████| 249/249 [03:37<00:00,  1.14it/s]


Epoch 4 : 


loss: -3.8830: 100%|██████████| 249/249 [03:38<00:00,  1.14it/s]


Epoch 5 : 


loss: -4.1223: 100%|██████████| 249/249 [03:38<00:00,  1.14it/s]


Epoch 6 : 


loss: -4.1879: 100%|██████████| 249/249 [03:38<00:00,  1.14it/s]


Epoch 7 : 


loss: -4.1972: 100%|██████████| 249/249 [03:38<00:00,  1.14it/s]


Epoch 8 : 


loss: -4.2617:  11%|█         | 27/249 [00:25<03:27,  1.07it/s]


KeyboardInterrupt: 