#### Import Dependencies

In [None]:
import pandas as pd
import numpy as np
import math
import torch
import torch.nn as nn
import os
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from collections import deque
from sklearn.model_selection import train_test_split
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr


#### Create numpy grid 
In order to enable image representation

In [None]:
#### FOR ALL COLUMNS 

# Check for GPU availability
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Read the CSV file with coordinates and features
df = pd.read_csv('/home/s184310/3.Project/data/final_merged_data/50_all_data.csv')

# Constants for cell sizes in degrees (approximations)
cell_size_50m_lat = 0.000450  # Approx. 50m in latitude
cell_size_50m_lon = 0.000799  # Approx. 50m in longitude at Copenhagen's latitude

# Specify the bounding box (Nyhavn and Kongens Nytorv area with slight increase)
copenhagen_bbox = [55.545, 12.175, 55.809, 12.745]

# Generate grid cells and their centers
def generate_grid(min_lat, min_lon, max_lat, max_lon, cell_size_lat, cell_size_lon):
    lat_points = np.arange(min_lat, max_lat, cell_size_lat)
    lon_points = np.arange(min_lon, max_lon, cell_size_lon)
    centers = []
    for lat in lat_points:
        for lon in lon_points:
            center_lat = lat + cell_size_lat / 2
            center_lon = lon + cell_size_lon / 2
            centers.append((center_lat, center_lon))
    return centers

centers = generate_grid(*copenhagen_bbox, cell_size_50m_lat, cell_size_50m_lon)

# Convert centers to PyTorch tensor and move to device
centers_tensor = torch.tensor(centers, dtype=torch.float32).to(device)

# Map coordinates to the nearest center point
def find_nearest_center(lat, lon, centers):
    lat_lon_tensor = torch.tensor([lat, lon], dtype=torch.float32).to(device)
    distances = torch.sqrt(torch.sum((centers - lat_lon_tensor) ** 2, dim=1))
    nearest_center_idx = torch.argmin(distances)
    return nearest_center_idx.item()  # Convert to Python int

# Adding a progress bar to the mapping process
print("Mapping coordinates to the nearest center point...")
df['Center_Index'] = df.progress_apply(lambda row: find_nearest_center(row['Latitude'], row['Longitude'], centers_tensor), axis=1)

# Create a numpy array with the same structure as the grid
num_lat_cells = int((copenhagen_bbox[2] - copenhagen_bbox[0]) / cell_size_50m_lat) + 1
num_lon_cells = int((copenhagen_bbox[3] - copenhagen_bbox[1]) / cell_size_50m_lon) + 1

# Determine the number of features (excluding Latitude and Longitude)
num_features = df.shape[1] - 3  # Exclude Latitude, Longitude, and Center_Index

# Initialize a 3D array to hold the features for each cell
grid_array = np.full((num_lat_cells, num_lon_cells, num_features), np.nan)  # Initialize with NaNs

# Fill the numpy array with feature values
print("Filling the numpy array with feature values...")
for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    center_idx = int(row['Center_Index'])  # Ensure center_idx is an integer
    lat_idx = center_idx // num_lon_cells
    lon_idx = center_idx % num_lon_cells
    for feature_idx in range(num_features):
        if np.isnan(grid_array[lat_idx, lon_idx, feature_idx]):
            grid_array[lat_idx, lon_idx, feature_idx] = row.iloc[2 + feature_idx]  # Offset by 2 to skip Latitude and Longitude columns

# Save the numpy array to a file
np.save('/home/s184310/3.Project/data/final_merged_data/50_all_data.npy', grid_array)

print("Grid array saved successfully.")

In [None]:
# Load the numpy array from the file (for later use)
loaded_grid_array = np.load('/home/s184310/3.Project/data/final_merged_data/50_all_data.npy')

# Print the shape of the loaded grid array
print("Shape of loaded grid array:", loaded_grid_array.shape)

#### Set Flags

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
project_path = "/home/s184310/3.Project/data/final_merged_data" #"/home/s184310/3.Project" the path to ypur project root
patch_size = 50 # The size of the patches to use (You will probably need to experiment with this parameter)
patch_stride = 10 # The stride used when generating data
channels = 51 # The number of channels in the data
BATCH_SIZE = 4
#define the area we use as our test data
rect_x_start, rect_x_end = 100, 290 
rect_y_start, rect_y_end = 200, 390

channel_to_display =0 #the channel to show in the plots (0-39)

test_size=0.2
random_state=42

mask_count = 140
T = 1000 #Timesteps


epochs = 350 # Try more!
learning_rate = 0.001

num_images = 5 #number of images shown in plots

#### Load Data

In [None]:
grid_array = np.load(f'{project_path}/50_all_data.npy')
print(grid_array.shape)
# Initialize a list to store the count of valid data points for each channel
valid_data_counts = []

# Iterate over each channel
for channel in range(grid_array.shape[-1]):
    # Extract the data for the current channel
    channel_data = grid_array[:, :, channel]
    
    # Count the number of valid data points (not 0 and not -1)
    valid_count = np.sum((channel_data != 0) & (channel_data != -1))
    
    # Append the count to the list
    valid_data_counts.append(valid_count)

# Find the channel with the maximum valid data points
max_valid_channel = np.argmax(valid_data_counts)
max_valid_count = valid_data_counts[max_valid_channel]

print(f"Channel with the most valid data: {max_valid_channel}")
print(f"Number of valid data points in this channel: {max_valid_count}")

