#### Import Dependencies

In [None]:
import numpy as np
import math
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from collections import deque

#### Set Flags

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
project_path = "/home/s184310/3.Project/data" #"/home/s184310/3.Project" the path to ypur project root
patch_size = 40 # The size of the patches to use (You will probably need to experiment with this parameter)
patch_stride = 20 # The stride used when generating data
channels = 1
BATCH_SIZE = 5
#define the area we use as our test data
rect_x_start, rect_x_end = 100, 290 
rect_y_start, rect_y_end = 200, 390

test_size=0.2
random_state=42

mask_count = 140
T = 1000 #Timesteps


epochs = 100 # Try more!
learning_rate = 0.001

num_images = 4 #number of images shown in plots

#### Load Data

In [None]:
grid_array = np.load(f'{project_path}/final_merged_data/50_mapiveg_percent.npy')
print("Shape of loaded grid array:", grid_array.shape)

grid_array[grid_array == -1] = -2

plt.imshow(grid_array, cmap='gray')
plt.colorbar()
plt.axis('off')
plt.show()

#### Preprocess Data

In [None]:
def normalize_data(data):
    min_val = np.nanmin(data, axis=(0, 1), keepdims=True)
    max_val = np.nanmax(data, axis=(0, 1), keepdims=True)
    normalized_data = (((data - min_val) / (max_val - min_val))*2)-1
    return normalized_data, min_val, max_val

def denormalize_data(data, min_val, max_val):
    return data * (max_val - min_val) + min_val

In [None]:
def normalize_data(data):
    mask = data != -2
    masked_data = np.where(mask, data, np.nan)  # Replace -2 with np.nan for computation
    min_val = np.nanmin(masked_data)
    max_val = np.nanmax(masked_data)
    normalized_data = np.where(mask, (((data - min_val) / (max_val - min_val)) * 2) - 1, -2)
    return normalized_data, min_val, max_val

def denormalize_data(data, min_val, max_val):
    mask = data != -2
    denormalized_data = np.where(mask, data * (max_val - min_val) / 2 + (max_val + min_val) / 2, -2)
    return denormalized_data

In [None]:
normalized_grid_array, min_val, max_val = normalize_data(grid_array)

#### Generate Train, Validation, and Test Data

In [None]:
plt.imshow(normalized_grid_array, cmap='gray')
plt.colorbar()
plt.title("Data with test set marked")
rectangle = plt.Rectangle((rect_x_start, rect_y_start), 
                          rect_x_end - rect_x_start, 
                          rect_y_end - rect_y_start, 
                          edgecolor='red', 
                          facecolor='none', 
                          linewidth=2)
plt.gca().add_patch(rectangle)
plt.show()

In [None]:
test_data = normalized_grid_array[rect_y_start:rect_y_end,rect_x_start:rect_x_end].copy()
normalized_grid_array[rect_y_start:rect_y_end,rect_x_start:rect_x_end] = -3
train_data = normalized_grid_array
plt.imshow(train_data, cmap='gray')
plt.colorbar()
plt.title("Test Data")
plt.show()

In [None]:
plt.imshow(test_data, cmap='gray')
plt.colorbar()
plt.title("Test Data")
plt.show()

In [None]:
def create_input_output_pairs(data, patch_size=patch_size, stride=patch_stride):
    patches = []
    for i in range(0,data.shape[0]-patch_size,stride):
        for j in range(0,data.shape[1]-patch_size,stride):
            current_patch = data[i:patch_size+i,j:patch_size+j].copy()
            if not np.any(current_patch == -3): #Remove all patches where test data is recorded
                #if not np.all(current_patch == -2): #Remove all patches where no data is recorded
                patches.append(current_patch)
    return np.array(patches)

train_patches = create_input_output_pairs(train_data)
test_patches = create_input_output_pairs(test_data)

In [None]:
# Split the dataset into training and validation
train_patches, val_patches = train_test_split(train_patches, test_size=test_size, random_state=random_state)

