In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

class DIPModel(nn.Module):
    def __init__(self):
        super(DIPModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),  # Changed input channels from 1 to 3
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),  # Output 3 channels
        )

    def forward(self, x):
        return self.layers(x)

def train_dip(model, noisy_image, clean_image, num_epochs, optimizer, criterion):
    model.train()
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        output = model(noisy_image)
        loss = criterion(output, clean_image)
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

    return model

def add_noise(image, noise_factor=0.5):
    noisy_image = image + noise_factor * torch.randn_like(image)
    noisy_image = torch.clamp(noisy_image, 0., 1.)
    return noisy_image

def show_image(tensor, title=""):
    image = tensor.detach().cpu().squeeze().numpy()
    plt.imshow(image)
    plt.title(title)
    plt.axis('off')
    plt.show()

In [5]:
class DDPM(nn.Module):
    def __init__(self, timesteps=1000):
        super(DDPM, self).__init__()
        self.timesteps = timesteps
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x, t):
        # Simulate the effect of timestep t
        return self.model(x) * (1 - t / self.timesteps)

def train_ddpm(model, initial_prior, num_epochs, optimizer, criterion):
    model.train()
    for epoch in range(num_epochs):
        for t in range(model.timesteps):
            optimizer.zero_grad()
            noisy_input = initial_prior + torch.randn_like(initial_prior) * (t / model.timesteps)
            output = model(noisy_input, t)
            loss = criterion(output, initial_prior)
            loss.backward()
            optimizer.step()

        if (epoch + 1) % 1 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

    return model

In [None]:
# Mount Google Drive if using Google Colab
from google.colab import drive
drive.mount('/content/drive')

# Training parameters
img_size = 64
num_epochs_dip = 1000
num_epochs_ddpm = 60
learning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define data transformation
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load dataset
dataset_dir = '/content/drive/My Drive/Colab_Notebooks/P4_cars_trains/'
dataset = datasets.ImageFolder(root=dataset_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Get one sample image
dataiter = iter(dataloader)
clean_image, _ = next(dataiter)
clean_image = clean_image.to(device)

# Add noise
noisy_image = add_noise(clean_image).to(device)

# Initialize DIP model
dip_model = DIPModel().to(device)
optimizer_dip = optim.Adam(dip_model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Train DIP
trained_dip_model = train_dip(dip_model, noisy_image, clean_image, num_epochs_dip, optimizer_dip, criterion)

# DIP output = initial prior for DDPM
initial_prior = trained_dip_model(noisy_image).detach()

# Initialize DDPM model
ddpm_model = DDPM().to(device)
optimizer_ddpm = optim.Adam(ddpm_model.parameters(), lr=learning_rate)

# Train DDPM
trained_ddpm_model = train_ddpm(ddpm_model, initial_prior, num_epochs_ddpm, optimizer_ddpm, criterion)

# Display results
show_image(noisy_image, title='Noisy Image')
show_image(initial_prior, title='DIP Initial Prior')
show_image(trained_ddpm_model(initial_prior, trained_ddpm_model.timesteps), title='DDPM Final Output')
show_image(clean_image, title='Original Image')