In [None]:
df = pd.read_csv('/home/s184310/3.Project/data/final_merged_data/50_all_data.csv')

# Get the column names
columns = df.columns

# Calculate the actual column index (accounting for starting from the third column)
column_index = channel_to_display+2

# Ensure the index is within the range of the dataframe's columns
if column_index < len(columns):
    column_name = columns[column_index]
    print(f"The column to be displayed is: {column_name}")

In [None]:
# Create a list of column indices starting from the second column
column_indices = list(range(2, len(columns)))

# Print the column names along with their indices
for i in column_indices:
    print(f"Column {i-2}: {columns[i]}")

In [None]:
# Load the grid array and select the channel to display
#grid_array = np.load(f'{project_path}/50m_26_features_array.npy')
##print("Shape of loaded grid array:", grid_array.shape)
channel_to_display = 33

# Update all -1 values to -2
grid_array[grid_array == -1] = -2

# Mask the -2 values by setting them to NaN for visualization
#channel_data = np.where(grid_array[:, :, channel_to_display] == -2, np.nan, grid_array[:, :, channel_to_display])
channel_data = np.where(grid_array[:, :, channel_to_display] == 0, np.nan, grid_array[:, :, channel_to_display])
#channel_data = grid_array[:, :, channel_to_display]

# Mask the -1 values by setting them to NaN
channel_data = np.where(channel_data == -2, np.nan, channel_data)
channel_data = np.where(channel_data == 0, np.nan, channel_data)

# Load the dataframe to get column names
#df = pd.read_csv('/home/s184310/3.Project/data/merged_mapi_poi_metrics.csv')
columns = df.columns

# Calculate the actual column index (accounting for starting from the third column)
column_index = channel_to_display + 2

# Ensure the index is within the range of the dataframe's columns
if column_index < len(columns):
    column_name = columns[column_index]
    title = f"Original data for {column_name} feature"
else:
    title = f"Original data for channel {channel_to_display}: Column index out of range"

# Plot the data with a perceptually uniform colormap
plt.figure(figsize=(10, 8))  # Increase figure size for better coverage visibility
plt.imshow(channel_data, cmap='viridis', interpolation='nearest', aspect='auto')
plt.colorbar()
plt.title(title)
plt.axis('off')
plt.show()

# Print the min and max values excluding NaNs to understand the color range
min_val = np.nanmin(channel_data)
max_val = np.nanmax(channel_data)
print(f"Data range for displayed channel: {min_val} to {max_val}")


#### Preprocess Data

In [None]:
## Original
# Normalization functions
#def normalize_data(data):
#    min_val = np.nanmin(data, axis=(0, 1), keepdims=True)
#    max_val = np.nanmax(data, axis=(0, 1), keepdims=True)
#    normalized_data = (((data - min_val) / (max_val - min_val))*2)-1
#    return normalized_data, min_val, max_val

#def denormalize_data(data, min_val, max_val):
#    return data * (max_val - min_val) + min_val

In [None]:
def normalize_data(data):
    # Create a mask for -2 values
    mask = (data != -2) & ~np.isnan(data)
    
    # Compute min and max excluding -2 values, along the third axis (channels)
    with np.errstate(all='ignore'):  # Suppress warnings for NaN slices
        min_val = np.nanmin(np.where(mask, data, np.nan), axis=(0, 1), keepdims=True)
        max_val = np.nanmax(np.where(mask, data, np.nan), axis=(0, 1), keepdims=True)
    
    # Handle cases where the min and max are NaN (i.e., no valid data points in the channel)
    min_val = np.nan_to_num(min_val, nan=0.0)  # Replace NaN with 0.0
    max_val = np.nan_to_num(max_val, nan=1.0)  # Replace NaN with 1.0
    
    # Ensure min_val != max_val to avoid division by zero
    scale = np.where(max_val != min_val, max_val - min_val, 1)
    
    # Normalize data excluding -2 values
    normalized_data = np.where(mask, (((data - min_val) / scale) * 2) - 1, -2)
    
    return normalized_data, min_val, max_val

def denormalize_data(data, min_val, max_val):
    # Ensure min_val and max_val are numpy arrays
    if isinstance(min_val, torch.Tensor):
        min_val = min_val.cpu().numpy()
    if isinstance(max_val, torch.Tensor):
        max_val = max_val.cpu().numpy()
    
    # Handle NaN values in min and max
    min_val = np.nan_to_num(min_val, nan=0.0)
    max_val = np.nan_to_num(max_val, nan=1.0)

    # Ensure min_val and max_val are 3D for broadcasting (1, channels, 1, 1)
    min_val = min_val.reshape(1, -1, 1, 1)
    max_val = max_val.reshape(1, -1, 1, 1)

    # Create a mask for -2 values
    mask = (data != -2) & ~np.isnan(data)

    # Ensure min_val != max_val to avoid division by zero
    scale = np.where(max_val != min_val, max_val - min_val, 1)

    # Denormalize data excluding -2 values
    denormalized_data = np.where(mask, (data + 1) / 2 * scale + min_val, -2)

    return denormalized_data

In [None]:
normalized_grid_array, min_val, max_val = normalize_data(grid_array)
print(min_val, max_val)

#### Generate Train, Validation, and Test Data

In [None]:
channel_data = normalized_grid_array[:,:,channel_to_display]
# Mask the -1 values by setting them to NaN
channel_data = np.where(channel_data == 0, np.nan, channel_data)
plt.imshow(channel_data, cmap='gray')
plt.colorbar()
plt.title(f"Normalised {column_name} with test are marked")
rectangle = plt.Rectangle((rect_x_start, rect_y_start), 
                          rect_x_end - rect_x_start, 
                          rect_y_end - rect_y_start, 
                          edgecolor='red', 
                          facecolor='none', 
                          linewidth=2)
