In [9]:
import torch
import os
from PIL import Image, ImageDraw, ImageFont
import random
import shutil
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt

print(f"CUDA is available: {torch.cuda.is_available()}")
print(f"Amount of CUDA devices available: {torch.cuda.device_count()}")
print(f"Index of current CUDA device: {torch.cuda.current_device()}")
print(f"Name of current CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print(f"Amount of CPU cores available: {os.cpu_count()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cores = os.cpu_count()

CUDA is available: True
Amount of CUDA devices available: 1
Index of current CUDA device: 0
Name of current CUDA device: Tesla V100-SXM3-32GB
Amount of CPU cores available: 96


In [10]:
def add_date_stamp_to_images(input_dir, output_dir):
    try:
        shutil.rmtree(output_dir)
        print(f"Directory '{output_dir}' has been removed successfully.")
    except OSError as error:
        print(error)

    # Create directories for clean and noisy (stamped) images
    clean_dir = os.path.join(output_dir, "clean")
    stamped_dir = os.path.join(output_dir, "stamped")
    os.makedirs(clean_dir, exist_ok=True)
    os.makedirs(stamped_dir, exist_ok=True)

    for subdir in ['training', 'validation', 'evaluation']:
        input_subdir = os.path.join(input_dir, subdir)
        clean_subdir = os.path.join(clean_dir, subdir)
        stamped_subdir = os.path.join(stamped_dir, subdir)
        os.makedirs(clean_subdir, exist_ok=True)
        os.makedirs(stamped_subdir, exist_ok=True)

        i = 0

        for image_name in os.listdir(input_subdir):
            if i >= 10:
                break
                # break or pass to limit or not

            # Load image
            image_path = os.path.join(input_subdir, image_name)
            image = Image.open(image_path)

            # Save clean version
            clean_output_path = os.path.join(clean_subdir, image_name)
            image.save(clean_output_path)

            # Add date stamp to a copy
            draw = ImageDraw.Draw(image)
            date_text = f"{random.choice(months)} {random.randint(1, 28)}, 2024"
            font = ImageFont.load_default(32)  # Use default font
            text_width = draw.textlength(date_text, font=font)
            x, y = image.width - text_width - 10, image.height - 50
            draw.text((x, y), date_text, fill="white", font=font)

            # Save stamped version
            stamped_output_path = os.path.join(stamped_subdir, image_name)
            image.save(stamped_output_path)

            if i % 500 == 0:
                print(f"Date stamped: {stamped_output_path}")
            i += 1

    print(f"Paired dataset created at '{output_dir}'")


months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November',
          'December']

add_date_stamp_to_images("dataset", "dataset_with_paired_images")

Directory 'dataset_with_paired_images' has been removed successfully.
Date stamped: dataset_with_paired_images/stamped/training/2_394.jpg
Date stamped: dataset_with_paired_images/stamped/validation/2_394.jpg
Date stamped: dataset_with_paired_images/stamped/evaluation/2_394.jpg
Paired dataset created at 'dataset_with_paired_images'


In [11]:
class PreloadedDateStampedDataset(Dataset):
    def __init__(self, stamped_dir, clean_dir, transform):

        self.stamped_images = []
        self.clean_images = []

        file_count = sum([len(files) for _, _, files in os.walk("dataset_with_paired_images/stamped/")])
        print(f"Total files in dataset_with_paired_images/stamped/: {file_count}")

        file_count = sum([len(files) for _, _, files in os.walk("dataset_with_paired_images/clean/")])
        print(f"Total files in dataset_with_paired_images/clean/: {file_count}")

        i = 0

        # Apply transformations and load stamped images
        for stamped_file in sorted(os.listdir(stamped_dir)):
            stamped_path = os.path.join(stamped_dir, stamped_file)
            stamped_image = Image.open(stamped_path).convert("RGB")
            if transform:
                stamped_image = transform(stamped_image)
            self.stamped_images.append(stamped_image)
            i += 1
            if i % 1000 == 0:
                print(f"Loaded {i} stamped images into memory.")

        i = 0
        # Apply transformations and load clean images
        for clean_file in sorted(os.listdir(clean_dir)):
            clean_path = os.path.join(clean_dir, clean_file)
            clean_image = Image.open(clean_path).convert("RGB")
            if transform:
                clean_image = transform(clean_image)
            self.clean_images.append(clean_image)
            i += 1
            if i % 1000 == 0:
                print(f"Loaded {i} clean images into memory.")

        # Ensure equal number of stamped and clean images
        assert len(self.stamped_images) == len(self.clean_images), "Mismatch in paired dataset size"

    def __len__(self):
        return len(self.stamped_images)

    def __getitem__(self, idx):
        # Preprocessed images are already in memory
        stamped_image = self.stamped_images[idx]
        clean_image = self.clean_images[idx]
        return stamped_image, clean_image


# Transformations for input images
transform = transforms.Compose([
    transforms.Resize((448, 448)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize to [-1, 1]
])

# Initialize DataLoaders
train_loader = DataLoader(
    PreloadedDateStampedDataset(
        stamped_dir='dataset_with_paired_images/stamped/training',
        clean_dir='dataset_with_paired_images/clean/training',
        transform=transform
    ),
    batch_size=4, shuffle=True, num_workers=0, pin_memory=True
)

val_loader = DataLoader(
    PreloadedDateStampedDataset(
        stamped_dir='dataset_with_paired_images/stamped/validation',
        clean_dir='dataset_with_paired_images/clean/validation',
        transform=transform
    ),
    batch_size=4, shuffle=False, num_workers=0, pin_memory=True
)

eval_loader = DataLoader(
    PreloadedDateStampedDataset(
        stamped_dir='dataset_with_paired_images/stamped/evaluation',
        clean_dir='dataset_with_paired_images/clean/evaluation',
        transform=transform
    ),
    batch_size=4, shuffle=False, num_workers=0, pin_memory=True
)

for stamped_batch, clean_batch in train_loader:
    print(stamped_batch.shape, clean_batch.shape)
    break

Total files in dataset_with_paired_images/stamped/: 30
Total files in dataset_with_paired_images/clean/: 30
Total files in dataset_with_paired_images/stamped/: 30
Total files in dataset_with_paired_images/clean/: 30
Total files in dataset_with_paired_images/stamped/: 30
Total files in dataset_with_paired_images/clean/: 30
torch.Size([4, 3, 448, 448]) torch.Size([4, 3, 448, 448])


In [12]:
class Diffusion:
    def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02):
        self.num_steps = num_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.betas = torch.linspace(beta_start, beta_end, num_steps).to(device)
        self.alphas = 1.0 - self.betas
        self.alpha_bar = torch.cumprod(self.alphas, dim=0)

    def forward_process(self, x0, t):
        noise = torch.randn_like(x0).to(device)
        sqrt_alpha_bar_t = torch.sqrt(self.alpha_bar[t])[:, None, None, None]
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - self.alpha_bar[t])[:, None, None, None]
        xt = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
        return xt, noise

    def reverse_process(self, model, xt, t):
        pred_noise = model(xt, t)
        sqrt_alpha_bar_t = torch.sqrt(self.alpha_bar[t])[:, None, None, None]
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - self.alpha_bar[t])[:, None, None, None]
        x0_pred = (xt - sqrt_one_minus_alpha_bar_t * pred_noise) / sqrt_alpha_bar_t
        return x0_pred


