# 1. Data Load

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd drive/MyDrive/[한이음] 적대적 AI 공격에 대한 인공지능 보안기술 연구/3. 소스코드/attacked0.02_GTSRB

/content/drive/MyDrive/[한이음] 적대적 AI 공격에 대한 인공지능 보안기술 연구/3. 소스코드/attacked0.02_GTSRB


In [None]:
# pip install denoising_diffusion_pytorch

# 2. denoising_diffusion_pytorch.py

## 2.1. pip & import & constants

In [None]:
pip install einops ema_pytorch accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Collecting ema_pytorch
  Downloading ema_pytorch-0.0.9-py3-none-any.whl (4.1 kB)
Collecting accelerate
  Downloading accelerate-0.12.0-py3-none-any.whl (143 kB)
[K     |████████████████████████████████| 143 kB 60.2 MB/s 
Installing collected packages: ema-pytorch, einops, accelerate
Successfully installed accelerate-0.12.0 einops-0.4.1 ema-pytorch-0.0.9


In [None]:
import math
import copy
import torch
from torch import nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from collections import namedtuple
from functools import partial

from torch.utils.data import Dataset, DataLoader
from multiprocessing import cpu_count

from pathlib import Path
from torch.optim import Adam
from torchvision import transforms as T, utils
from PIL import Image

from einops import rearrange, reduce
from einops.layers.torch import Rearrange

from tqdm.auto import tqdm
from ema_pytorch import EMA

from accelerate import Accelerator

In [None]:
# constants
ModelPrediction =  namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])

## 2.2. helpers functions

In [None]:
# helpers functions

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def cycle(dl):
    while True:
        for data in dl:
            yield data

def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

def l2norm(t):
    return F.normalize(t, dim = -1)

## 2.3. normalization functions

In [None]:
# normalization functions
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

## 2.4. small helper modules

In [None]:
# small helper modules
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

def Upsample(dim, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )

def Downsample(dim, dim_out = None):
    return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * (var + eps).rsqrt() * self.g

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

## 2.5. sinusoidal positional embeds

In [None]:
# sinusoidal positional embeds
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class LearnedSinusoidalPosEmb(nn.Module):
    """ following @crowsonkb 's lead with learned sinusoidal pos emb """
    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, dim):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        fouriered = torch.cat((x, fouriered), dim = -1)
        return fouriered

## 2.6. building block modules

In [None]:
# building block modules
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):

        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)

        h = self.block2(h)

        return h + self.res_conv(x)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            LayerNorm(dim)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale
        v = v / (h * w)

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32, scale = 16):
        super().__init__()
        self.scale = scale
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q, k = map(l2norm, (q, k))

        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        return self.to_out(out)



## 2.7. model

In [None]:
# model
class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 3,
        resnet_block_groups = 8,
        learned_variance = False,
        learned_sinusoidal_cond = False,
        learned_sinusoidal_dim = 16
    ):
        super().__init__()

        # determine dimensions

        self.channels = channels

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups = resnet_block_groups)

        # time embeddings

        time_dim = dim * 4

        self.learned_sinusoidal_cond = learned_sinusoidal_cond

        if learned_sinusoidal_cond:
            sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim)
            fourier_dim = learned_sinusoidal_dim + 1
        else:
            sinu_pos_emb = SinusoidalPosEmb(dim)
            fourier_dim = dim

        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # layers

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Upsample(dim_out, dim_in) if not is_last else  nn.Conv2d(dim_out, dim_in, 3, padding = 1)
            ]))

        default_out_dim = channels * (1 if not learned_variance else 2)
        self.out_dim = default(out_dim, default_out_dim)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time):
        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim = 1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)

## 2.8. gaussian diffusion trainer class

In [None]:
# gaussian diffusion trainer class
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def linear_beta_schedule(timesteps):
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)

