In [1]:
!git clone https://github.com/marina-shesha/ddpm_hw_clean.git

Cloning into 'ddpm_hw_clean'...
remote: Enumerating objects: 171, done.[K
remote: Counting objects: 100% (171/171), done.[K
remote: Compressing objects: 100% (118/118), done.[K
remote: Total 171 (delta 95), reused 126 (delta 50), pack-reused 0[K
Receiving objects: 100% (171/171), 24.50 MiB | 18.52 MiB/s, done.
Resolving deltas: 100% (95/95), done.


In [2]:
%cd ddpm_hw_clean

/content/ddpm_hw_clean


In [3]:
!pip install wandb
!pip install ml_collections
!pip install pytorch_fid

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.15.0-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m77.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.20.0-py2.py3-none-any.whl (198 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m198.8/198.8 kB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting GitPython!=3.1.29,>=1.0.0
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting setproctitle
  Downloading setproctitle

In [4]:
import torch
import numpy as np

from skimage.io import imread, imsave
from tqdm.auto import trange, tqdm
from torchvision.datasets import MNIST
from pytorch_fid import fid_score

from data_generator import DataGenerator
from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet

from matplotlib import pyplot as plt

from torchvision.transforms import Compose, Resize

import os

#os.environ['CUDA_VISIBLE_DEVICES'] = '1'

#### Определим папку с настоящими картинками

In [5]:
def create_dir(path: str):
    if not os.path.exists(path):
        os.makedirs(path)

In [17]:
create_dir('../real_images_MNIST')

real_dataset = MNIST(root='../data', download=True, train=True, transform=Compose([Resize((32, 32))]))
for idx, (image_mnist, label) in enumerate(tqdm(real_dataset, total=len(real_dataset))):
    image = np.array(image_mnist)
    imsave("../real_images_MNIST/{}.png".format(idx), image)

  0%|          | 0/60000 [00:00<?, ?it/s]

#### Определим папку для синтетических картинок и сгенерируем 60к картинок

In [8]:
device = torch.device('cuda')
uncond_diff = DiffusionRunner(create_default_mnist_config(), eval=True)
uncond_diff.model.eval();
uncond_diff.model.to(device);

In [16]:
create_dir('../uncond_mnist')

TOTAL_IMAGES_COUNT = 60_000
BATCH_SIZE = 200
NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE

global_idx = 0
for idx in trange(NUM_ITERS):
    images: torch.Tensor = uncond_diff.sample_images(batch_size=BATCH_SIZE).cpu()
    images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

    for i in range(len(images)):
        imsave(os.path.join('../uncond_mnist', f'{global_idx}.png'), images[i])
        global_idx += 1

  0%|          | 0/300 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [11]:
fid_value = fid_score.calculate_fid_given_paths(
    paths=['../real_images_MNIST', '../uncond_mnist'],
    batch_size=200,
    device=device,
    dims=2048
)
fid_value

Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 385MB/s]
100%|██████████| 300/300 [00:29<00:00, 10.19it/s]
100%|██████████| 300/300 [00:29<00:00, 10.18it/s]


15.623740579105885

In [21]:
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
noisy_classifier = ConditionalResNet(**classifier_args)
noisy_classifier.to(device)

noisy_classifier.load_state_dict(torch.load('./ddpm_checkpoints/classifier.pth'))

noisy_classifier.eval();

conditional_diffusion = DiffusionRunner(create_default_mnist_config(), eval=True)
conditional_diffusion.set_classifier(noisy_classifier, T=1.)

conditional_diffusion.model.eval();
conditional_diffusion.model.to(device)

global_idx = 0
create_dir(f'../cond_MNIST')
for i in range(10):
  TOTAL_IMAGES_COUNT = 6000
  BATCH_SIZE = 200
  NUM_ITERS = TOTAL_IMAGES_COUNT // BATCH_SIZE
  labels = i*torch.ones(BATCH_SIZE).long().to(device)
  for idx in trange(NUM_ITERS):
      images: torch.Tensor = conditional_diffusion.sample_images(batch_size=BATCH_SIZE, labels=labels).cpu()
      images = images.permute(0, 2, 3, 1).data.numpy().astype(np.uint8)

      for j in range(len(images)):
          imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])
          global_idx += 1

  0%|          | 0/30 [00:00<?, ?it/s]

  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])


  0%|          | 0/30 [00:00<?, ?it/s]

  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])
  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])
  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])


  0%|          | 0/30 [00:00<?, ?it/s]

  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])


  0%|          | 0/30 [00:00<?, ?it/s]

  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])


  0%|          | 0/30 [00:00<?, ?it/s]

  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])
  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])


  0%|          | 0/30 [00:00<?, ?it/s]

  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])
  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])


  0%|          | 0/30 [00:00<?, ?it/s]

  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])
  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])


  0%|          | 0/30 [00:00<?, ?it/s]

  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])
  imsave(os.path.join(f'../cond_MNIST', f'{global_idx}.png'), images[j])


In [22]:
fid_value = fid_score.calculate_fid_given_paths(
    paths=['../real_images_MNIST', '../cond_MNIST'],
    batch_size=200,
    device=device,
    dims=2048
)
fid_value

100%|██████████| 300/300 [00:29<00:00, 10.17it/s]
100%|██████████| 300/300 [00:29<00:00, 10.23it/s]


15.756938835486238

> Какой фид получился? Сравните FID для безусловной генерации и для условной. Сгенерируйте для каждого класса по 6к картинок и посчитайте FID между реальными и условно сгенерированными картинками.

На сколько я понимаю мы получили вполне приличный фид, при этом для условной и безусловной генерации он примерно одинаковый, что означает, что наш обученный классификатор неплох.