plt.gca().add_patch(rectangle)
plt.axis('off')
plt.show()

In [None]:
test_data = normalized_grid_array[rect_y_start:rect_y_end,rect_x_start:rect_x_end,:].copy()
normalized_grid_array[rect_y_start:rect_y_end,rect_x_start:rect_x_end,:] = -3
train_data = normalized_grid_array
plt.imshow(train_data[:,:,channel_to_display], cmap='gray')
plt.colorbar()
plt.title(f"Test data for {channel_to_display}: {column_name} blacked out in training data")
plt.axis('off')
plt.show()

In [None]:
plt.imshow(test_data[:,:,channel_to_display], cmap='gray')
plt.colorbar()
plt.title(f"Close up on for {column_name} test data")
plt.axis('off')
plt.show()

In [None]:
def create_input_output_pairs(data, patch_size=patch_size, stride=patch_stride):
    patches = []
    for i in range(0,data.shape[0]-patch_size,stride):
        for j in range(0,data.shape[1]-patch_size,stride):
            current_patch = data[i:patch_size+i,j:patch_size+j].copy()
            if not np.any(current_patch == -3): #Make sure to not include any test data by removing any patch containing -3
                if not np.all(current_patch == -2): #Remove all patches where no data is recorded
                    patches.append(current_patch)
    return np.array(patches)

train_patches = create_input_output_pairs(train_data)
test_patches = create_input_output_pairs(test_data)

In [None]:
# Split the dataset into training and validation
train_patches, val_patches = train_test_split(train_patches, test_size=test_size, random_state=random_state)

print("Train inputs shape:", train_patches.shape)
print("Val inputs shape:", val_patches.shape)
print("Test inputs shape:", test_patches.shape)

In [None]:
class PatchDataset(Dataset):
    def __init__(self, data, mask_pixels=mask_count):
        self.data = data
        self.mask_pixels = mask_pixels
        self.height, self.width = data.shape[1], data.shape[2]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        image = torch.tensor(image, dtype=torch.float32)  # Convert to torch tensor

        mask = torch.ones_like(image)
        num_pixels_to_mask = self.mask_pixels

        # Start from a random initial position
        i = np.random.randint(0, self.height)
        j = np.random.randint(0, self.width)
        
        # Directions for movement: up, down, left, right
        directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        
        # Queue for BFS
        queue = deque([(i, j)])
        mask[i, j, :] = 0
        num_pixels_to_mask -= 1

        while queue and num_pixels_to_mask > 0:
            ci, cj = queue.popleft()

            for di, dj in directions:
                ni, nj = ci + di, cj + dj
                if 0 <= ni < self.height and 0 <= nj < self.width and mask[ni, nj, 0] == 1:
                    mask[ni, nj, :] = 0
                    num_pixels_to_mask -= 1
                    queue.append((ni, nj))

                    if num_pixels_to_mask == 0:
                        break

        # Ensure exactly self.mask_pixels are masked
        if num_pixels_to_mask > 0:
            for i in range(self.height):
                for j in range(self.width):
                    if num_pixels_to_mask == 0:
                        break
                    if mask[i, j, 0] == 1:
                        mask[i, j, :] = 0
                        num_pixels_to_mask -= 1

        masked_image = image * mask

        return masked_image.permute(2, 0, 1), image.permute(2, 0, 1), mask.permute(2, 0, 1)


train_dataset = PatchDataset(train_patches)
val_dataset = PatchDataset(val_patches)
test_dataset = PatchDataset(test_patches)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Test the dataloader
for masked_image, image, mask in train_dataloader:
    print(masked_image.shape)
    print(image.shape)
    print(mask.shape)
    break


#### Define Model

##### The forward process

In [None]:
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def cosine_beta_schedule(timesteps:int, s:float=0.008)->torch.tensor:
    """
    The cosine scheduler

    Parameters
    -----------
    timesteps : int
        The number of timesteps
    Returns
    ---------
    torch.tensor: The beta values

    """
    def f(t):
        return torch.cos((t / timesteps + s) / (1 + s) * 0.5 * torch.pi) ** 2
    x = torch.linspace(0, timesteps, timesteps + 1).to(device)
    alphas_cumprod = f(x) / f(torch.tensor([0]).to(device))
    betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
    betas = torch.clip(betas, 0.0001, 0.999).to(device)
    return betas

def get_index_from_list(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t).to(t.device)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) # Ensure output is on the same device as t

def forward_diffusion_sample(x_0, t, device=device):
    """ 
    Takes an image and a timestep as input and 
    returns the noisy version of it
    """
    x_0 = x_0.to(device)
    noise = torch.randn_like(x_0, device=device)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape).to(device)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    ).to(device)
    # mean + variance
    c1 = sqrt_alphas_cumprod_t.to(device) * x_0 
    c2 = c1 + sqrt_one_minus_alphas_cumprod_t.to(device) * noise 
    return c2, noise.to(device) 


# Define beta schedule

betas = cosine_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

# Debugging: Print and visualize the betas and alphas
print("Betas:", betas)
print("Alphas:", alphas)
print("Alphas Cumprod:", alphas_cumprod)

