In [2]:
import copy
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

import logging

In [3]:
import os
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


In [5]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)

In [6]:
class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

In [7]:
class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.sa1 = SelfAttention(128, 32)
        self.down2 = Down(128, 256)
        self.sa2 = SelfAttention(256, 16)
        self.down3 = Down(256, 256)
        self.sa3 = SelfAttention(256, 8)

        self.bot1 = DoubleConv(256, 512)
        self.bot2 = DoubleConv(512, 512)
        self.bot3 = DoubleConv(512, 256)

        self.up1 = Up(512, 128)
        self.sa4 = SelfAttention(128, 16)
        self.up2 = Up(256, 64)
        self.sa5 = SelfAttention(64, 32)
        self.up3 = Up(128, 64)
        self.sa6 = SelfAttention(64, 64)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)

        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_dim)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t, y=None):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        if y is not None:
            t += self.label_emb(y)

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        output = self.outc(x)
        return output

In [8]:
class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())

In [9]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, type="unconditional", device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device
        self.type = type

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def unconditional_sample(self, model, n, labels, cfg_scale=3):
        if self.type != "unconditional" and labels is None:
            raise ValueError('Labels must be passed to perform conditional sampling.')
        if self.type != "unconditional" and cfg_scale <= 0:
            raise ValueError('For conditional sampling, make sure the classifier-free guidance scale must be '
                             'greater than 0.')
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                if self.type == "unconditional":
                    predicted_noise = model(x, t)
                else:
                    predicted_noise = model(x, t, labels)
                    if cfg_scale > 0:
                        uncond_predicted_noise = model(x, t, None)
                        predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

# Unconditional Model train

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [os.path.join(root_dir, img_name) for img_name in os.listdir(root_dir)]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

