<a href="https://colab.research.google.com/github/bochendong/diffusion-model/blob/main/%E2%80%9CUntitled12_ipynb%E2%80%9D%E7%9A%84%E5%89%AF%E6%9C%AC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Imports

from contextlib import contextmanager
from copy import deepcopy
import math
import os
import glob
from IPython import display
from matplotlib import pyplot as plt
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import datasets, transforms, utils, models
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm, trange
from torchvision.utils import save_image
import numpy as np

In [None]:
batch_size = 100
epoches = 30

# Actually train the model
ema_decay = 0.999

# The number of timesteps to use when sampling
steps = 500

# The amount of noise to add each timestep when sampling
eta = 1.

# Classifier-free guidance scale (0 is unconditional, 1 is conditional)
guidance_scale = 2.


tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

train_set = datasets.CIFAR10('data', train=True, download=True, transform=tf)
train_dl = data.DataLoader(train_set, batch_size, shuffle=False, num_workers=4, persistent_workers=True, pin_memory=True)

val_set = datasets.CIFAR10('data', train=False, download=True, transform=tf)
val_dl = data.DataLoader(val_set, batch_size, num_workers=4, persistent_workers=True, pin_memory=True)

cifar_10_classes = {
    0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer",
    5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck",
}

Files already downloaded and verified
Files already downloaded and verified


In [None]:
if (os.path.exists("./output")) == False:
    os.mkdir("output")

files = glob.glob("./output/*.png")

for f in files:
    os.remove(f)

In [None]:
def get_alphas_sigmas(t):
    """
    Returns the scaling factors for the clean image (alpha) and for the
    noise (sigma), given a timestep.
    """
    return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, main, skip=None):
        super().__init__()
        self.main = nn.Sequential(*main)
        self.skip = skip if skip else nn.Identity()

    def forward(self, input):
        return self.main(input) + self.skip(input)


class ResConvBlock(ResidualBlock):
    def __init__(self, c_in, c_mid, c_out, is_last=False):
        skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False)
        super().__init__([
            nn.Conv2d(c_in, c_mid, 3, padding=1),
            nn.Dropout2d(0.1, inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(c_mid, c_out, 3, padding=1),
            nn.Dropout2d(0.1, inplace=True) if not is_last else nn.Identity(),
            nn.ReLU(inplace=True) if not is_last else nn.Identity(),
        ], skip)


class SelfAttention2d(nn.Module):
    def __init__(self, c_in, n_head=1, dropout_rate=0.1):
        super().__init__()
        assert c_in % n_head == 0
        self.norm = nn.GroupNorm(1, c_in)
        self.n_head = n_head
        self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
        self.out_proj = nn.Conv2d(c_in, c_in, 1)
        self.dropout = nn.Dropout2d(dropout_rate, inplace=True)

    def forward(self, input):
        n, c, h, w = input.shape
        qkv = self.qkv_proj(self.norm(input))
        qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
        q, k, v = qkv.chunk(3, dim=1)
        scale = k.shape[3]**-0.25
        att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
        y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
        return input + self.dropout(self.out_proj(y))


class FourierFeatures(nn.Module):
    def __init__(self, in_features, out_features, std=1.):
        super().__init__()
        assert out_features % 2 == 0
        self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)

    def forward(self, input):
        f = 2 * math.pi * input @ self.weight.T
        return torch.cat([f.cos(), f.sin()], dim=-1)


def expand_to_planes(input, shape):
    return input[..., None, None].repeat([1, 1, shape[2], shape[3]])