# Plot betas
plt.figure(figsize=(10, 5))
plt.plot(betas.cpu().numpy(), label='Betas')
plt.plot(alphas.cpu().numpy(), label='Alphas')
plt.title('Beta and Alpha Schedule')
plt.xlabel('Timestep')
plt.ylabel('Value')
plt.legend()
plt.show()

# Plot alphas cumprod
plt.figure(figsize=(10, 5))
plt.plot(alphas_cumprod.cpu().numpy(), label='Alphas Cumprod')
plt.title('Alphas Cumprod Schedule')
plt.xlabel('Timestep')
plt.ylabel('Alpha Cumprod Value')
plt.legend()
plt.show()

In [None]:
# Simulate forward diffusion
image_masked,image,mask = next(iter(train_dataloader))
plt.figure(figsize=(15,15))
stepsize = int(T/num_images)

def show_tensor_image(image):
    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, channel_to_display:channel_to_display+1, :, :] 
    plt.imshow(image.permute(1, 2, 0), cmap='gray')
    plt.axis('off')


for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64).to(device)
    plt.subplot(1, num_images+1, int(idx/stepsize) + 1)
    img, noise = forward_diffusion_sample(image, t)
    show_tensor_image(img.detach().cpu())

In [None]:
sample = next(iter(train_dataloader))
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# Show the first image on the first subplot
plt.sca(axes[0])
show_tensor_image(sample[1])
axes[0].set_title('Original Image')

# Show the second image on the second subplot
plt.sca(axes[1])
show_tensor_image(sample[0])
axes[1].set_title('Masked Image')

# Display the figure
plt.show()

In [None]:
show_tensor_image(sample[2])

##### The backward process

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch).to(device)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1).to(device)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1).to(device)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1).to(device)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1).to(device)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1).to(device)
        self.bnorm1 = nn.BatchNorm2d(out_ch).to(device)
        self.bnorm2 = nn.BatchNorm2d(out_ch).to(device)
        self.relu  = nn.ReLU().to(device)
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False, num_heads=8):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch).to(device)
        if up:
            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1).to(device)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1).to(device)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1).to(device)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1).to(device)
        
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1).to(device)
        self.bnorm1 = nn.BatchNorm2d(out_ch).to(device)
        self.bnorm2 = nn.BatchNorm2d(out_ch).to(device)
        self.relu = nn.ReLU().to(device)
        
        # Define the attention layer
        self.attention = MultiheadAttention(embed_dim=out_ch, num_heads=num_heads).to(device)
        self.attention_norm = nn.LayerNorm(out_ch).to(device)
        
    def forward(self, x, t):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        
        # Add time channel
        h = h + time_emb
        
        # Apply attention
        B, C, H, W = h.shape
        h_flat = h.view(B, C, -1).permute(2, 0, 1)  # (H*W, B, C)
        attn_output, _ = self.attention(h_flat, h_flat, h_flat)
        attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W)
        h = self.attention_norm(h + attn_output)
        
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        
        # Down or Upsample
        return self.transform(h)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = channels
        down_channels = (patch_size, patch_size * 2, patch_size * 4, patch_size * 8, patch_size * 16)
        up_channels = (patch_size * 16, patch_size * 8, patch_size * 4, patch_size * 2, patch_size)
        out_dim = channels
        time_emb_dim = 32

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim).to(device),
            nn.Linear(time_emb_dim, time_emb_dim).to(device),
            nn.ReLU().to(device)
        )

        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1).to(device)

        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], \
                                          time_emb_dim).to(device) \
                    for i in range(len(down_channels) - 1)])
    
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], \
                                        time_emb_dim, up=True).to(device) \
                    for i in range(len(up_channels) - 1)])
        
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1).to(device)

    def forward(self, x, timestep):
        t = self.time_mlp(timestep)
        x = self.conv0(x)
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            if x.size(2) != residual_x.size(2) or x.size(3) != residual_x.size(3):
                x = F.interpolate(x, size=(residual_x.size(2), residual_x.size(3)), mode='bilinear', align_corners=False)
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

model = SimpleUnet().to(device)
print("Num params: ", sum(p.numel() for p in model.parameters()))


#### Train Model

##### loss

In [None]:
#def get_loss(model, x_0, t):
#    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
#    noise_pred = model(x_noisy, t)
#    return F.mse_loss(noise_pred,noise)
    #return F.l1_loss(noise, noise_pred)

def get_loss(model, x_0, t):
    # Apply the forward diffusion process to generate noisy data
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    
    # Predict the noise from the noisy data using the model
    noise_pred = model(x_noisy, t)
    
    # Create a mask where True represents values not equal to -2
    mask = (x_0 != -2) & ~torch.isnan(x_0)
    
    # Apply the mask to filter out -2 values from noise and noise_pred
    valid_noise = noise[mask]
    valid_noise_pred = noise_pred[mask]
    
    # Compute the mean squared error loss only for valid values
    loss = F.mse_loss(valid_noise_pred, valid_noise)
    
    return loss

##### Sampling

In [None]:
@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean 
    else:
        noise = torch.randn_like(x)#, device=device)
        return model_mean + torch.sqrt(posterior_variance_t) * noise #