class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model,
        *,
        image_size,
        channels = 3,
        timesteps = 1000,
        sampling_timesteps = None,
        loss_type = 'l1',
        objective = 'pred_noise',
        beta_schedule = 'cosine',
        p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
        p2_loss_weight_k = 1,
        ddim_sampling_eta = 1.
    ):
        super().__init__()
        assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)

        self.channels = channels
        self.image_size = image_size
        self.model = model
        self.objective = objective

        assert objective in {'pred_noise', 'pred_x0'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start)'

        if beta_schedule == 'linear':
            betas = linear_beta_schedule(timesteps)
        elif beta_schedule == 'cosine':
            betas = cosine_beta_schedule(timesteps)
        else:
            raise ValueError(f'unknown beta schedule {beta_schedule}')

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.loss_type = loss_type

        # sampling related parameters

        self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training

        assert self.sampling_timesteps <= timesteps
        self.is_ddim_sampling = self.sampling_timesteps < timesteps
        self.ddim_sampling_eta = ddim_sampling_eta

        # helper function to register buffer from float64 to float32

        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))

        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others

        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)

        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)

        register_buffer('posterior_variance', posterior_variance)

        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain

        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

        # calculate p2 reweighting

        register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_noise_from_start(self, x_t, t, x0):
        return (
            (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def model_predictions(self, x, t):
        model_output = self.model(x, t)

        if self.objective == 'pred_noise':
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x, t, model_output)

        elif self.objective == 'pred_x0':
            pred_noise = self.predict_noise_from_start(x, t, model_output)
            x_start = model_output

        return ModelPrediction(pred_noise, x_start)

    def p_mean_variance(self, x, t, clip_denoised: bool):
        preds = self.model_predictions(x, t)
        x_start = preds.pred_x_start

        if clip_denoised:
            x_start.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
        return model_mean, posterior_variance, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, x, t: int, clip_denoised = True):
        b, *_, device = *x.shape, x.device
        batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
        model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = batched_times, clip_denoised = clip_denoised)
        noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
        return model_mean + (0.5 * model_log_variance).exp() * noise

    @torch.no_grad()
    def p_sample_loop(self, shape):
        batch, device = shape[0], self.betas.device

        img = torch.randn(shape, device=device)

        for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step'):
            img = self.p_sample(img, t)

        img = unnormalize_to_zero_to_one(img)
        return img

    @torch.no_grad()
    def ddim_sample(self, shape, clip_denoised = True):
        batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

        times = torch.linspace(0., total_timesteps, steps = sampling_timesteps + 2)[:-1]
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:]))

        img = torch.randn(shape, device = device)

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
            alpha = self.alphas_cumprod_prev[time]
            alpha_next = self.alphas_cumprod_prev[time_next]

            time_cond = torch.full((batch,), time, device = device, dtype = torch.long)

            pred_noise, x_start, *_ = self.model_predictions(img, time_cond)

            if clip_denoised:
                x_start.clamp_(-1., 1.)

            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = ((1 - alpha_next) - sigma ** 2).sqrt()

            noise = torch.randn_like(img) if time_next > 0 else 0.

            img = x_start * alpha_next.sqrt() + \
                  c * pred_noise + \
                  sigma * noise

        img = unnormalize_to_zero_to_one(img)
        return img

    @torch.no_grad()
    def sample(self, batch_size = 16):
        image_size, channels = self.image_size, self.channels
        sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
        return sample_fn((batch_size, channels, image_size, image_size))

    @torch.no_grad()
    def interpolate(self, x1, x2, t = None, lam = 0.5):
        b, *_, device = *x1.shape, x1.device
        t = default(t, self.num_timesteps - 1)

        assert x1.shape == x2.shape

        t_batched = torch.stack([torch.tensor(t, device=device)] * b)
        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))

        img = (1 - lam) * xt1 + lam * xt2
        for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))

        return img

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    @property
    def loss_fn(self):
        if self.loss_type == 'l1':
            return F.l1_loss
        elif self.loss_type == 'l2':
            return F.mse_loss
        else:
            raise ValueError(f'invalid loss type {self.loss_type}')

    def p_losses(self, x_start, t, noise = None):
        b, c, h, w = x_start.shape
        noise = default(noise, lambda: torch.randn_like(x_start))

        x = self.q_sample(x_start = x_start, t = t, noise = noise)
        model_out = self.model(x, t)

        if self.objective == 'pred_noise':
            target = noise
        elif self.objective == 'pred_x0':
            target = x_start
        else:
            raise ValueError(f'unknown objective {self.objective}')

        loss = self.loss_fn(model_out, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b (...)', 'mean')

        loss = loss * extract(self.p2_loss_weight, t, loss.shape)
        return loss.mean()

    def forward(self, img, *args, **kwargs):
        b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

        img = normalize_to_neg_one_to_one(img)
        return self.p_losses(img, t, *args, **kwargs)

## 2.9. dataset classes

In [None]:
# dataset classes
class Dataset(Dataset):
    def __init__(
        self,
        folder,
        image_size,
        exts = ['jpg', 'jpeg', 'png', 'tiff'],
        augment_horizontal_flip = False,
        convert_image_to = None
    ):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        ########################################################################
        print(f"dataset path 몇 가지만 프린트:{self.paths[0:5]}")
        print(f"self.paths의 길이 : {len(self.paths)}")
        ########################################################################

        maybe_convert_fn = partial(convert_image_to, convert_image_to) if exists(convert_image_to) else nn.Identity()

        self.transform = T.Compose([
            T.Lambda(maybe_convert_fn),
            T.Resize(image_size),
            T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

In [None]:
Dataset('./original', 32, augment_horizontal_flip=True, convert_image_to=None)

dataset path 몇 가지만 프린트:[PosixPath('original/4668.png'), PosixPath('original/4669.png'), PosixPath('original/4670.png'), PosixPath('original/4671.png'), PosixPath('original/4672.png')]
self.paths의 길이 : 5668


<__main__.Dataset at 0x7fc73b4a74d0>

## 2.10. trainer class

In [None]:
# trainer class
class Trainer(object):
    def __init__(
        self,
        diffusion_model,
        folder,
        *,
        train_batch_size = 16,
        gradient_accumulate_every = 1,
        augment_horizontal_flip = True,
        train_lr = 1e-4,
        train_num_steps = 100000,
        ema_update_every = 10,
        ema_decay = 0.995,
        adam_betas = (0.9, 0.99),
        save_and_sample_every = 100, ########################################### 원래 1000
        num_samples = 25,
        # 결과 저장하는 디렉토리 path ################################################
        results_folder = './DDPMresults',
        ######################################################################## 
        amp = False,
        fp16 = False,
        split_batches = True,
        convert_image_to = None
    ):
        super().__init__()

        self.accelerator = Accelerator(
            split_batches = split_batches,
            mixed_precision = 'fp16' if fp16 else 'no'
        )

        self.accelerator.native_amp = amp

        self.model = diffusion_model

        assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
        self.num_samples = num_samples
        self.save_and_sample_every = save_and_sample_every

        self.batch_size = train_batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.train_num_steps = train_num_steps
        self.image_size = diffusion_model.image_size

        # dataset and dataloader

        self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
        dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

        ########################################################################
        print(f"type of dl : {type(dl)}")
        ########################################################################

        dl = self.accelerator.prepare(dl)
        self.dl = cycle(dl)

        # optimizer

        self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)

        # for logging results in a folder periodically

        if self.accelerator.is_main_process:
            self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)

            self.results_folder = Path(results_folder)
            self.results_folder.mkdir(exist_ok = True)

        # step counter state

        self.step = 0

        # prepare model, dataloader, optimizer with accelerator

        self.model, self.opt = self.accelerator.prepare(self.model, self.opt)

    def save(self, milestone):
        if not self.accelerator.is_local_main_process:
            return

        data = {
            'step': self.step,
            'model': self.accelerator.get_state_dict(self.model),
            'opt': self.opt.state_dict(),
            'ema': self.ema.state_dict(),
            'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None
        }

        # 용량 많이 차지해서 주석 처리 해 놓았음 ########################################
        # torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
        ########################################################################

    def load(self, milestone):
        data = torch.load(str(self.results_folder / f'model-{milestone}.pt'))

        model = self.accelerator.unwrap_model(self.model)
        model.load_state_dict(data['model'])

        self.step = data['step']
        self.opt.load_state_dict(data['opt'])
        self.ema.load_state_dict(data['ema'])

        if exists(self.accelerator.scaler) and exists(data['scaler']):
            self.accelerator.scaler.load_state_dict(data['scaler'])

    def train(self):
        accelerator = self.accelerator
        device = accelerator.device

        with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:

            while self.step < self.train_num_steps:

                total_loss = 0.

                for _ in range(self.gradient_accumulate_every):
                    data = next(self.dl).to(device)

                    with self.accelerator.autocast():
                        loss = self.model(data)
                        loss = loss / self.gradient_accumulate_every
                        total_loss += loss.item()

                    self.accelerator.backward(loss)

                pbar.set_description(f'loss: {total_loss:.4f}')

                accelerator.wait_for_everyone()

                self.opt.step()
                self.opt.zero_grad()

                accelerator.wait_for_everyone()

                if accelerator.is_main_process:  
                    self.ema.to(device)
                    self.ema.update()

                    if self.step != 0 and self.step % self.save_and_sample_every == 0:
                        self.ema.ema_model.eval()

                        with torch.no_grad():
                            milestone = self.step // self.save_and_sample_every
                            batches = num_to_groups(self.num_samples, self.batch_size)
                            all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
                            ##############################################################################
                            # print(f"all_images_list 사이즈 : {all_images_list.size()}")
                            ##############################################################################

                        all_images = torch.cat(all_images_list, dim = 0)
                        ##############################################################################
                        print(f"all_images 사이즈 : {len(all_images)}")
                        ##############################################################################


                        utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))
                        self.save(milestone)

                self.step += 1
                pbar.update(1)
        print(f"self.step:{self.step}")

        accelerator.print('training complete')