print("Train inputs shape:", train_patches.shape)
print("Val inputs shape:", val_patches.shape)
print("Test inputs shape:", test_patches.shape)

In [None]:
class PatchDataset(Dataset):
    def __init__(self, data, mask_pixels=mask_count):
        self.data = data
        self.mask_pixels = mask_pixels
        self.height, self.width = data.shape[1], data.shape[2]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        image = torch.tensor(image, dtype=torch.float32)  # Convert to torch tensor

        mask = torch.ones_like(image)
        num_pixels_to_mask = self.mask_pixels

        i = np.random.randint(0, self.height)
        j = np.random.randint(0, self.width)
        
        directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        queue = deque([(i, j)])
        mask[i, j, :] = 0
        num_pixels_to_mask -= 1

        while queue and num_pixels_to_mask > 0:
            ci, cj = queue.popleft()

            for di, dj in directions:
                ni, nj = ci + di, cj + dj
                if 0 <= ni < self.height and 0 <= nj < self.width and mask[ni, nj, 0] == 1:
                    mask[ni, nj, :] = 0
                    num_pixels_to_mask -= 1
                    queue.append((ni, nj))

                    if num_pixels_to_mask == 0:
                        break

        if num_pixels_to_mask > 0:
            for i in range(self.height):
                for j in range(self.width):
                    if num_pixels_to_mask == 0:
                        break
                    if mask[i, j, 0] == 1:
                        mask[i, j, :] = 0
                        num_pixels_to_mask -= 1

        masked_image = image * mask

        return masked_image.permute(2, 0, 1), image.permute(2, 0, 1), mask.permute(2, 0, 1)


train_dataset = PatchDataset(train_patches)
val_dataset = PatchDataset(val_patches)
test_dataset = PatchDataset(test_patches)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Test the dataloader
for masked_image, image, mask in train_dataloader:
    print(masked_image.shape)
    print(image.shape)
    print(mask.shape)
    break


#### Define Model

##### The forward process

In [None]:
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps, device=device)

def get_index_from_list(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t)#.cpu()
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))#.to(t.device)

def forward_diffusion_sample(x_0, t, device=device):
    """ 
    Takes an image and a timestep as input and 
    returns the noisy version of it
    """
    x_0 = x_0.to(device)
    #noise = torch.randn_like(x_0).to(device)
    noise = torch.randn_like(x_0, device=device)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape).to(device)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    ).to(device)
    c1 = sqrt_alphas_cumprod_t * x_0
    c2 = c1 + sqrt_one_minus_alphas_cumprod_t * noise
    return c2, noise


# Define beta schedule

betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

In [None]:
# Simulate forward diffusion
image_masked,image,mask = next(iter(train_dataloader))
plt.figure(figsize=(15,15))
stepsize = int(T/num_images)

def show_tensor_image(image):
    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :] 
    plt.imshow(image.permute(1, 2, 0).detach().cpu(), cmap='gray')
    plt.axis('off')

for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64).to(device)
    plt.subplot(1, num_images + 1, int(idx / stepsize) + 1)
    img, noise = forward_diffusion_sample(image, t)
    show_tensor_image(img)

In [None]:
sample = next(iter(train_dataloader))
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# Show the first image on the first subplot
plt.sca(axes[0])
show_tensor_image(sample[1])
axes[0].set_title('Original Image')

# Show the second image on the second subplot
plt.sca(axes[1])
show_tensor_image(sample[0])
axes[1].set_title('Masked Image')

# Display the figure
plt.show()

In [None]:
show_tensor_image(sample[2])

