In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms, models
import numpy as np
import random
from torchvision.utils import save_image
import os
from torch.optim.lr_scheduler import StepLR
import glob
import cv2

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

from torch.utils.data import Dataset

from torchvision.datasets import MNIST

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

UNET_PATH = 'model_weight/unet.pth'
DNN_PATH = 'model_weight/dnn.pth'

num_epochs = 24
dnn_epoch = 50

# Hyperparameters
BATCH_SIZE = 256

MNIST = True
CIFAR10 = False

# Network Training Settings
Train_BASE_DNN = True
Train_Unet = True

if (os.path.exists(DNN_PATH)) == True:
    Train_BASE_DNN = False

if (os.path.exists(UNET_PATH)) == True:
    Train_Unet = False

In [None]:
if (os.path.exists("./output")) == False:
    os.mkdir("output")

if (os.path.exists("./model_weight")) == False:
    os.mkdir("model_weight")

if (os.path.exists("./test_out")) == False:
    os.mkdir("test_out")

for epoch in range (num_epochs):
    if (os.path.exists("./output/%03d" % epoch)) == False:
        os.mkdir("./output/%03d" % epoch)
    else:
        files = glob.glob("./output/%03d/*.png" % epoch)

        for f in files:
          os.remove(f)

In [None]:
train_dataset = datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                           transforms.Resize(32),
                           transforms.ToTensor()
                       ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

test_dataset =  datasets.MNIST('data', train=False, transform=transforms.Compose([
                           transforms.Resize(32),
                           transforms.ToTensor()
                       ]))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)


In [None]:
class NoisyMNISTDataset(Dataset):
    def __init__(self, mnist_dataset, noise_schedule, t):
        self.mnist_dataset = mnist_dataset
        self.noise_schedule = noise_schedule
        self.t = t

    def __getitem__(self, index):
        image, label = self.mnist_dataset[index]
        noise_level = self.noise_schedule[self.t]
        noisy_image = image + torch.randn_like(image) * noise_level
        return noisy_image, label

    def __len__(self):
        return len(self.mnist_dataset)

def linear_noise_schedule(T, initial_noise=1.0, final_noise=0.0):
    return torch.linspace(initial_noise, final_noise, T)

mnist_dataset = MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
T = 100
noise_schedule = linear_noise_schedule(T)

noisy_mnist_dataset = NoisyMNISTDataset(mnist_dataset, noise_schedule, t=0) 

