<a href="https://colab.research.google.com/github/fdsig/defuse/blob/main/simple_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install wandb -qqq
!pip install pytorch-ssim

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m68.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 KB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 KB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 KB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-ssim
  Downloading pytorch_ssim-0.1.tar.gz (1.4 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytorch-ssim
  Building wheel for pytorch-ssim (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-ssim: filename=pytorc

In [4]:
import os
from pathlib import Path
import zipfile
import requests
from tqdm import tqdm
import tarfile
import wandb

def download_and_log_dataset(url, dataset_name, local_filename, wandb_project, wandb_entity):
    # Download and unzip the dataset
    path = Path('DIV2K')
    path.mkdir(exist_ok=True)
    chunk_size = 1024
    local_dir_name = path/dataset_name
    local_filename = Path(local_filename)
    print(local_dir_name,local_filename)
    parent_local_filename = local_filename.parent.parent/local_filename.name
    if not local_dir_name.exists():
        if not local_filename.exists():
            response = requests.get(url, stream=True)
            file_size = int(response.headers.get('Content-Length', 0))
            with open(local_filename, 'wb') as f:
                for data in tqdm(response.iter_content(chunk_size), total=file_size // chunk_size, unit='KB'):
                    f.write(data)
    if local_filename.exists() and local_filename.suffix=='.zip':
        print(local_filename.suffix)
        try:
            with zipfile.ZipFile(local_filename, 'r') as zip_ref:
                zip_ref.extractall(path/dataset_name)
        
        except:
            with zipfile.ZipFile(parent_local_filename, 'r') as zip_ref:
                zip_ref.extractall(path/dataset_name)
                parent_local_filename.rename(path/parent_local_filename.name)
    if local_filename.exists() and local_filename.suffix=='.tar':
        print(local_filename)
        my_tar = tarfile.open(local_filename)
        my_tar.extractall(path/dataset_name) # specify which folder to extract to
        my_tar.close()
   



    # Initialize wandb and log dataset as artifact
    run = wandb.init(project=wandb_project, entity=wandb_entity)
    artifact = wandb.Artifact(dataset_name, type='dataset')
    artifact.add_reference(f'file://{path.resolve()}')
    run.log_artifact(artifact)

    run.finish()


# URLs for the train, validation, and test sets
train_url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
val_url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"
test_url = "https://cv.snu.ac.kr/research/EDSR/DIV2K.tar"

# Local file names for the downloaded ZIP files
train_filename = "DIV2K/DIV2K_train_HR.zip"
val_filename = "DIV2K/DIV2K_valid_HR.zip"
test_filename = "DIV2K/DIV2K_test_HR.tar"

# Dataset names and subdirectories
train_dataset = "DIV2K_train_HR"
val_dataset = "DIV2K_valid_HR"
test_dataset = "DIV2K_test_HR"

# W&B project and entity (team)
wandb_project = "gpt_4_denoising"
wandb_entity = "demonstrations"

# Download and log train, validation, and test sets
download_and_log_dataset(train_url, train_dataset, train_filename, wandb_project, wandb_entity)
download_and_log_dataset(val_url, val_dataset, val_filename, wandb_project, wandb_entity)
download_and_log_dataset(test_url, test_dataset, test_filename, wandb_project, wandb_entity)


DIV2K/DIV2K_train_HR DIV2K/DIV2K_train_HR.zip
.zip


[34m[1mwandb[0m: Generating checksum for up to 10000 files in "/content/DIV2K"...
[34m[1mwandb[0m: Done. 26.6s


DIV2K/DIV2K_valid_HR DIV2K/DIV2K_valid_HR.zip
.zip


[34m[1mwandb[0m: Generating checksum for up to 10000 files in "/content/DIV2K"...
[34m[1mwandb[0m: Done. 24.5s


DIV2K/DIV2K_test_HR DIV2K/DIV2K_test_HR.tar


100%|██████████| 7469400/7469400 [10:53<00:00, 11425.69KB/s]


DIV2K/DIV2K_test_HR.tar


[34m[1mwandb[0m: Generating checksum for up to 10000 files in "/content/DIV2K"...
[34m[1mwandb[0m: Done. 105.4s


In [None]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.middle = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        middle = self.middle(encoded)
        decoded = self.decoder(torch.cat([middle, encoded], dim=1))
        return decoded

model = UNet()


In [6]:
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
            SelfAttention(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, height * width)
        energy = torch.bmm(query, key)
        attention = torch.softmax(energy, dim=-1)
        value = self.value(x).view(batch_size, -1, height * width)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)
        out = self.gamma * out + x
        return out

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True)
        )
        self.conv_block = ConvBlock(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1, x2], dim=1)
        return self.conv_block(x)

class ImprovedUNet(nn.Module):
    def __init__(self):
        super(ImprovedUNet, self).__init__()

        self.enc1 = ConvBlock(3, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.enc4 = ConvBlock(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.up1 = UpBlock(512, 256)
        self.up2 = UpBlock(256, 128)
        self.up3 = UpBlock(128, 64)

        self.predict_noise = nn.Conv2d(64, 3, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        d1 = self.up1(e4, e3)
        d2 = self.up2(d1, e2)
        d3 = self.up3(d2, e1)

        noise = self.predict_noise(d3)
        denoised_image = x - noise
        return denoised_image, noise

# Create an instance of the ImprovedUNet model
model = ImprovedUNet()

   

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.middle = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        middle = self.middle(encoded)
        decoded = self.decoder(torch.cat([middle, encoded], dim=1))
        return decoded

model = UNet()

In [None]:
# Denoising Autoencoder Model
class RGBDenoisingModel(nn.Module):
    def __init__(self):
        super(RGBDenoisingModel, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),  # Change the input channels to 3
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_ssim
import wandb
import os
from wandb import Table

os.environ["WANDB_ENTITY"] = "demonstrations"
os.environ["WANDB_PROJECT"] = "gpt_4_denoising"

# Hyperparameters
batch_size = 64
epochs = 30
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
noise_std = 0.1

# Data preprocessing
# Data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load the datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Split the train dataset into train and validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


import os
import glob
import numpy as np
from PIL import Image
from torchvision.transforms import ToTensor
import torch.utils.data as data

class DIV2KDataset(data.Dataset):
    def __init__(self, root_dir, train=True, transform=None, noise_std=30):
        self.train = train
        self.transform = transform
        self.noise_std = noise_std

        if self.train:
            self.target_dir = os.path.join(root_dir, 'DIV2K_train_HR/DIV2K_train_HR')
        else:
            self.target_dir = os.path.join(root_dir, 'DIV2K_valid_HR/DIV2K_valid_HR')

        self.target_images = sorted(glob.glob(os.path.join(self.target_dir, '*.png')))

    def __len__(self):
        return len(self.target_images)

    def __getitem__(self, idx):
        target_image = Image.open(self.target_images[idx]).convert('RGB')
        
        if self.transform:
            target_image = self.transform(target_image)

        # Add Gaussian noise
        noisy_image = target_image + torch.normal(mean=0.0, std=self.noise_std / 255, size=target_image.shape)

        # Clip the pixel values to the valid range [0, 1]
        noisy_image = torch.clamp(noisy_image, min=0, max=1)

        return noisy_image, target_image


# Image transforms
from torchvision.transforms import Resize, ToTensor, Compose

root_dir = "./DIV2K"
transform = Compose([Resize((256, 256)), ToTensor()])
noise_std = 30


train_dataset = DIV2KDataset(root_dir, train=True, transform=transform, noise_std=noise_std)
val_dataset = DIV2KDataset(root_dir, train=False, transform=transform, noise_std=noise_std)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)




# Assuming you've already implemented the dataset loading and transformation
# Replace MyDataset with your custom dataset class
# Initialize WandB
wandb.init(project="image-denoising",entity='demonstrations')

# Create the model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedUNet().to(device)
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.001)

# Training function
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0.0

    for noisy_images, target_images in dataloader:
        noisy_images, target_images = noisy_images.to(device), target_images.to(device)
        optimizer.zero_grad()

        denoised_images, predicted_noise = model(noisy_images)
        loss = criterion(denoised_images, target_images)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

# Validation function
def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    epoch_loss = 0.0

    with torch.no_grad():
        for noisy_images, target_images in dataloader:
            noisy_images, target_images = noisy_images.to(device), target_images.to(device)

            denoised_images, predicted_noise = model(noisy_images)
            loss = criterion(denoised_images, target_images)

            epoch_loss += loss.item()

            # Log images to WandB
            if wandb.run.step % 100 == 0:
                wandb.log({
                    "Noisy Images": wandb.Image(make_grid(noisy_images.cpu().detach())),
                    "Target Images": wandb.Image(make_grid(target_images.cpu().detach())),
                    "Predicted Noise": wandb.Image(make_grid(predicted_noise.cpu().detach())),
                    "Denoised Images": wandb.Image(make_grid(denoised_images.cpu().detach()))
                })

    return epoch_loss / len(dataloader)

# Main training loop
num_epochs = 50
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss = validate_epoch(model, val_loader, criterion, device)

    # Log metrics to WandB
    wandb.log({
        "Train Loss": train_loss,
        "Validation Loss": val_loss
    })

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

# Save the final model
torch.save(model.state_dict(), "improved_unet.pth")
wandb.save("improved_unet.pth")

# Finish WandB run
wandb.finish()


Files already downloaded and verified
Files already downloaded and verified


VBox(children=(Label(value='0.001 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.124152…

OutOfMemoryError: ignored

In [None]:
import wandb
from wandb.sdk.interface.artifacts import Artifact

api = wandb.Api()

# Set your project and user name
project_name = "your_project_name"
user_name = "demonstrations"

# Retrieve the desired run
run_path = f"{user_name}/{project_name}/{run.id}"
run = api.run(run_path)

# Create the report
report = wandb.Report("Model and Training Overview", project_name, user_name)
report.add_description("This report provides a detailed overview of the model architecture, training process, performance, and debugging steps for a denoising image model based on a stable diffusion process.")

# Add a text block with an explanation of the model
report.add_block("text", {
    "text": """
This project aims to build a denoising model based on a stable diffusion process for image restoration tasks. The model uses a custom PyTorch architecture with multiple Convolutional layers and ReLU activation functions, along with an SSIM loss function to optimize its parameters. The model is trained on the CIFAR-10 dataset, with images perturbed by Gaussian noise to simulate noisy input data. The objective of the model is to reconstruct the original image as closely as possible, with minimal artifacts or noise.

Throughout the development process, we encountered several issues that required debugging, including:

1. Issues with tensor shapes in the data loader.
2. Errors related to the model's input channels not matching the data.
3. Errors caused by device mismatches in the add_noise function.

We were able to resolve these issues by refining the data loader, updating the model architecture to accept RGB images, and ensuring tensors were created on the correct device.

I would like to express my gratitude to the Engineering Manager at NVIDIA, @ptrblck_de, the academic authors who contributed to the development of stable diffusion models, and the amazing team at OpenAI for their invaluable resources and support throughout this learning journey.
"""})

# Add a code block with the relevant model and training code snippets
report.add_block("code", {"code": '''
# Model Architecture
class DenoisingModel(nn.Module):
    ...

# Loss Function
criterion = -1 * pytorch_ssim.SSIM(window_size=11)

# Training Function
def train(model, train_loader, criterion, optimizer, device):
    ...

# Validation Function
def validate(model, val_loader, criterion, device):
    ...

# Training Loop
for epoch in range(epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss = validate(model, val_loader, criterion, device)
    ...
'''})
# Add an image block with logged images
logged_images = run.history(samples=5, keys=["input_image_noisy", "input_image_original", "output_image"])
for i, row in logged_images.iterrows():
    input_image_noisy = Artifact.get_path(row["input_image_noisy"]).download_file()
    input_image_original = Artifact.get_path(row["input_image_original"]).download_file()
    output_image = Artifact.get_path(row["output_image"]).download_file()

    report.add_block("wandb-images", {
        "images": [
            {"src": input_image_noisy, "caption": f"Input Image Noisy {i+1}"},
            {"src": input_image_original, "caption": f"Input Image Original {i+1}"},
            {"src": output_image, "caption": f"Output Image {i+1}"}
        ]
    })

# Save and publish the report
report.save()
report.publish()


CommError: ignored