In [13]:
class UNet(nn.Module):
    def __init__(self, num_channels=3):
        super(UNet, self).__init__()
        self.encoder1 = self.conv_block(num_channels, 64)
        self.encoder2 = self.conv_block(64, 128)
        self.encoder3 = self.conv_block(128, 256)
        self.encoder4 = self.conv_block(256, 512)
        self.middle = self.conv_block(512, 512)
        self.decoder4 = self.conv_block(512, 256)
        self.decoder3 = self.conv_block(256, 128)
        self.decoder2 = self.conv_block(128, 64)
        self.decoder1 = nn.Conv2d(64, num_channels, kernel_size=3, padding=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
        )

    def forward(self, x, t):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        m = self.middle(e4)
        d4 = self.decoder4(m + e4)
        d3 = self.decoder3(d4 + e3)
        d2 = self.decoder2(d3 + e2)
        d1 = self.decoder1(d2 + e1)
        return d1


In [14]:
# Initialize diffusion process and model
diffusion = Diffusion(num_steps=200)
model = UNet(num_channels=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Gradient scaler for mixed precision
scaler = torch.cuda.amp.GradScaler()

# Early stopping parameters
patience = 5  # Number of epochs to wait for improvement
best_val_loss = float("inf")
early_stop_counter = 0

In [None]:


# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for stamped_images, clean_images in train_loader:
        stamped_images, clean_images = stamped_images.to(device), clean_images.to(device)

        # Sample random timesteps
        t = torch.randint(0, diffusion.num_steps, (stamped_images.size(0),)).to(device)

        # Forward process: add noise
        xt, noise = diffusion.forward_process(clean_images, t)

        # Mixed precision forward pass
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            pred_noise = model(xt, t)  # Predict noise
            loss = nn.MSELoss()(pred_noise, noise)  # Compute loss

        # Backpropagation with mixed precision
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

    train_loss = epoch_loss / len(train_loader)

    # Validation phase
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for stamped_images, clean_images in val_loader:
            stamped_images, clean_images = stamped_images.to(device), clean_images.to(device)

            # Sample random timesteps
            t = torch.randint(0, diffusion.num_steps, (stamped_images.size(0),)).to(device)

            # Forward process: add noise
            xt, noise = diffusion.forward_process(clean_images, t)

            # Mixed precision forward pass
            with torch.cuda.amp.autocast():
                pred_noise = model(xt, t)  # Predict noise
                loss = nn.MSELoss()(pred_noise, noise)  # Compute loss
                val_loss += loss.item()

    val_loss /= len(val_loader)

    print(f"Epoch [{epoch + 1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0

        # Save the best model
        torch.save(model.state_dict(), "best_diffusion_model.pth")
        print(f"Validation loss improved. Model saved.")
    else:
        early_stop_counter += 1
        print(f"Early stopping counter: {early_stop_counter}/{patience}")

        if early_stop_counter >= patience:
            print("Early stopping triggered. Training stopped.")
            break


Epoch [1/50] - Train Loss: 0.3645, Val Loss: 0.1571
Validation loss improved. Model saved.
Epoch [2/50] - Train Loss: 0.1108, Val Loss: 0.0919
Validation loss improved. Model saved.
Epoch [3/50] - Train Loss: 0.0802, Val Loss: 0.0791
Validation loss improved. Model saved.
Epoch [4/50] - Train Loss: 0.0753, Val Loss: 0.0663
Validation loss improved. Model saved.
Epoch [5/50] - Train Loss: 0.0627, Val Loss: 0.0556
Validation loss improved. Model saved.
Epoch [6/50] - Train Loss: 0.0566, Val Loss: 0.0886
Early stopping counter: 1/5
Epoch [7/50] - Train Loss: 0.0581, Val Loss: 0.0583
Early stopping counter: 2/5
Epoch [8/50] - Train Loss: 0.0498, Val Loss: 0.0473
Validation loss improved. Model saved.
Epoch [9/50] - Train Loss: 0.0528, Val Loss: 0.0481
Early stopping counter: 1/5
Epoch [10/50] - Train Loss: 0.0425, Val Loss: 0.0469
Validation loss improved. Model saved.
Epoch [11/50] - Train Loss: 0.0456, Val Loss: 0.0448
Validation loss improved. Model saved.
Epoch [12/50] - Train Loss: 0.

In [15]:
# Load the trained model
model.load_state_dict(torch.load("best_diffusion_model.pth", map_location=device))


def show_images(input_images, reconstructed, n=4):
    plt.figure(figsize=(12, 4))
    for i in range(n):
        plt.subplot(2, n, i + 1)
        input_img = input_images[i].permute(1, 2, 0).cpu().numpy()
        input_img = (input_img + 1) / 2
        plt.imshow(np.clip(input_img, 0, 1))
        plt.axis('off')
        plt.title("Stamped (Input)")
#
        plt.subplot(2, n, i + 1 + n)
        reconstructed_img = reconstructed[i].permute(1, 2, 0).cpu().numpy()
        reconstructed_img = (reconstructed_img + 1) / 2
        plt.imshow(np.clip(reconstructed_img, 0, 1))
        plt.axis('off')
        plt.title("Denoised (Output)")

    plt.tight_layout()
    plt.show()

# Evaluate and visualize results
model.eval()
with torch.no_grad():
    for stamped_images, clean_images in val_loader:
        stamped_images = stamped_images.to(device)
        xt, _ = diffusion.forward_process(stamped_images, t=torch.tensor([999]).to(device))
        for t_step in range(999, -1, -1):
            xt = diffusion.reverse_process(model, xt, torch.tensor([t_step]).to(device))

        show_images(stamped_images.cpu(), xt.cpu())
        break


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


/opt/pytorch/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [0,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