# Unet

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.activate = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d((2, 2))
        self.dropout = nn.Dropout(p=0.5)
        self.sigmod = nn.Sigmoid ()
        self.label_embedding = nn.Embedding(10, 512)

        self.encoder_1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding= 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding= 1),
            nn.ReLU(inplace=True),
        )

        self.encoder_2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding= 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding= 1),
            nn.ReLU(inplace=True),
        )

        self.encoder_3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding= 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding= 1),
            nn.ReLU(inplace=True),
        )

        self.encoder_4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding= 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding= 1),
            nn.ReLU(inplace=True),
        )
        
        self.middle_1_0 = nn.Conv2d(1024, 1024, 3, padding= 1)
        self.middle_1_1 = nn.Conv2d(1024, 1024, 3, padding= 1)
        
       
        self.deconv4_0 = nn.ConvTranspose2d(1536, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv4_1 = nn.Conv2d(1024, 512, 3, padding= 1) 
        self.uconv4_2 = nn.Conv2d(512, 512, 3, padding= 1)

        self.deconv3_0 = nn.ConvTranspose2d(512, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv3_1 = nn.Conv2d(768, 256, 3, padding= 1) 
        self.uconv3_2 = nn.Conv2d(256, 256, 3, padding= 1)

        self.deconv2_0 = nn.ConvTranspose2d(256, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv2_1 = nn.Conv2d(640, 128, 3, padding= 1) 
        self.uconv2_2 = nn.Conv2d(128, 128, 3, padding= 1)

        self.deconv1_0 = nn.ConvTranspose2d(128, 512, 3, stride=(2,2), padding = 1, output_padding = 1)
        self.uconv1_1 = nn.Conv2d(576, 192, 3, padding= 1) 
        self.uconv1_2 = nn.Conv2d(192, 192, 3, padding= 1)

  
        self.out_layer = nn.Conv2d(192, 1, 1)

    def forward(self, x, input_labels, target_labels):
        conv1 = self.encoder_1(x)
        pool1 = self.pool(conv1)
        pool1 = self.dropout(pool1)

        conv2 = self.encoder_2(pool1)
        pool2 = self.pool(conv2)
        pool2 = self.dropout(pool2)

        conv3 = self.encoder_3(pool2)
        pool3 = self.pool(conv3)
        pool3 = self.dropout(pool3)

        conv4 = self.encoder_4(pool3)
        pool4 = self.pool(conv4)
        encoder_out = self.dropout(pool4)

        input_label_embedding = self.label_embedding(input_labels).view(input_labels.size(0), 512, 1, 1)
        x1 = torch.cat([encoder_out, input_label_embedding.expand_as(encoder_out)], dim=1)

        convm = self.middle_1_0(x1)
        convm = self.activate(convm)
        convm = self.middle_1_1(convm)
        x2 = self.activate(convm)

        target_label_embedding = self.label_embedding(target_labels).view(target_labels.size(0), 512, 1, 1)
        x2 = torch.cat([x2, target_label_embedding.expand(x2.size(0), 512, x2.size(2), x2.size(3))], dim=1)

        deconv4 = self.deconv4_0(x2)
        uconv4 = torch.cat([deconv4, conv4], 1)   # (None, 4, 4, 1024)
        uconv4 = self.dropout(uconv4)
        uconv4 = self.uconv4_1(uconv4)            # (None, 4, 4, 512)
        uconv4 = self.activate(uconv4)
        uconv4 = self.uconv4_2(uconv4)            # (None, 4, 4, 512)
        uconv4 = self.activate(uconv4)

        deconv3 = self.deconv3_0(uconv4)          # (None, 8, 8, 512)
        uconv3 = torch.cat([deconv3, conv3], 1)   # (None, 8, 8, 768)
        uconv3 = self.dropout(uconv3)
        uconv3 = self.uconv3_1(uconv3)            # (None, 8, 8, 256)
        uconv3 = self.activate(uconv3)
        uconv3 = self.uconv3_2(uconv3)            # (None, 8, 8, 256)
        uconv3 = self.activate(uconv3)
        
        deconv2 = self.deconv2_0(uconv3)          # (None, 16, 16, 512)
        uconv2 = torch.cat([deconv2, conv2], 1)   # (None, 16, 16, 640)
        uconv2 = self.dropout(uconv2)
        uconv2 = self.uconv2_1(uconv2)            # (None, 16, 16, 128)
        uconv2 = self.activate(uconv2)
        uconv2 = self.uconv2_2(uconv2)            # (None, 16, 16, 128)
        uconv2 = self.activate(uconv2)

        deconv1 = self.deconv1_0(uconv2)          # (None, 32, 32, 512)
        uconv1 = torch.cat([deconv1, conv1], 1)   # (None, 32, 32, 576)
        uconv1 = self.dropout(uconv1)
        uconv1 = self.uconv1_1(uconv1)            # (None, 32, 32, 192)
        uconv1 = self.activate(uconv1)
        uconv1 = self.uconv1_2(uconv1)            # (None, 32, 32, 192)
        uconv1 = self.activate(uconv1)

        out = self.out_layer(uconv1)
        out = self.sigmod(out)
        return out

# Training

In [None]:
denoising_model = UNet(in_channels=1, out_channels=1).cuda()
mse_loss = nn.MSELoss()
optimizer = optim.Adam(denoising_model.parameters(), lr=1e-3)

# Define the noise schedule
T = 100
noise_schedule = linear_noise_schedule(T)

# Training loop
num_epochs = 10
batch_size = 64

for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}')
    
    for t in range(T):
        # Create the NoisyMNISTDataset for the current diffusion step t
        noisy_mnist_dataset = NoisyMNISTDataset(mnist_dataset, noise_schedule, t)
        train_loader = DataLoader(noisy_mnist_dataset, batch_size=batch_size, shuffle=True)

        for i, (noisy_images, labels) in enumerate(train_loader):
            noisy_images = noisy_images.cuda()

            # Forward pass
            denoised_images = denoising_model(noisy_images)

            # Compute the denoising score matching loss
            noise_level = noise_schedule[t]
            noise = (denoised_images - noisy_images) / noise_level
            loss = mse_loss(noise, torch.randn_like(noise))

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                print(f'Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')


In [None]:
def generate_image_with_label(denoising_model, label, noise_schedule, T, label_embedding):
    # Create a random noise image
    noise_image = torch.randn(1, 1, 28, 28).cuda()

    # Convert the label to a tensor
    label_tensor = torch.tensor([label], dtype=torch.long).cuda()

    # Perform the reverse diffusion process
    for t in reversed(range(T)):
        with torch.no_grad():
            # Condition the noise_image with the desired label
            noise_image_with_label = torch.cat([noise_image, label_embedding(label_tensor).view(1, -1, 1, 1)], dim=1)

            # Denoise the image
            denoised_image_with_label = denoising_model(noise_image_with_label)

            # Compute the noise level and update the noise_image
            noise_level = np.sqrt(1 - noise_schedule[t].item())
            noise = torch.randn_like(denoised_image_with_label) * noise_level
            noise_image = denoised_image_with_label + noise

    # Return the generated image
    return noise_image.squeeze(0).detach().cpu()

In [None]:
desired_label = 3
generated_image = generate_image_with_label(denoising_model, desired_label, noise_schedule, T, denoising_model.label_embedding)