class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        c = 64  # The base channel count

        self.timestep_embed = FourierFeatures(1, 16)
        self.class_embed = nn.Embedding(11, 4)
        self.skip = nn.Identity()
        self.block_1 = nn.Sequential(                   # 32x32
            ResConvBlock(3 + 16 + 4, 64, 64),
            ResConvBlock(64, 64, 64)
        )

        self.block_2 = nn.Sequential(
            nn.AvgPool2d(2),                            # 32x32 -> 16x16
            ResConvBlock(64, 128, 128),
            ResConvBlock(128, 128, 128)
        )


        self.block_3 = nn.Sequential(
            nn.AvgPool2d(2),                            # 16x16 -> 8x8
            ResConvBlock(128, 256, 256),
            SelfAttention2d(256, 256 // 64),
            ResConvBlock(256, 256, 256),
            SelfAttention2d(256, 256 // 64),
        )

        self.block_4 =  nn.Sequential(
            nn.AvgPool2d(2),                            # 8x8 -> 4x4
            ResConvBlock(256, 512, 512),
            SelfAttention2d(512, 512 // 64),
            ResConvBlock(512, 512, 512),
            SelfAttention2d(512, 512// 64),
            ResConvBlock(512, 512, 512),
            SelfAttention2d(512, 512 // 64),
            ResConvBlock(512, 512, 256),
            SelfAttention2d(256, 256// 64),
            nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False),
        )

        self.block_5 = nn.Sequential(                   # 4x4 -> 8x8
            ResConvBlock(512, 256, 256),
            SelfAttention2d(256, 256 // 64),
            ResConvBlock(256, 256, 128),
            SelfAttention2d(128, 128// 64),
            nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False),
        )

        self.block_6 = nn.Sequential(                   # 8x8 -> 16x16
            ResConvBlock(256, 128, 128),
            ResConvBlock(128, 128, 64),
            nn.Upsample(scale_factor = 2, mode='bilinear', align_corners=False),
        )

        self.block_7 = nn.Sequential(
            ResConvBlock(128, 64, 64),
            ResConvBlock(64, 64, 3, is_last=True)
        )
    def forward(self, input, t, cond):
        timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
        class_embed = expand_to_planes(self.class_embed(cond + 1), input.shape)

        x = torch.cat([input, class_embed, timestep_embed], dim=1)

        x_1 = self.block_1(x)                                # 32x32

        x_2 = self.block_2(x_1)                              # 16x16

        x_3 = self.block_3(x_2)                              # 8x8

        x_4 = self.block_4(x_3)                              # 4x4 -> 8x8
        x_4 = torch.cat([x_4, self.skip(x_3)], dim=1)

        x_5 = self.block_5(x_4)                              # 8x8 -> 16x16
        x_5 = torch.cat([x_5, self.skip(x_2)], dim=1)

        x_6 = self.block_6(x_5)                              # 16x16 -> 32 * 32
        x_6 = torch.cat([x_6, self.skip(x_1)], dim=1)

        x_7 = self.block_7(x_6)

        return x_7


In [None]:
seed = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(0)

model = Diffusion().to(device)
model_ema = deepcopy(model)

opt = optim.Adam(model.parameters(), lr=2e-4)
scaler = torch.cuda.amp.GradScaler()
rng = torch.quasirandom.SobolEngine(1, scramble=True)

Using device: cuda


In [None]:
@torch.no_grad()
def sample(model, x, steps, eta, classes, guidance_scale=1.):
    """
    Draws samples from a model given starting noise.
    """
    ts = x.new_ones([x.shape[0]])

    # Create the noise schedule
    t = torch.linspace(1, 0, steps + 1)[:-1]
    alphas, sigmas = get_alphas_sigmas(t)

    # The sampling loop
    for i in trange(steps):

        # Get the model output (v, the predicted velocity)
        with torch.cuda.amp.autocast():
            x_in = torch.cat([x, x])
            ts_in = torch.cat([ts, ts])
            classes_in = torch.cat([-torch.ones_like(classes), classes])
            v_uncond, v_cond = model(x_in, ts_in * t[i], classes_in).float().chunk(2)
        v = v_uncond + guidance_scale * (v_cond - v_uncond)

        # Predict the noise and the denoised image
        pred = x * alphas[i] - v * sigmas[i]
        eps = x * sigmas[i] + v * alphas[i]

        # If we are not on the last timestep, compute the noisy image for the
        # next timestep.
        if i < steps - 1:
            # If eta > 0, adjust the scaling factor for the predicted noise
            # downward according to the amount of additional noise to add
            ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
            adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()

            # Recombine the predicted noise and predicted denoised image in the
            # correct proportions for the next step
            x = pred * alphas[i + 1] + eps * adjusted_sigma

            # Add the correct amount of fresh noise
            if eta:
                x += torch.randn_like(x) * ddim_sigma

    # If we are on the last timestep, output the denoised image
    return pred

In [None]:
def extract_features(images, pre_trian_model):
    with torch.no_grad():
        features = pre_trian_model(images)
    return features.view(images.size(0), -1)

def generate_synthetic_images(dataset, label, count):
    target_array = np.array(dataset.targets)
    image_indices = np.where(target_array == label)[0]

    if len(image_indices) == 0:
        raise ValueError(f"No samples found for label {label}")

    selected_indices = np.random.choice(image_indices, count, replace=True)
    synthetic_images = torch.stack([dataset[i][0] for i in selected_indices])
    return synthetic_images

def build_features_dict(dataset, pre_trian_model, BATCH_SIZE, device):
    features_dict = {}
    for label in range(10):
        synthetic_images = generate_synthetic_images(dataset, label, BATCH_SIZE)
        synthetic_images = synthetic_images.to(device)
        synthetic_features = extract_features(synthetic_images, pre_trian_model)
        features_dict[label] = {'images': synthetic_images, 'features': synthetic_features}
    return features_dict

def find_nearest_neighbor(images, target_label, dataset, pre_trian_vgg, device):
    # Generate a new batch of synthetic images every time
    synthetic_images = generate_synthetic_images(dataset, target_label.item(), batch_size)
    synthetic_images = synthetic_images.to(device)
    target_features = extract_features(synthetic_images, pre_trian_vgg)

    # Extract features for input images
    input_features = extract_features(images, pre_trian_vgg)

    # Compute distances between input_features and target_features
    input_features_expanded = input_features.unsqueeze(1)
    target_features_expanded = target_features.unsqueeze(0)

    # Compute distances between input images and target_images
    distances = (input_features_expanded - target_features_expanded).norm(dim=2)

    # Find the index of the input image with the smallest distance to the selected target_image
    min_distances, min_indices = distances.min(dim=1)
    closest_input_image_index = min_indices[min_distances.argmin()]

    return synthetic_images[min_indices[closest_input_image_index]]

In [None]:
@torch.no_grad()
def generate_nearest_neighbor(dataset, data_loader, model, device):
    arr = []
    for i, (img, labels) in enumerate(tqdm(data_loader)):
        img, labels = img.to(device), labels.to(device)
        target_labels = torch.arange(10, device=device).repeat_interleave(10, 0)

        target_image = img.clone()
        for j in range(labels.size(0)):
            target_image[j] = find_nearest_neighbor(img, target_labels[j], dataset, model, device)

        arr.append(target_image)

    return arr

In [None]:
pre_trian_vgg = models.vgg16(pretrained=True).features.eval().to(device)
nearest_neighbor = generate_nearest_neighbor(train_set, train_dl, pre_trian_vgg, device)

In [None]:
@torch.no_grad()
def save_image(imgs, targets, epoch, step):
    t = torch.ones(32).to(device) * 0.8
    alphas, sigmas = get_alphas_sigmas(t)
    noise = torch.randn_like(imgs)
    noised_imgs = imgs * alphas + noise * sigmas

    fakes_classes = torch.arange(10, device=device).repeat_interleave(10, 0)
    fakes = sample(model, noised_imgs, steps, eta, fakes_classes, guidance_scale)

    fakes = (fakes + 1) / 2
    fakes = torch.clamp(fakes, min=0, max = 1)

    imgs = (imgs + 1) / 2
    imgs = torch.clamp(imgs, min=0, max = 1)

    targets = (targets + 1) / 2
    targets = torch.clamp(targets, min=0, max = 1)

    fig, axs = plt.subplots(10, 10, figsize=(15, 20))
    for c in range (10):
        for r in range (10):
            axs[c][r].imshow(fakes[c * 10 + r].permute(1, 2, 0).to("cpu").numpy())
            if (r == 0):
                axs[c][r].title.set_text(cifar_10_classes[c])

    fig.savefig('./output/%03d_%03d_fake.png' % (epoch, step))
    plt.close()

    fig, axs = plt.subplots(10, 10, figsize=(15, 20))
    for c in range (10):
        for r in range (10):
            axs[c][r].imshow(imgs[c * 10 + r].permute(1, 2, 0).to("cpu").numpy())

    fig.savefig('./output/%03d_%03d_real.png' % (epoch, step))
    plt.close()

    fig, axs = plt.subplots(10, 10, figsize=(15, 20))
    for c in range (10):
        for r in range (10):
            axs[c][r].imshow(targets[c * 10 + r].permute(1, 2, 0).to("cpu").numpy())

    fig.savefig('./output/%03d_%03d_target.png' % (epoch, step))
    plt.close()

In [None]:
def train(epoch):
    for i, (imgs, labels) in enumerate(tqdm(train_dl)):
        opt.zero_grad()
        imgs, labels = imgs.to(device), labels.to(device)

        # Draw uniformly distributed continuous timesteps
        t = rng.draw(imgs.shape[0])[:, 0].to(device)

        # Calculate the noise schedule parameters for those timesteps
        alphas, sigmas = get_alphas_sigmas(t)

        # Combine the ground truth images and the noise
        alphas, sigmas  = alphas[:, None, None, None], sigmas[:, None, None, None]
        noise = torch.randn_like(imgs)

        # with t increase, alpha goes down and sigmas goes up
        # which means with t increase, noised_img will have more noise and less image
        noised_imgs = imgs * alphas + noise * sigmas

        # New added
        with torch.no_grad():
            target_image = nearest_neighbor[i]
            target_labels = torch.arange(10, device=device).repeat_interleave(10, 0)

        # with t increase, alpha goes down and sigmas goes up
        # which means with t increase, targets will closer to - targets

        targets = noise * alphas - target_image * sigmas

        '''
        # Drop out the class on 20% of the examples
        to_drop = torch.rand(labels.shape, device=labels.device).le(0.2)
        labels_drop = torch.where(to_drop, -torch.ones_like(labels), labels)
        '''

        # Compute the model output and the loss.
        with torch.cuda.amp.autocast():
            v = model(noised_imgs, t, target_labels)
            loss = F.mse_loss(v, targets)

        # Do the optimizer step and EMA update
        scaler.scale(loss).backward()
        scaler.step(opt)

        # ema_update(model, model_ema, 0.95 if epoch < 20 else ema_decay)
        scaler.update()
        with torch.no_grad():
            if i % 50 == 0:
                tqdm.write(f'Epoch: {epoch}, iteration: {i}, loss: {loss.item():g}')

            if i % 200 == 0:
                save_image(imgs, target_image, epoch, i)


In [None]:
try:
    i = 0
    while True:
        i += 1
        train(i)
except KeyboardInterrupt:
    pass

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, iteration: 0, loss: 0.639112


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, iteration: 50, loss: 0.325639
Epoch: 1, iteration: 100, loss: 0.23052
Epoch: 1, iteration: 150, loss: 0.269523
Epoch: 1, iteration: 200, loss: 0.230026


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, iteration: 250, loss: 0.219097
Epoch: 1, iteration: 300, loss: 0.240443
Epoch: 1, iteration: 350, loss: 0.216763
Epoch: 1, iteration: 400, loss: 0.217098


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1, iteration: 450, loss: 0.210034


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 2, iteration: 0, loss: 0.215162


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 2, iteration: 50, loss: 0.201808
Epoch: 2, iteration: 100, loss: 0.19382
Epoch: 2, iteration: 150, loss: 0.239814
Epoch: 2, iteration: 200, loss: 0.215531


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 2, iteration: 250, loss: 0.199768
Epoch: 2, iteration: 300, loss: 0.214811
Epoch: 2, iteration: 350, loss: 0.191664
Epoch: 2, iteration: 400, loss: 0.193257


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 2, iteration: 450, loss: 0.200713


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 3, iteration: 0, loss: 0.197201


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 3, iteration: 50, loss: 0.199351
Epoch: 3, iteration: 100, loss: 0.182182
Epoch: 3, iteration: 150, loss: 0.233201
Epoch: 3, iteration: 200, loss: 0.199819


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 3, iteration: 250, loss: 0.186755
Epoch: 3, iteration: 300, loss: 0.204473
Epoch: 3, iteration: 350, loss: 0.195517
Epoch: 3, iteration: 400, loss: 0.194255


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 3, iteration: 450, loss: 0.200657


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 4, iteration: 0, loss: 0.198116


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 4, iteration: 50, loss: 0.195892
Epoch: 4, iteration: 100, loss: 0.177049
Epoch: 4, iteration: 150, loss: 0.21525
Epoch: 4, iteration: 200, loss: 0.190704


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 4, iteration: 250, loss: 0.181123
Epoch: 4, iteration: 300, loss: 0.200498
Epoch: 4, iteration: 350, loss: 0.187362
Epoch: 4, iteration: 400, loss: 0.196207


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 4, iteration: 450, loss: 0.190229


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 5, iteration: 0, loss: 0.194052


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 5, iteration: 50, loss: 0.186765
Epoch: 5, iteration: 100, loss: 0.173807
Epoch: 5, iteration: 150, loss: 0.218617
Epoch: 5, iteration: 200, loss: 0.197386


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 5, iteration: 250, loss: 0.180376
Epoch: 5, iteration: 300, loss: 0.208493
Epoch: 5, iteration: 350, loss: 0.187955
Epoch: 5, iteration: 400, loss: 0.187599


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 5, iteration: 450, loss: 0.183585


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 6, iteration: 0, loss: 0.18729


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 6, iteration: 50, loss: 0.185407
Epoch: 6, iteration: 100, loss: 0.167035
Epoch: 6, iteration: 150, loss: 0.222346
Epoch: 6, iteration: 200, loss: 0.194611


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 6, iteration: 250, loss: 0.184077
Epoch: 6, iteration: 300, loss: 0.19495
Epoch: 6, iteration: 350, loss: 0.183257
Epoch: 6, iteration: 400, loss: 0.185387


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 6, iteration: 450, loss: 0.1929


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 7, iteration: 0, loss: 0.186232


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 7, iteration: 50, loss: 0.189119
Epoch: 7, iteration: 100, loss: 0.171711
Epoch: 7, iteration: 150, loss: 0.218847
Epoch: 7, iteration: 200, loss: 0.186774


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 7, iteration: 250, loss: 0.170118
Epoch: 7, iteration: 300, loss: 0.192095
Epoch: 7, iteration: 350, loss: 0.1806
Epoch: 7, iteration: 400, loss: 0.187235


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 7, iteration: 450, loss: 0.185476


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 8, iteration: 0, loss: 0.195676


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 8, iteration: 50, loss: 0.1838
Epoch: 8, iteration: 100, loss: 0.173493
Epoch: 8, iteration: 150, loss: 0.210627
Epoch: 8, iteration: 200, loss: 0.18701


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 8, iteration: 250, loss: 0.176671
Epoch: 8, iteration: 300, loss: 0.198406
Epoch: 8, iteration: 350, loss: 0.179646
Epoch: 8, iteration: 400, loss: 0.191742


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 8, iteration: 450, loss: 0.184664


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 9, iteration: 0, loss: 0.18499


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 9, iteration: 50, loss: 0.174006
Epoch: 9, iteration: 100, loss: 0.162154
Epoch: 9, iteration: 150, loss: 0.213001
Epoch: 9, iteration: 200, loss: 0.185564


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 9, iteration: 250, loss: 0.179072
Epoch: 9, iteration: 300, loss: 0.195123
Epoch: 9, iteration: 350, loss: 0.186098
Epoch: 9, iteration: 400, loss: 0.184821


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 9, iteration: 450, loss: 0.185968


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 10, iteration: 0, loss: 0.186276


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 10, iteration: 50, loss: 0.179786
Epoch: 10, iteration: 100, loss: 0.165015
Epoch: 10, iteration: 150, loss: 0.218414
Epoch: 10, iteration: 200, loss: 0.187189


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 10, iteration: 250, loss: 0.182493
Epoch: 10, iteration: 300, loss: 0.188537
Epoch: 10, iteration: 350, loss: 0.173693
Epoch: 10, iteration: 400, loss: 0.178078


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 10, iteration: 450, loss: 0.180698


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 11, iteration: 0, loss: 0.18093


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 11, iteration: 50, loss: 0.179435
Epoch: 11, iteration: 100, loss: 0.174203
Epoch: 11, iteration: 150, loss: 0.209931
Epoch: 11, iteration: 200, loss: 0.19083


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 11, iteration: 250, loss: 0.168665
Epoch: 11, iteration: 300, loss: 0.193496
Epoch: 11, iteration: 350, loss: 0.176829
Epoch: 11, iteration: 400, loss: 0.186412


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 11, iteration: 450, loss: 0.185167


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 12, iteration: 0, loss: 0.192857


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 12, iteration: 50, loss: 0.178681
Epoch: 12, iteration: 100, loss: 0.168151
Epoch: 12, iteration: 150, loss: 0.203759
Epoch: 12, iteration: 200, loss: 0.179511


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 12, iteration: 250, loss: 0.174846
Epoch: 12, iteration: 300, loss: 0.188093
Epoch: 12, iteration: 350, loss: 0.181438
Epoch: 12, iteration: 400, loss: 0.185659


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 12, iteration: 450, loss: 0.185928


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 13, iteration: 0, loss: 0.182659


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 13, iteration: 50, loss: 0.180659
Epoch: 13, iteration: 100, loss: 0.16342
Epoch: 13, iteration: 150, loss: 0.21362
Epoch: 13, iteration: 200, loss: 0.184443


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 13, iteration: 250, loss: 0.174532
Epoch: 13, iteration: 300, loss: 0.196528
Epoch: 13, iteration: 350, loss: 0.185077
Epoch: 13, iteration: 400, loss: 0.180165


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 13, iteration: 450, loss: 0.176232


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 14, iteration: 0, loss: 0.179794


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 14, iteration: 50, loss: 0.175239
Epoch: 14, iteration: 100, loss: 0.164637
Epoch: 14, iteration: 150, loss: 0.211846
Epoch: 14, iteration: 200, loss: 0.190142


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 14, iteration: 250, loss: 0.175944
Epoch: 14, iteration: 300, loss: 0.194714
Epoch: 14, iteration: 350, loss: 0.172731
Epoch: 14, iteration: 400, loss: 0.181982


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 14, iteration: 450, loss: 0.180168


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 15, iteration: 0, loss: 0.184607


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 15, iteration: 50, loss: 0.172814
Epoch: 15, iteration: 100, loss: 0.169357
Epoch: 15, iteration: 150, loss: 0.209267
Epoch: 15, iteration: 200, loss: 0.183364


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 15, iteration: 250, loss: 0.162863
Epoch: 15, iteration: 300, loss: 0.184924
Epoch: 15, iteration: 350, loss: 0.176695
Epoch: 15, iteration: 400, loss: 0.181449


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 15, iteration: 450, loss: 0.184248


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 16, iteration: 0, loss: 0.18808


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 16, iteration: 50, loss: 0.18101
Epoch: 16, iteration: 100, loss: 0.163587
Epoch: 16, iteration: 150, loss: 0.209076
Epoch: 16, iteration: 200, loss: 0.179851


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 16, iteration: 250, loss: 0.173872
Epoch: 16, iteration: 300, loss: 0.188638
Epoch: 16, iteration: 350, loss: 0.179286
Epoch: 16, iteration: 400, loss: 0.180615


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 16, iteration: 450, loss: 0.18297


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 17, iteration: 0, loss: 0.179337


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 17, iteration: 50, loss: 0.169925
Epoch: 17, iteration: 100, loss: 0.157545
Epoch: 17, iteration: 150, loss: 0.20394
Epoch: 17, iteration: 200, loss: 0.182811


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 17, iteration: 250, loss: 0.1686
Epoch: 17, iteration: 300, loss: 0.195318
Epoch: 17, iteration: 350, loss: 0.176926
Epoch: 17, iteration: 400, loss: 0.183424


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 17, iteration: 450, loss: 0.179282


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 18, iteration: 0, loss: 0.185835


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 18, iteration: 50, loss: 0.171477
Epoch: 18, iteration: 100, loss: 0.166992
Epoch: 18, iteration: 150, loss: 0.208063
Epoch: 18, iteration: 200, loss: 0.188192


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 18, iteration: 250, loss: 0.172187
Epoch: 18, iteration: 300, loss: 0.189182
Epoch: 18, iteration: 350, loss: 0.170439
Epoch: 18, iteration: 400, loss: 0.175481


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 18, iteration: 450, loss: 0.177211


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 19, iteration: 0, loss: 0.178218


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 19, iteration: 50, loss: 0.180264
Epoch: 19, iteration: 100, loss: 0.166334
Epoch: 19, iteration: 150, loss: 0.20846
Epoch: 19, iteration: 200, loss: 0.1809


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 19, iteration: 250, loss: 0.166741
Epoch: 19, iteration: 300, loss: 0.189753
Epoch: 19, iteration: 350, loss: 0.179423
Epoch: 19, iteration: 400, loss: 0.179073


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 19, iteration: 450, loss: 0.184688


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 20, iteration: 0, loss: 0.185154


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 20, iteration: 50, loss: 0.178095
Epoch: 20, iteration: 100, loss: 0.16044
Epoch: 20, iteration: 150, loss: 0.199303
Epoch: 20, iteration: 200, loss: 0.178438


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 20, iteration: 250, loss: 0.167005
Epoch: 20, iteration: 300, loss: 0.188425
Epoch: 20, iteration: 350, loss: 0.173133
Epoch: 20, iteration: 400, loss: 0.186181


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 20, iteration: 450, loss: 0.17938


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 21, iteration: 0, loss: 0.182196


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 21, iteration: 50, loss: 0.171877
Epoch: 21, iteration: 100, loss: 0.160693
Epoch: 21, iteration: 150, loss: 0.206549
Epoch: 21, iteration: 200, loss: 0.184524


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 21, iteration: 250, loss: 0.166371
Epoch: 21, iteration: 300, loss: 0.197106
Epoch: 21, iteration: 350, loss: 0.178235
Epoch: 21, iteration: 400, loss: 0.177534


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 21, iteration: 450, loss: 0.174186


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 22, iteration: 0, loss: 0.178105


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 22, iteration: 50, loss: 0.173564
Epoch: 22, iteration: 100, loss: 0.159513
Epoch: 22, iteration: 150, loss: 0.212328
Epoch: 22, iteration: 200, loss: 0.183194


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 22, iteration: 250, loss: 0.17712
Epoch: 22, iteration: 300, loss: 0.186431
Epoch: 22, iteration: 350, loss: 0.175498
Epoch: 22, iteration: 400, loss: 0.176645


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 22, iteration: 450, loss: 0.178386


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 23, iteration: 0, loss: 0.178499


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 23, iteration: 50, loss: 0.177276
Epoch: 23, iteration: 100, loss: 0.162307
Epoch: 23, iteration: 150, loss: 0.212342
Epoch: 23, iteration: 200, loss: 0.180505


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 23, iteration: 250, loss: 0.163124
Epoch: 23, iteration: 300, loss: 0.181691
Epoch: 23, iteration: 350, loss: 0.173298
Epoch: 23, iteration: 400, loss: 0.179386


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 23, iteration: 450, loss: 0.180006


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 24, iteration: 0, loss: 0.189083


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 24, iteration: 50, loss: 0.173116
Epoch: 24, iteration: 100, loss: 0.165296
Epoch: 24, iteration: 150, loss: 0.20078
Epoch: 24, iteration: 200, loss: 0.182739


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 24, iteration: 250, loss: 0.169916
Epoch: 24, iteration: 300, loss: 0.189887
Epoch: 24, iteration: 350, loss: 0.172403
Epoch: 24, iteration: 400, loss: 0.186015


  0%|          | 0/500 [00:00<?, ?it/s]