@torch.no_grad()
def inpainting_plot_image(image, mask):
    plt.figure(figsize=(15, 15))
    stepsize = int(T / num_images)

    # Display the original image
    plt.subplot(1, num_images + 1, 1)
    show_tensor_image(image.detach().cpu())
    plt.title(f"Original Image")  # Set the title of the subplot
    
    # Display the initial masked image
    masked_image = image * mask

    image_list = []
    image_idx_plot_list = []

    # Forward diffusion process
    for idx in range(0, T):  # Iterate over all timesteps
        t = torch.Tensor([idx]).type(torch.int64).to(device)  # Convert idx to tensor
        img, noise = forward_diffusion_sample(masked_image, t, device)  # Apply forward diffusion to generate noisy image
        if idx % stepsize == 0:  # Plot the image if the timestep is in the list of timesteps to plot
            image_idx_plot_list.append(idx)  # Append the timestep index to the list
            plt.subplot(1, num_images + 1, int(idx / stepsize) + 2)  # Create a subplot for the image
            show_tensor_image(img.detach().cpu())  # Show the noisy image at this timestep
            plt.title(f"Timestep {idx}")  # Set the title of the subplot
        image_list.append((img, t))  # Save the noisy image and its timestep

    plt.show()

    # Backward denoising process
    plt.figure(figsize=(15, 15))  # Prepare to plot the denoising process
    plt.subplot(1, num_images + 1, 1)  # Create a subplot for the masked image
    tmp_mask = torch.ones_like(mask)  # Initialize the mask to 1
    tmp_img = image_list[-1][0]  # Start with the last noisy image from the forward process

    for idx, (noisy_img, t) in enumerate(image_list[::-1]):  # Iterate over the noisy images in reverse order
        # Extract the single element from the tensor
        t_idx = int(t.item())

        # Predict the denoised image at this timestep
        tmp_img = sample_timestep(noisy_img, t)  # Predict the denoised image at this timestep

        # Combine the noisy image with the progressively denoised image using the mask
        img = (noisy_img * tmp_mask) + (tmp_img * (1 - tmp_mask))

        # Update the tmp_mask to mask so that after the first step only the originally masked part in the forward process is used
        tmp_mask = mask

        # Debug prints for denoising process
        print(f"Backward process - Step {idx}, Timestep {t_idx}")
        print(f"noisy_img min: {noisy_img.min().item()}, max: {noisy_img.max().item()}")
        print(f"tmp_img min: {tmp_img.min().item()}, max: {tmp_img.max().item()}")
        print(f"combined_img min: {img.min().item()}, max: {img.max().item()}")

        # Plot the image if the timestep is in the list of timesteps to plot
        if t_idx in image_idx_plot_list:
            subplot_index = num_images - (image_idx_plot_list.index(t_idx))
            print(f"INSIDE LOOP: Timestep {t_idx}")
            print(f"original idx: {subplot_index}")
            plt.subplot(1, num_images + 1, subplot_index + 2)  # Correct subplot index for denoised image
            show_tensor_image(img.detach().cpu())  # Show denoised image for that step
            plt.title(f"Timestep {t_idx}")

    plt.show()

    #plt.figure(figsize=(15, 15))
    #plt.subplot(1, num_images, 1)
    #tmp_mask = 1
    #tmp_img = image_list[-1][0]
    #for idx, (noisy_img,t) in enumerate(image_list[::-1]):
        #img = (noisy_img*tmp_mask)+(tmp_img*(1-tmp_mask))
        #tmp_mask = mask
        #tmp_img = sample_timestep(img,t)
        #if t in image_idx_plot_list:
            #plt.subplot(1, num_images, num_images-(int(idx/stepsize)))
            #show_tensor_image(tmp_img.detach().cpu())
    #plt.show()

##### Training Loop

In [None]:
optimizer = Adam(model.parameters(), lr=learning_rate)

def train_epoch(model):
    loss_value = 0
    for batch in train_dataloader:
        optimizer.zero_grad()
        t = torch.randint(0, T, (batch[1].shape[0],), device=device).long()
        loss = get_loss(model, batch[1].to(device), t)
        loss.backward()
        optimizer.step()
        loss_value+=loss.item()
    return loss_value/len(train_dataloader)

In [None]:
@torch.no_grad()
def val_epoch(model):
    loss_value = 0
    for batch in val_dataloader:
        t = torch.randint(0, T, (batch[1].shape[0],), device=device).long()
        loss = get_loss(model, batch[1].to(device), t)
        loss_value+=loss.item()
    return loss_value/len(val_dataloader)

In [None]:
epochs = 100
loss_list = []
for epoch in range(epochs):
    train_loss = train_epoch(model)
    val_loss = val_epoch(model)
    print(f"Epoch {epoch}/{epochs}: Train Loss: {train_loss}, Validation Loss: {val_loss}")
    loss_list.append((train_loss,val_loss))
    #if epoch > 0:
    #    sample = next(iter(train_dataloader))
    #    inpainting_plot_image(sample[1].to(device), sample[2].to(device))
        #inpainting_plot_image(sample[1].to(device),sample[2].to(device))

In [None]:
torch.save(model.state_dict(), 'best_model_mapillaryFeatures.pt')

In [None]:
plt.figure()

# Plot the training loss
plt.plot([g[0] for g in loss_list], label='Train Loss')

# Plot the validation loss
plt.plot([g[1] for g in loss_list], color='orange', label='Validation Loss')

# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')

# Add legend
plt.legend()

# Display the plot
plt.show()


In [None]:
from torchsummary import summary
from torchviz import make_dot
# Load the best model
model.load_state_dict(torch.load('best_model.pt'))

In [None]:
# Create a wrapper function
def model_summary_wrapper(model, input_size):
    class ModelWrapper(nn.Module):
        def __init__(self, model):
            super(ModelWrapper, self).__init__()
            self.model = model
        
        def forward(self, x):
            # Create a dummy timestep tensor
            batch_size = x.size(0)
            timestep = torch.zeros(batch_size, dtype=torch.long, device=x.device)
            return self.model(x, timestep)
    
    return ModelWrapper(model)