# 3. 실행

In [None]:
# from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer : 여기를 아예 코드를 복붙해놨으니까 필요 없음

In [None]:
model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
).cuda() # .cuda() : gpu에 텐서를 올리는 코드

In [None]:
diffusion = GaussianDiffusion(
    model,
    image_size = 32, ## 128 -> 32로 수정했음
    timesteps = 1000,           # number of steps
    sampling_timesteps = 250,   # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
    loss_type = 'l1'            # L1 or L2
).cuda()

In [None]:
trainer = Trainer(
    diffusion,
    './original',
    train_batch_size = 32,
    train_lr = 8e-5,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True                        # turn on mixed precision
)

dataset path 몇 가지만 프린트:[PosixPath('original/4668.png'), PosixPath('original/4669.png'), PosixPath('original/4670.png'), PosixPath('original/4671.png'), PosixPath('original/4672.png')]
self.paths의 길이 : 5668
type of dl : <class 'torch.utils.data.dataloader.DataLoader'>


In [None]:
trainer.train()

  0%|          | 0/700000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

all_images 사이즈 : 25


sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

In [None]:
result = trainer.train()

#  해야할 것
* diffpure
* 모델 저장 어떻게 하는지?
* 코드 정확히 이해할 것
* fgsm으로 돌려보고 classifier 넣어봐서 성능 보기. 그냥 diffusion 만으로도 fgsm 제거 되는지.확인 필요.