In [1]:
from contextlib import contextmanager
from copy import deepcopy
import math
import os
import glob
from PIL import Image
from IPython import display
from matplotlib import pyplot as plt
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils import data
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm, trange
from torchvision.models import resnet50
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
device

device(type='cuda')

In [None]:
torch.cuda.empty_cache()

In [4]:
def get_alphas_sigmas(t):
    """
    Returns the scaling factors for the clean image (alpha) and for the
    noise (sigma), given a timestep.
    """
    return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)

if (os.path.exists("./output")) == False:
    os.mkdir("output")

files = glob.glob("./output/*.png")

for f in files:
    os.remove(f)

# Data loader

In [5]:
batch_size = 200
epoches = 200
ema_decay = 0.999
steps = 500
eta = 1.

guidance_scale = 2.

def load_data_set(batch_size=64):
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

    # Load the entire CIFAR10 dataset
    full_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=tf)

    # Create filtered Subsets for each category
    source_indices = [i for i, (_, label) in enumerate(full_dataset) if label <= 9]


    source_set = Subset(full_dataset, source_indices)

    # Create DataLoaders for each subset
    source_dl = DataLoader(source_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)

    return source_dl

# Models

## Sub Models

In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, main, skip=None):
        super().__init__()
        self.main = nn.Sequential(*main)
        self.skip = skip if skip else nn.Identity()

    def forward(self, input):
        return self.main(input) + self.skip(input)

class ResConvBlock(ResidualBlock):
    def __init__(self, c_in, c_mid, c_out, is_last=False):
        skip = None if c_in == c_out else nn.Conv2d(c_in, c_out, 1, bias=False)
        super().__init__([
            nn.Conv2d(c_in, c_mid, 3, padding=1),
            nn.Dropout2d(0.1, inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(c_mid, c_out, 3, padding=1),
            nn.Dropout2d(0.1, inplace=True) if not is_last else nn.Identity(),
            nn.ReLU(inplace=True) if not is_last else nn.Identity(),
        ], skip)

class SelfAttention2d(nn.Module):
    def __init__(self, c_in, n_head=1, dropout_rate=0.1):
        super().__init__()
        assert c_in % n_head == 0
        self.norm = nn.GroupNorm(1, c_in)
        self.n_head = n_head
        self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
        self.out_proj = nn.Conv2d(c_in, c_in, 1)
        self.dropout = nn.Dropout2d(dropout_rate, inplace=True)

    def forward(self, input):
        n, c, h, w = input.shape
        qkv = self.qkv_proj(self.norm(input))
        qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
        q, k, v = qkv.chunk(3, dim=1)
        scale = k.shape[3]**-0.25
        att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
        y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
        return input + self.dropout(self.out_proj(y))

class FourierFeatures(nn.Module):
    def __init__(self, in_features, out_features, std=1.):
        super().__init__()
        assert out_features % 2 == 0
        self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)

    def forward(self, input):
        f = 2 * math.pi * input @ self.weight.T
        return torch.cat([f.cos(), f.sin()], dim=-1)

class FeatureEmbedding(nn.Module):
    def __init__(self):
        super(FeatureEmbedding, self).__init__()

        self.fe = nn.Sequential(
            nn.Linear(1000, 3072),
            nn.Tanh(),
        )

        self.pretrained_model = resnet50(pretrained=True)
        self.pretrained_model = self.pretrained_model.to(device)

    def forward(self, x):
        with torch.no_grad():
            x = self.pretrained_model(x)
        x = self.fe(x)
        x = x.view(-1, 3, 32, 32)
        return x

class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

class DomainClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(256 * 8 * 8, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(True),
            nn.Linear(1024, 2),
            nn.LogSoftmax(dim = 1)
        )

    def forward(self, input):
        input = input.view(-1, 256 * 8 * 8)
        return self.net(input)

def expand_to_planes(input, shape):
    return input[..., None, None].repeat([1, 1, shape[2], shape[3]])

## Diffussion Model