# Wrap your model
wrapped_model = model_summary_wrapper(model, (channels, patch_size, patch_size))

# Print the model summary
summary(wrapped_model, (channels, patch_size, patch_size))

In [None]:
# Initialize the model
model = SimpleUnet().to(device)

# Load the saved model state
model.load_state_dict(torch.load('best_model_mapillaryFeatures.pt'))

# Set the model to evaluation mode
model.eval()

In [None]:
# Test function to evaluate the model's performance on the test dataset
def test_model(model, dataloader, min_val, max_val, num_examples=3, channels=51):
    model.eval()
    mse_per_channel_normalized = torch.zeros(channels, dtype=torch.float32).to(device)
    mse_per_channel_denormalized = torch.zeros(channels, dtype=torch.float32).to(device)
    count_per_channel = torch.zeros(channels, dtype=torch.float32).to(device)
    examples_printed = 0

    for batch in dataloader:
        masked_image, original_image, mask = batch
        masked_image = masked_image.to(device)
        original_image = original_image.to(device)
        mask = mask.to(device)

        with torch.no_grad():
            t = torch.randint(0, T, (masked_image.shape[0],), device=device).long()
            generated_image = model(masked_image, t)

        # Convert tensors to CPU and NumPy arrays for denormalization
        original_image_np = original_image.cpu().numpy()
        generated_image_np = generated_image.cpu().numpy()

        # Denormalize the images
        denormalized_original = denormalize_data(original_image_np, min_val, max_val)
        denormalized_generated = denormalize_data(generated_image_np, min_val, max_val)

        if examples_printed < num_examples:
            original_means = []
            generated_means = []
            denormalized_original_means = []
            denormalized_generated_means = []
            for channel in range(channels):
                valid_mask = (mask[:, channel, :, :] == 0)
                if valid_mask.any():
                    valid_data_mask = (original_image[:, channel, :, :] != -2)
                    valid_combined_mask = valid_mask & valid_data_mask

                    original_channel_mean = torch.mean(original_image[:, channel, :, :][valid_combined_mask]).item()
                    generated_channel_mean = torch.mean(generated_image[:, channel, :, :][valid_combined_mask]).item()

                    denormalized_original_channel_mean = np.mean(denormalized_original[:, channel, :, :][valid_combined_mask.cpu().numpy()])
                    denormalized_generated_channel_mean = np.mean(denormalized_generated[:, channel, :, :][valid_combined_mask.cpu().numpy()])
                else:
                    original_channel_mean = float('nan')
                    generated_channel_mean = float('nan')
                    denormalized_original_channel_mean = float('nan')
                    denormalized_generated_channel_mean = float('nan')
                original_means.append(original_channel_mean)
                generated_means.append(generated_channel_mean)
                denormalized_original_means.append(denormalized_original_channel_mean)
                denormalized_generated_means.append(denormalized_generated_channel_mean)

            print(f"Example {examples_printed + 1}:")
            print("Original masked means (ignoring -2):", [round(mean, 3) for mean in original_means])
            print("Generated means:", [round(mean, 3) for mean in generated_means])
            print("Denormalized original means:", [round(mean, 3) for mean in denormalized_original_means])
            print("Denormalized generated means:", [round(mean, 3) for mean in denormalized_generated_means])
            print()

            examples_printed += 1

        for channel in range(channels):
            valid_mask = (mask[:, channel, :, :] == 0)
            if valid_mask.any():
                valid_data_mask = (original_image[:, channel, :, :] != -2)
                valid_combined_mask = valid_mask & valid_data_mask

                channel_mse_normalized = F.mse_loss(generated_image[:, channel, :, :][valid_combined_mask], original_image[:, channel, :, :][valid_combined_mask], reduction='mean')
                mse_per_channel_normalized[channel] += channel_mse_normalized * valid_combined_mask.sum()

                denorm_generated_tensor = torch.tensor(denormalized_generated[:, channel, :, :][valid_combined_mask.cpu().numpy()], dtype=torch.float32)
                denorm_original_tensor = torch.tensor(denormalized_original[:, channel, :, :][valid_combined_mask.cpu().numpy()], dtype=torch.float32)
                denorm_mask = torch.isnan(denorm_generated_tensor) | torch.isnan(denorm_original_tensor)

                if not denorm_mask.all():
                    channel_mse_denormalized = F.mse_loss(
                        denorm_generated_tensor[~denorm_mask],
                        denorm_original_tensor[~denorm_mask],
                        reduction='mean'
                    )
                    mse_per_channel_denormalized[channel] += channel_mse_denormalized * (~denorm_mask).sum()

                count_per_channel[channel] += valid_combined_mask.sum()

    average_mse_per_channel_normalized = mse_per_channel_normalized / count_per_channel
    average_mse_per_channel_denormalized = mse_per_channel_denormalized / count_per_channel

    return average_mse_per_channel_normalized, average_mse_per_channel_denormalized


# Test the model
average_mse_per_channel_normalized, average_mse_per_channel_denormalized = test_model(model, test_dataloader, min_val, max_val, channels=51)
print("Average MSE per channel (Normalized):", [round(mse.item(), 3) for mse in average_mse_per_channel_normalized])


