In [1]:
!git clone https://github.com/marina-shesha/ddpm_hw_clean.git

Cloning into 'ddpm_hw_clean'...
remote: Enumerating objects: 171, done.[K
remote: Counting objects: 100% (171/171), done.[K
remote: Compressing objects: 100% (118/118), done.[K
remote: Total 171 (delta 95), reused 126 (delta 50), pack-reused 0[K
Receiving objects: 100% (171/171), 24.50 MiB | 23.92 MiB/s, done.
Resolving deltas: 100% (95/95), done.


In [2]:
!pip install wandb
!pip install ml_collections

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.15.0-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.20.0-py2.py3-none-any.whl (198 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m198.8/198.8 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting GitPython!=3.1.29,>=1.0.0
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━

In [3]:
%cd ddpm_hw_clean

/content/ddpm_hw_clean


In [4]:
import torch
import wandb

from default_mnist_config import create_default_mnist_config
from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet
from data_generator import DataGenerator

from tqdm.auto import trange

import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [5]:
device = torch.device('cuda')
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
model = ResNet(**classifier_args)
model.to(device)

optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_func = torch.nn.CrossEntropyLoss()

In [6]:
datagen = DataGenerator(create_default_mnist_config())
train_generator = datagen.sample_train()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 108246966.13it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 83599512.65it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 43739628.73it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5868924.45it/s]


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



In [7]:
TOTAL_ITERS = 2_000
EVAL_FREQ = 500

### Обучите классификатор только на чистых картинках. Он понадобится нам для классификации условно сгенерированных картинок

In [8]:
def log_metric(metric_name, loader_name, value, step):
    wandb.log({f'{metric_name}/{loader_name}': value}, step=step)

In [9]:
model.train()
wandb.init(project='sde', name='clean_classifier')
for iter_idx in trange(1, 1 + TOTAL_ITERS):
    X, y = next(train_generator)
    X = X.to(device)
    y = y.to(device)
    step = iter_idx*y.shape[0]
    logits = model(X)
    pred_labels = torch.argmax(logits, dim=-1)
    loss = loss_func(logits, y)

    accuracy = (pred_labels == y).sum() / pred_labels.shape[0]
    log_metric('loss', 'train', loss.item(), step)
    log_metric('accuracy', 'train', accuracy, step)
    loss.backward()

    optim.step()
    optim.zero_grad()
    
    if iter_idx % EVAL_FREQ == 0:
        """
        validate
        """
        valid_loss = 0
        valid_accuracy = 0
        valid_count = 0
        model.eval()
        with torch.no_grad():
          for X,y in datagen.valid_loader:
              X = X.to(device)
              y = y.to(device)
              valid_count += X.shape[0]
              logits = model(X)
              pred_labels = torch.argmax(logits, dim=-1)
              loss = loss_func(logits, y)
              valid_loss += loss * X.shape[0]
              valid_accuracy += (pred_labels == y).sum()

        valid_loss = valid_loss / valid_count
        valid_accuracy = valid_accuracy / valid_count
        log_metric('loss', 'valid', valid_loss, step)
        log_metric('accuracy', 'valid', valid_accuracy, step)
        model.train()
        print('Clean MNIST classifier\'s accuracy:', valid_accuracy)
model.eval()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


  0%|          | 0/2000 [00:00<?, ?it/s]

Clean MNIST classifier's accuracy: tensor(0.9809, device='cuda:0')
Clean MNIST classifier's accuracy: tensor(0.9868, device='cuda:0')
Clean MNIST classifier's accuracy: tensor(0.9912, device='cuda:0')
Clean MNIST classifier's accuracy: tensor(0.9914, device='cuda:0')


ResNet(
  (conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResidualBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=

In [11]:
torch.save(model.state_dict(), './ddpm_checkpoints/clean_classifier.pth')