In [7]:
class DownSample(nn.Module):
    def __init__(self):
        super().__init__()
        c = 64
        self.avg_1 = nn.AvgPool2d(2)
        self.avg_2 = nn.AvgPool2d(2)

        self.down_0 = nn.Sequential(
            ResConvBlock(3 + 16 + 4, c, c),
            ResConvBlock(c, c, c),
        )

        self.down_1 = nn.Sequential(
            ResConvBlock(c, c * 2, c * 2),
            ResConvBlock(c * 2, c * 2, c * 2),
        )

        self.down_2 = nn.Sequential(
            ResConvBlock(c * 2, c * 4, c * 4),
            SelfAttention2d(c * 4, c * 4 // 64),
            ResConvBlock(c * 4, c * 4, c * 4),
            SelfAttention2d(c * 4, c * 4 // 64),
        )

    def forward(self, input):
        down_sample = self.down_0(input)
        identity_0 = nn.Identity()(down_sample)

        down_sample = self.avg_1(self.down_1(down_sample))
        identity_1 = nn.Identity()(down_sample)

        down_sample = self.avg_2(self.down_2(down_sample))
        identity_2 = nn.Identity()(down_sample)

        return down_sample, identity_0, identity_1, identity_2

class MidSample(nn.Module):
    def __init__(self):
        super().__init__()
        c = 64
        self.avg_3 = nn.AvgPool2d(2)

        self.mid_1 = nn.Sequential(
            ResConvBlock(c * 4, c * 8, c * 8),
            SelfAttention2d(c * 8, c * 8 // 64),
            ResConvBlock(c * 8, c * 8, c * 8),
            SelfAttention2d(c * 8, c * 8 // 64),
            ResConvBlock(c * 8, c * 8, c * 8),
            SelfAttention2d(c * 8, c * 8 // 64),
            ResConvBlock(c * 8, c * 8, c * 4),
            SelfAttention2d(c * 4, c * 4 // 64),
        )

        self.mid_2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, down_sample):
        middle_sample = self.mid_1(down_sample)
        middle_sample = self.avg_3(self.mid_2(middle_sample))

        return middle_sample

class UpSample(nn.Module):
    def __init__(self):
        super().__init__()
        c = 64
        self.up_1 = nn.Sequential(
            ResConvBlock(c * 8, c * 4, c * 4),
            SelfAttention2d(c * 4, c * 4 // 64),
            ResConvBlock(c * 4, c * 4, c * 2),
            SelfAttention2d(c * 2, c * 2 // 64),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
        )

        self.up_2 = nn.Sequential(
            ResConvBlock(c * 4, c * 2, c * 2),
            ResConvBlock(c * 2, c * 2, c),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
        )

        self.up_3 = nn.Sequential(
            ResConvBlock(c * 2, c, c),
            ResConvBlock(c, c, 3, is_last=True),
        )

    def forward(self, middle_sample, identity_0, identity_1, identity_2):
        up_sample = torch.cat([middle_sample, identity_2], dim=1)
        up_sample = self.up_1(up_sample)

        up_sample = torch.cat([up_sample, identity_1], dim=1)
        up_sample = self.up_2(up_sample)

        up_sample = torch.cat([up_sample, identity_0], dim=1)

        return self.up_3(up_sample)

class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        c = 64
        self.feature_embed = FeatureEmbedding()
        self.domain_classifier = DomainClassifier()
        self.timestep_embed = FourierFeatures(1, 16)
        self.class_embed = nn.Embedding(11, 4)

        self.down_sample_net = DownSample()
        self.mid_sample_net = MidSample()
        self.up_sample_net = UpSample()

    def forward(self, input, t, cond):
        timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
        class_embed = expand_to_planes(self.class_embed(cond + 1), input.shape)
        x = torch.cat([input, timestep_embed, class_embed], dim=1)

        down_sample, identity_0, identity_1, identity_2 = self.down_sample_net(x)
        middle_sample = self.mid_sample_net(down_sample)
        up_sample = self.up_sample_net(middle_sample, identity_0, identity_1, identity_2)

        return up_sample

In [8]:
@torch.no_grad()
def sample(model, x, steps, eta, classes, guidance_scale=1.):
    """
    Draws samples from a model given starting noise.
    """
    ts = x.new_ones([x.shape[0]])

    # Create the noise schedule
    t = torch.linspace(1, 0, steps + 1)[:-1]
    alphas, sigmas = get_alphas_sigmas(t)

    # The sampling loop
    for i in range(steps):

        # Get the model output (v, the predicted velocity)
        with torch.cuda.amp.autocast():
            x_in = torch.cat([x, x])
            ts_in = torch.cat([ts, ts])
            classes_in = torch.cat([-torch.ones_like(classes), classes])
            v_uncond, v_cond = model(x_in, ts_in * t[i], classes_in)[0].float().chunk(2)
        v = v_uncond + guidance_scale * (v_cond - v_uncond)

        # Predict the noise and the denoised image
        pred = x * alphas[i] - v * sigmas[i]
        eps = x * sigmas[i] + v * alphas[i]

        # If we are not on the last timestep, compute the noisy image for the
        # next timestep.
        if i < steps - 1:
            # If eta > 0, adjust the scaling factor for the predicted noise
            # downward according to the amount of additional noise to add
            ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
                (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
            adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()

            # Recombine the predicted noise and predicted denoised image in the
            # correct proportions for the next step
            x = pred * alphas[i + 1] + eps * adjusted_sigma

            # Add the correct amount of fresh noise
            if eta:
                x += torch.randn_like(x) * ddim_sigma

    # If we are on the last timestep, output the denoised image
    return pred

# Train

In [9]:
rng = torch.quasirandom.SobolEngine(1, scramble=True)
torch.cuda.empty_cache()

steps = 500
eta = 1.
ema_decay = 0.999
guidance_scale = 2.

In [10]:
def generate_diffussion_target(images, labels):
    t = rng.draw(labels.shape[0])[:, 0].to(device)

    alphas, sigmas = get_alphas_sigmas(t)

    alphas = alphas[:, None, None, None]
    sigmas = sigmas[:, None, None, None]

    noise = torch.randn_like(images)
    noised_reals = images * alphas + noise * sigmas

    targets = noise * alphas - images * sigmas

    return t, noised_reals, targets

def train_diffussion(epoch, model, train_dl, optimizer, scheduler):
    model.train()
    for src_images, src_labels in train_dl:
            src_images, src_labels = src_images.to(device), src_labels.to(device)

            t, noised_src, src_recon_targets = generate_diffussion_target(src_images, src_labels)

            optimizer.zero_grad()

            to_drop = torch.rand(src_labels.shape, device=src_labels.device).le(0.2)
            classes_drop = torch.where(to_drop, -torch.ones_like(src_labels), src_labels)

            output = model(noised_src, t, classes_drop)
            diffused_loss = F.mse_loss(output, src_recon_targets)

            loss = diffused_loss

            loss.backward()
            optimizer.step()
    scheduler.step()

    print(f"Epoch {epoch+1}:")
    print('diff', loss.item())
    noise = torch.randn([10, 3, 32, 32], device=device)
    fakes_classes = torch.arange(10, device=device)
    fakes = sample(model, noise, steps, eta, fakes_classes, guidance_scale)
    fakes = (fakes + 1) / 2
    fakes = torch.clamp(fakes, min=0, max = 1)
    save_image(fakes.data, './output/%03d_train.png' % epoch)

In [11]:
train_dl = load_data_set(batch_size = batch_size)
criterion = nn.NLLLoss()
model = Diffusion().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size = 1, gamma = 0.98)

Files already downloaded and verified


In [12]:
epoch = 0
active_domian_loss = 50
src_domain_label = torch.zeros(batch_size).long().to(device)
tgt_domain_label = torch.ones(batch_size).long().to(device)

while True:
    try:
        train_diffussion(epoch, model, train_dl, optimizer, scheduler)
    except KeyboardInterrupt:
        break

# Result

In [None]:
image = Image.open("./output/09_train.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open("./output/020_train.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open("./output/029_train.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open("./output/039_train.png")
plt.imshow(image)
plt.show()

In [None]:
image = Image.open("./output/049_train.png")
plt.imshow(image)
plt.show()