In [None]:
def test_model(model, dataloader, min_val, max_val, num_examples=3, channels=26):
    model.eval()
    mse_per_channel_normalized = torch.zeros(channels, dtype=torch.float32).to(device)
    mse_per_channel_denormalized = torch.zeros(channels, dtype=torch.float32).to(device)
    count_per_channel = torch.zeros(channels, dtype=torch.float32).to(device)
    examples_printed = 0

    for batch in dataloader:
        masked_image, original_image, mask = batch
        masked_image = masked_image.to(device)
        original_image = original_image.to(device)
        mask = mask.to(device)

        with torch.no_grad():
            t = torch.randint(0, T, (masked_image.shape[0],), device=device).long()
            generated_image = model(masked_image, t)

        # Convert tensors to CPU and NumPy arrays for denormalization
        original_image_np = original_image.cpu().numpy()
        generated_image_np = generated_image.cpu().numpy()

        # Denormalize the images
        denormalized_original = denormalize_data(original_image_np, min_val, max_val)
        denormalized_generated = denormalize_data(generated_image_np, min_val, max_val)

        if examples_printed < num_examples:
            original_means = []
            generated_means = []
            denormalized_original_means = []
            denormalized_generated_means = []
            for channel in range(channels):
                valid_mask = (mask[:, channel, :, :] == 0)
                if valid_mask.any():
                    valid_data_mask = (original_image[:, channel, :, :] != -2)
                    valid_combined_mask = valid_mask & valid_data_mask

                    original_channel_mean = torch.mean(original_image[:, channel, :, :][valid_combined_mask]).item()
                    generated_channel_mean = torch.mean(generated_image[:, channel, :, :][valid_combined_mask]).item()

                    denormalized_original_channel_mean = np.mean(denormalized_original[:, channel, :, :][valid_combined_mask.cpu().numpy()])
                    denormalized_generated_channel_mean = np.mean(denormalized_generated[:, channel, :, :][valid_combined_mask.cpu().numpy()])
                else:
                    original_channel_mean = float('nan')
                    generated_channel_mean = float('nan')
                    denormalized_original_channel_mean = float('nan')
                    denormalized_generated_channel_mean = float('nan')
                original_means.append(original_channel_mean)
                generated_means.append(generated_channel_mean)
                denormalized_original_means.append(denormalized_original_channel_mean)
                denormalized_generated_means.append(denormalized_generated_channel_mean)

            print(f"Example {examples_printed + 1}:")
            print("Original masked means (ignoring -2):", [round(mean, 3) for mean in original_means])
            print("Generated means:", [round(mean, 3) for mean in generated_means])
            print("Denormalized original means:", [round(mean, 3) for mean in denormalized_original_means])
            print("Denormalized generated means:", [round(mean, 3) for mean in denormalized_generated_means])
            print()

            examples_printed += 1

        for channel in range(channels):
            valid_mask = (mask[:, channel, :, :] == 0)
            if valid_mask.any():
                valid_data_mask = (original_image[:, channel, :, :] != -2)
                valid_combined_mask = valid_mask & valid_data_mask

                # Calculate the normalized MSE
                channel_mse_normalized = F.mse_loss(
                    generated_image[:, channel, :, :][valid_combined_mask], 
                    original_image[:, channel, :, :][valid_combined_mask], 
                    reduction='sum'
                )
                mse_per_channel_normalized[channel] += channel_mse_normalized
                count_per_channel[channel] += valid_combined_mask.sum().item()

                # Calculate the denormalized MSE
                denorm_generated_tensor = torch.tensor(
                    denormalized_generated[:, channel, :, :][valid_combined_mask.cpu().numpy()], 
                    dtype=torch.float32
                )
                denorm_original_tensor = torch.tensor(
                    denormalized_original[:, channel, :, :][valid_combined_mask.cpu().numpy()], 
                    dtype=torch.float32
                )
                denorm_mask = torch.isnan(denorm_generated_tensor) | torch.isnan(denorm_original_tensor)

                if not denorm_mask.all():
                    channel_mse_denormalized = F.mse_loss(
                        denorm_generated_tensor[~denorm_mask],
                        denorm_original_tensor[~denorm_mask],
                        reduction='sum'
                    )
                    mse_per_channel_denormalized[channel] += channel_mse_denormalized.item()

    average_mse_per_channel = mse_per_channel_normalized# / count_per_channel
    #average_mse_per_channel_denormalized = mse_per_channel_denormalized / count_per_channel

    rmse_per_channel = torch.sqrt(average_mse_per_channel)
    rmse_per_channel_denormalized = torch.sqrt(average_mse_per_channel_denormalized)

    return rmse_per_channel

# Test the model
rmse_per_channel_normalized, rmse_per_channel_denormalized = test_model(model, test_dataloader, min_val, max_val, channels=51)
print("RMSE per channel (Normalized):", [round(rmse.item(), 3) for rmse in rmse_per_channel])