def get_data(dataset_path, batch_size, image_size):
    train_transform = transforms.Compose([
        transforms.Resize(80),
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    test_transform = transforms.Compose([
        transforms.Resize(80),
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_data = CustomDataset(root_dir=os.path.join(dataset_path, 'train_label_img'), transform=train_transform)
    test_data = CustomDataset(root_dir=os.path.join(dataset_path, 'test_label_img'), transform=test_transform)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [None]:
def setup_logging(run_name):
    logging.basicConfig(filename=f'{run_name}.log', level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')

def train(run_name, epochs=5, batch_size=16, image_size=64, dataset_path= '/content/drive/My Drive/', device="cuda", lr=3e-4):
    setup_logging(run_name)
    device = torch.device(device)
    train_loader,test_loader  = get_data(dataset_path, batch_size, image_size)
    model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=image_size, type='unconditional', device=device)
    l = len(train_loader)

    for epoch in range(epochs):
        print(f"Starting epoch {epoch}:")
        pbar = tqdm(train_loader)
        for images in pbar:
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())

        if epoch % 5 == 0:
            torch.save(model.state_dict(), os.path.join("/content/drive/My Drive/", run_name, f"DDPM_unconditional_best_weights.pt"))

In [None]:
def launch():
    run_name = 'SSDD'
    epochs = 15
    batch_size = 4
    image_size = 64
    dataset_path = "/content/drive/My Drive/SSDD"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lr = 3e-3
    train(run_name, epochs, batch_size, image_size, dataset_path, device, lr)

In [None]:
launch()

Starting epoch 0:


100%|██████████| 235/235 [00:55<00:00,  4.24it/s, MSE=0.0472]


Starting epoch 1:


100%|██████████| 235/235 [00:43<00:00,  5.36it/s, MSE=0.0158]


Starting epoch 2:


100%|██████████| 235/235 [00:45<00:00,  5.17it/s, MSE=0.0161]


Starting epoch 3:


100%|██████████| 235/235 [00:43<00:00,  5.39it/s, MSE=0.00823]


Starting epoch 4:


100%|██████████| 235/235 [00:43<00:00,  5.36it/s, MSE=0.0212]


Starting epoch 5:


100%|██████████| 235/235 [00:43<00:00,  5.38it/s, MSE=0.0317]


Starting epoch 6:


100%|██████████| 235/235 [00:43<00:00,  5.36it/s, MSE=0.00945]


Starting epoch 7:


100%|██████████| 235/235 [00:43<00:00,  5.39it/s, MSE=0.00358]


Starting epoch 8:


100%|██████████| 235/235 [00:43<00:00,  5.37it/s, MSE=0.0115]


Starting epoch 9:


100%|██████████| 235/235 [00:43<00:00,  5.38it/s, MSE=0.0158]


Starting epoch 10:


100%|██████████| 235/235 [00:43<00:00,  5.39it/s, MSE=0.00586]


Starting epoch 11:


100%|██████████| 235/235 [00:44<00:00,  5.34it/s, MSE=0.00935]


Starting epoch 12:


100%|██████████| 235/235 [00:43<00:00,  5.39it/s, MSE=0.00455]


Starting epoch 13:


100%|██████████| 235/235 [00:43<00:00,  5.39it/s, MSE=0.00295]


Starting epoch 14:


100%|██████████| 235/235 [00:43<00:00,  5.38it/s, MSE=0.00529]


In [None]:
!nvidia-smi

Wed Apr 24 13:58:13 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   76C    P0              40W /  70W |   6973MiB / 15360MiB |     81%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Denoise Generation Test

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
model.load_state_dict(torch.load("/content/drive/My Drive/Detecting Small Vessels or Boats with Satellite Imagray using YOLOv9/model_scratch/DDPM/DDPM_unconditional_best_weights_1.pt", map_location=device))
diffusion = Diffusion(img_size=64, device=device)

# Sample random timesteps for denoising
timesteps = torch.randint(0, diffusion.n_timesteps, (1,)).long().to(device)

preprocess = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
postprocess = transforms.Compose([
    transforms.Normalize((-1, -1, -1), (2, 2, 2)),
    transforms.ToPILImage()
])

input_image = Image.open('/content/drive/My Drive/Detecting Small Vessels or Boats with Satellite Imagray using YOLOv9/Datasets/SSDD/test_img.jpg')
input_tensor = preprocess(input_image).unsqueeze(0)

with torch.no_grad():
    denoised_image = model(input_tensor.to(device), timesteps)

denoised_image = postprocess(denoised_image.squeeze())

AttributeError: 'Diffusion' object has no attribute 'n_timesteps'

In [None]:
original_image = Image.open('original_image.jpg')
denoised_image = Image.open('denoised_image.jpg')

fig, axes = plt.subplots(1, 2)

axes[0].imshow(original_image)
axes[0].set_title('Original Image')
axes[1].imshow(denoised_image)
axes[1].set_title('Denoised Image')

for ax in axes:
    ax.axis('off')

plt.show()

## Testing Purpose

In [None]:
def plot_images(images):
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)

In [None]:
def get_data(args):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(80),
        torchvision.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    return dataloader


In [None]:
def setup_logging(run_name):
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    os.makedirs(os.path.join("models", run_name), exist_ok=True)
    os.makedirs(os.path.join("results", run_name), exist_ok=True)

In [None]:
def train(args):
    setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)
    model = UNet().to(device) if args.sampling_type == 'unconditional' else UNet(num_classes=args.num_classes).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, type=args.sampling_type, device=device)
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)
    if args.sampling_type != "unconditional":
        ema = EMA(0.995)
        ema_model = copy.deepcopy(model).eval().requires_grad_(False)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, labels) in enumerate(pbar):
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            if args.sampling_type == 'unconditional':
                predicted_noise = model(x_t, t)
            else:
                labels = labels.to(device)
                if np.random.random() < 0.1: labels = None
                predicted_noise = model(x_t, t, labels)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if args.sampling_type != 'unconditional':
                ema.step_ema(ema_model, model)

            pbar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        if epoch % 10 == 0:
            if args.sampling_type != "unconditional":
                labels = torch.arange(10).long().to(device)
                sampled_images = diffusion.sample(model, n=len(labels), labels=labels)
                ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)
                plot_images(sampled_images)
                save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
                save_images(ema_sampled_images, os.path.join("results", args.run_name, f"{epoch}_ema.jpg"))
                torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))
                torch.save(ema_model.state_dict(), os.path.join("models", args.run_name, f"ema_ckpt.pt"))
                torch.save(optimizer.state_dict(), os.path.join("models", args.run_name, f"optim.pt"))
            else:
                sampled_images = diffusion.sample(model, n=images.shape[0])
                save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
                torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))


In [None]:
def launch(sampling_type='unconditional'):
    import argparse
    parser = argparse.ArgumentParser()
    args = parser.parse_args()
    args.sampling_type = sampling_type
    args.run_name = "DDPM_"+args.sampling_type
    args.epochs = 500
    args.batch_size = 12
    args.image_size = 64
    args.num_classes = 10
    args.dataset_path = r"<path to Landscape dataset>"
    args.device = "cuda"
    args.lr = 3e-4
    train(args)

In [None]:
def generate_images(sampling_type='unconditional'):
    device = "cuda"
    if sampling_type != 'unconditional':
        model = UNet().to(device)
        model.load_state_dict(torch.load("<path to model checkpoint file>"))
        diffusion = Diffusion(img_size=64, device=device)
        x = diffusion.sample(model, 8, torch.Tensor([6] * 8).long().to(device), cfg_scale=0)
    else:
        model = UNet(num_classes=10).to(device)
        model.load_state_dict(torch.load("<path to model checkpoint file>"))
        diffusion = Diffusion(img_size=64, type=sampling_type, device=device)
        x = diffusion.sample(model, 8)
    plot_images(x)

In [None]:
if __name__ == '__main__':
    launch()
    generate_images()
    launch('conditional')
    generate_images('conditional')

usage: colab_kernel_launcher.py [-h]
colab_kernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-e9bf3af7-60c6-40e3-af0e-09d9ba0c1c5d.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