##### The backward process

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch).to(device)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1).to(device)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1).to(device)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1).to(device)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1).to(device)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1).to(device)
        self.bnorm1 = nn.BatchNorm2d(out_ch).to(device)
        self.bnorm2 = nn.BatchNorm2d(out_ch).to(device)
        self.relu  = nn.ReLU().to(device)
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = channels
        down_channels = (patch_size, patch_size*2, patch_size*4, patch_size*8, patch_size*16)
        up_channels = (patch_size*16, patch_size*8, patch_size*4, patch_size*2, patch_size)
        out_dim = channels
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim).to(device),
                nn.Linear(time_emb_dim, time_emb_dim).to(device),
                nn.ReLU().to(device)
            )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1).to(device)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim).to(device) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True).to(device) \
                    for i in range(len(up_channels)-1)])
        
        
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1).to(device)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            if x.size(2) != residual_x.size(2) or x.size(3) != residual_x.size(3): #MAYBE REMOVE
                x = F.interpolate(x, size=(residual_x.size(2), residual_x.size(3)), mode='bilinear', align_corners=False)
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

model = SimpleUnet().to(device)
print("Num params: ", sum(p.numel() for p in model.parameters()))



**Further improvements that can be implemented:**
- Residual connections
- Different activation functions like SiLU, GWLU, ...
- BatchNormalization 
- GroupNormalization
- Attention
- ...

#### Train Model

##### loss

In [None]:
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.mse_loss(noise_pred,noise)
    #return F.l1_loss(noise, noise_pred)

In [None]:
def get_loss(model, x_0, t):
    # Apply the forward diffusion process to generate noisy data
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    
    # Predict the noise from the noisy data using the model
    noise_pred = model(x_noisy, t)
    
    # Create a mask where True represents values not equal to -2
    mask = (x_0 != -2)
    
    # Apply the mask to filter out -2 values from noise and noise_pred
    valid_noise = noise[mask]
    valid_noise_pred = noise_pred[mask]
    
    # Compute the mean squared error loss only for valid values
    loss = F.mse_loss(valid_noise_pred, valid_noise)
    
    return loss

##### Sampling

In [None]:
@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x, device=device)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def inpainting_plot_image(image, mask):
    plt.figure(figsize=(15, 15))
    stepsize = int(T / num_images)
    
    # Display the original image
    plt.subplot(1, num_images + 1, 1)
    show_tensor_image(image.detach().cpu())
    plt.title(f"Original Image")  # Set the title of the subplot

    masked_image = image*mask

    image_list = []
    image_idx_plot_list = []
    
    for idx in range(0, T):
        t = torch.Tensor([idx]).type(torch.int64).to(device)
        img, noise = forward_diffusion_sample(masked_image, t)
        if idx % stepsize == 0:
            image_idx_plot_list.append(t)            
            plt.subplot(1, num_images+1, 1+int(idx/stepsize) + 1)
            show_tensor_image(img.detach().cpu())
            #plt.title(f"Timestep {idx}")
        image_list.append((img,t))
    plt.show()


    # plot backward process
    plt.figure(figsize=(15, 15))
    plt.subplot(1, num_images, 1)  # this plots the original image
    tmp_mask = 1 #
    tmp_img = image_list[-1][0]
    for idx, (noisy_img,t) in enumerate(image_list[::-1]):
        img = (noisy_img*tmp_mask)+(tmp_img*(1-tmp_mask))
        tmp_mask = mask
        tmp_img = sample_timestep(img,t)
        if t in image_idx_plot_list:
            plt.subplot(1, num_images, num_images-(int(idx/stepsize)))
            show_tensor_image(tmp_img.detach().cpu())
            #plt.title(f"Timestep {idx}")
    plt.show()

##### Training Loop

In [None]:
optimizer = Adam(model.parameters(), lr=learning_rate)

def train_epoch(model):
    loss_value = 0
    for batch in train_dataloader:
        optimizer.zero_grad()
        t = torch.randint(0, T, (batch[1].shape[0],), device=device).long()
        loss = get_loss(model, batch[1].to(device), t)
        loss.backward()
        optimizer.step()
        loss_value+=loss.item()
    return loss_value/len(train_dataloader)

In [None]:
@torch.no_grad()
def val_epoch(model):
    loss_value = 0
    for batch in val_dataloader:
        t = torch.randint(0, T, (batch[1].shape[0],), device=device).long()
        loss = get_loss(model, batch[1].to(device), t)
        loss_value+=loss.item()
    return loss_value/len(val_dataloader)