In [None]:
def test_model(model, test_dataloader, min_val, max_val, device="cpu", num_examples=3):
    model.eval()
    
    original_means = []
    generated_means = []
    
    # Collecting means for each channel
    for batch in test_dataloader:
        masked_images, original_images, masks = batch
        masked_images = masked_images.to(device)
        original_images = original_images.to(device)
        masks = masks.to(device)
        
        # Generate predictions
        t = torch.randint(0, T, (masked_images.shape[0],), device=device).long()
        predictions = model(masked_images, t)
        
        # Print the range of the model's output for debugging
        print("Predictions range before denormalization:", predictions.min().item(), predictions.max().item())
        
        # Detach tensors from the computation graph and convert to NumPy arrays
        predictions_np = predictions.detach().cpu().numpy()
        original_images_np = original_images.detach().cpu().numpy()
        
        # Denormalize the predictions and original images
        denormalized_predictions = denormalize_data(predictions_np, min_val, max_val)
        denormalized_originals = denormalize_data(original_images_np, min_val, max_val)
        
        # Calculate means
        original_mean = np.mean(denormalized_originals, axis=(0, 2, 3))
        generated_mean = np.mean(denormalized_predictions, axis=(0, 2, 3))
        
        original_means.append(original_mean)
        generated_means.append(generated_mean)
    
    # Converting lists to arrays
    original_means = np.array(original_means)
    generated_means = np.array(generated_means)
    
    # Print some example means
    print("Original Means of the Mask (Examples):")
    print(original_means[:num_examples])
    print("Generated Means of the Mask (Examples):")
    print(generated_means[:num_examples])
    
    # Compute MSE for all channels individually
    mse_per_channel = np.mean((original_means - generated_means) ** 2, axis=0)
    
    return mse_per_channel

# Example usage
mse_per_channel = test_model(model, test_dataloader, min_val, max_val, device)
print("MSE for each channel:", mse_per_channel)

Implementing Early stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                           Default: 0
        """
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
        self.best_loss = np.Inf

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            torch.save(model.state_dict(), 'checkpoint.pt')

### Hyperparameter Tuning

In [None]:
# Parameters
epochs = 100
num_images = 5
channel_to_display = 6

# Hyperparameter space
hyperparameter_space = {
    'learning_rate': [1e-3, 1e-4, 1e-5],
    'batch_size': [4, 8, 16],
    'mask_count': [100, 140, 180],
    'T': [500, 1000, 1500],
    'patch_size': [30, 40, 50],
    'patch_stride': [10, 20, 25]
}

# Define the updated sampling function
def sample_hyperparameters(hyperparameter_space, max_T):
    sampled_params = {key: random.choice(values) for key, values in hyperparameter_space.items()}
    # Ensure T is within bounds
    sampled_params['T'] = min(sampled_params['T'], max_T)
    return sampled_params


In [None]:
# Perform random search
n_iterations = 10
results = []

max_T = T  # Set the maximum value of T based on initial value

for i in range(n_iterations):
    print(f"Iteration {i+1}/{n_iterations}")
    hyperparameters = sample_hyperparameters(hyperparameter_space, max_T)
    
    # Update hyperparameters
    learning_rate = hyperparameters['learning_rate']
    BATCH_SIZE = hyperparameters['batch_size']
    mask_count = hyperparameters['mask_count']
    T = hyperparameters['T']
    patch_size = hyperparameters['patch_size']
    patch_stride = hyperparameters['patch_stride']
    
    # Define the model, optimizer, and data loaders here based on the new hyperparameters
    model = SimpleUnet().to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    train_dataset = PatchDataset(train_patches, mask_pixels=mask_count)
    val_dataset = PatchDataset(val_patches, mask_pixels=mask_count)
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # Redefine beta schedule and precomputed terms based on the new T
    betas = cosine_beta_schedule(timesteps=T)
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0).to(device)
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas).to(device)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).to(device)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).to(device)
    posterior_variance = (betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)).to(device)
    
    # Train and validate the model
    loss_list = []
    for epoch in range(epochs):
        train_loss = train_epoch(model)
        val_loss = val_epoch(model)
        print(f"Epoch {epoch}/{epochs}: Train Loss: {train_loss}, Validation Loss: {val_loss}")
        loss_list.append((train_loss, val_loss))
    
    # Log results
    results.append({
        'hyperparameters': hyperparameters,
        'validation_loss': val_loss
    })

# Save the results to a DataFrame
results_df = pd.DataFrame(results)
results_df.to_csv('random_search_results.csv', index=False)

# Find the best hyperparameters
best_result = results_df.loc[results_df['validation_loss'].idxmin()]
print(f"Best hyperparameters: {best_result['hyperparameters']}")
print(f"Best validation loss: {best_result['validation_loss']}")


In [None]:
# Train the final model with the best hyperparameters on the combined training and validation dataset
best_hyperparameters = best_result['hyperparameters']
learning_rate = best_hyperparameters['learning_rate']
BATCH_SIZE = best_hyperparameters['batch_size']
mask_count = best_hyperparameters['mask_count']
T = best_hyperparameters['T']
patch_size = best_hyperparameters['patch_size']
patch_stride = best_hyperparameters['patch_stride']

final_model = SimpleUnet().to(device)
final_optimizer = Adam(final_model.parameters(), lr=learning_rate)

combined_train_val_patches = np.concatenate((train_patches, val_patches))
combined_train_val_dataset = PatchDataset(combined_train_val_patches, mask_pixels=mask_count)
combined_train_val_dataloader = DataLoader(combined_train_val_dataset, batch_size=BATCH_SIZE, shuffle=True)

for epoch in range(epochs):
    train_loss = train_epoch(final_model)
    print(f"Epoch {epoch}/{epochs}: Train Loss: {train_loss}")

In [None]:
torch.save(model.state_dict(), 'best_model_HyperDone.pt')

plt.figure()

# Plot the training loss
plt.plot([g[0] for g in loss_list], label='Train Loss')

# Plot the validation loss
plt.plot([g[1] for g in loss_list], color='orange', label='Validation Loss')

# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')

# Add legend
plt.legend()

# Display the plot
plt.show()