In [None]:
loss_list = []
for epoch in range(epochs):
    train_loss = train_epoch(model)
    val_loss = val_epoch(model)
    print(f"Epoch {epoch}/{epochs}: Train Loss: {train_loss}, Validation Loss: {val_loss}")
    loss_list.append((train_loss,val_loss))
    if epoch > 10:
        sample = next(iter(train_dataloader))
        inpainting_plot_image(sample[1].to(device),sample[2].to(device))
        inpainting_plot_image(sample[1].to(device),sample[2].to(device))

In [None]:
plt.figure()

# Plot the training loss
plt.plot([g[0] for g in loss_list], label='Train Loss')

# Plot the validation loss
plt.plot([g[1] for g in loss_list], color='orange', label='Validation Loss')

# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')

# Add legend
plt.legend()

# Display the plot
plt.show()

In [None]:
torch.save(model.state_dict(), 'best_model_singlefeature.pt')

In [None]:
loss_list = []
for epoch in range(epochs):
    train_loss = train_epoch(model)
    val_loss = val_epoch(model)
    print(f"Epoch {epoch}/{epochs}: Train Loss: {train_loss}, Validation Loss: {val_loss}")
    loss_list.append((train_loss,val_loss))
    if epoch > 4:
        sample = next(iter(train_dataloader))
        inpainting_plot_image(sample[1].to(device),sample[2].to(device))
        inpainting_plot_image(sample[1].to(device),sample[2].to(device))


In [None]:
plt.figure()

# Plot the training loss
plt.plot([g[0] for g in loss_list], label='Train Loss')

# Plot the validation loss
plt.plot([g[1] for g in loss_list], color='orange', label='Validation Loss')

# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')

# Add legend
plt.legend()

# Display the plot
plt.show()


In [None]:
loss_list = []
for epoch in range(epochs):
    train_loss = train_epoch(model)
    val_loss = val_epoch(model)
    print(f"Epoch {epoch}/{epochs}: Train Loss: {train_loss}, Validation Loss: {val_loss}")
    loss_list.append((train_loss,val_loss))
    if epoch > 55:
        sample = next(iter(train_dataloader))
        inpainting_plot_image(sample[1].to(device),sample[2].to(device))

In [None]:
plt.figure()
# Plot the training loss
plt.plot([g[0] for g in loss_list], label='Train Loss')
# Plot the validation loss
plt.plot([g[1] for g in loss_list], color='orange', label='Validation Loss')

# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')

# Add legend
plt.legend()

# Display the plot
plt.show()

In [None]:
# Evaluation function
@torch.no_grad()
def evaluate_model_on_test(model, test_dataloader):
    model.eval()
    mse_per_channel = np.zeros(channels)
    total_batches = 0

    for masked_image, original_image, mask in test_dataloader:
        masked_image, original_image, mask = masked_image.to(device), original_image.to(device), mask.to(device)
        t = torch.randint(0, T, (original_image.shape[0],), device=device).long()
        generated_images = model(masked_image, t)
        
        # Calculate MSE for the single channel, excluding -1 values
        for i in range(original_image.shape[0]):
            valid_mask = (original_image[i, 0, :, :] != -1)
            mse_per_channel[0] += F.mse_loss(
                generated_images[i, 0, :, :][valid_mask], 
                original_image[i, 0, :, :][valid_mask]
            ).item()
        
        total_batches += 1

    mse_per_channel /= total_batches
    print("MSE per channel:", np.round(mse_per_channel, 3))

    # Print 5 examples of original and generated values
    num_examples = min(5, original_image.shape[0])
    for i in range(num_examples):
        print(f"Example {i+1}:")
        print("Original:", np.round(original_image[i, 0, :, :].cpu().numpy().flatten()[:10], 3))
        print("Generated:", np.round(generated_images[i, 0, :, :].cpu().numpy().flatten()[:10], 3))

# Assuming the test_dataloader is defined and contains the unknown test dataset
evaluate_model_on_test(model, test_dataloader)