#### Import Dependencies

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import numpy as np
import random
from collections import deque
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import math


Set Flags

In [2]:
# Set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Parameters
project_path = "/home/s184310/3.Project/data/final_merged_data"
patch_size = 50
patch_stride = 10
channels = 26
BATCH_SIZE = 4
rect_x_start, rect_x_end = 100, 290 
rect_y_start, rect_y_end = 200, 390
channel_to_display = 6
fixed_channel = 13  # Specify which channel to fix, which is vegetation percentage
fixed_value_range = (0, 100)  # Range of values to sample from for the fixed channel
test_size = 0.2
random_state = 42
mask_count = 140
T = 1000
epochs = 300
learning_rate = 0.001
num_images = 5

In [None]:
grid_array = np.load(f'{project_path}/50_all_data.npy')
print(grid_array.shape)
# Initialize a list to store the count of valid data points for each channel
valid_data_counts = []

# Iterate over each channel
for channel in range(grid_array.shape[-1]):
    # Extract the data for the current channel
    channel_data = grid_array[:, :, channel]
    
    # Count the number of valid data points (not 0 and not -1)
    valid_count = np.sum((channel_data != 0) & (channel_data != -1))
    
    # Append the count to the list
    valid_data_counts.append(valid_count)

# Find the channel with the maximum valid data points
max_valid_channel = np.argmax(valid_data_counts)
max_valid_count = valid_data_counts[max_valid_channel]

print(f"Channel with the most valid data: {max_valid_channel}")
print(f"Number of valid data points in this channel: {max_valid_count}")

In [None]:
# Load the grid array and select the channel to display
#grid_array = np.load(f'{project_path}/50m_26_features_array.npy')
##print("Shape of loaded grid array:", grid_array.shape)

# Update all -1 values to -2
grid_array[grid_array == -1] = -2

# Mask the -2 values by setting them to NaN for visualization
#channel_data = np.where(grid_array[:, :, channel_to_display] == -2, np.nan, grid_array[:, :, channel_to_display])
channel_data = np.where(grid_array[:, :, channel_to_display] == 0, np.nan, grid_array[:, :, channel_to_display])
#channel_data = grid_array[:, :, channel_to_display]

# Mask the -1 values by setting them to NaN
channel_data = np.where(channel_data == -2, np.nan, channel_data)
channel_data = np.where(channel_data == 0, np.nan, channel_data)

# Load the dataframe to get column names
#df = pd.read_csv('/home/s184310/3.Project/data/merged_mapi_poi_metrics.csv')
columns = df.columns

# Calculate the actual column index (accounting for starting from the third column)
column_index = channel_to_display + 2

# Ensure the index is within the range of the dataframe's columns
if column_index < len(columns):
    column_name = columns[column_index]
    title = f"Original data for {column_name} feature"
else:
    title = f"Original data for channel {channel_to_display}: Column index out of range"

# Plot the data with a perceptually uniform colormap
plt.figure(figsize=(10, 8))  # Increase figure size for better coverage visibility
plt.imshow(channel_data, cmap='viridis', interpolation='nearest', aspect='auto')
plt.colorbar()
plt.title(title)
plt.axis('off')
plt.show()

# Print the min and max values excluding NaNs to understand the color range
min_val = np.nanmin(channel_data)
max_val = np.nanmax(channel_data)
print(f"Data range for displayed channel: {min_val} to {max_val}")


Load an normalise

In [3]:
# Load the grid array and select the channel to display
grid_array = np.load(f'{project_path}/50m_26_features_array.npy')
print("Shape of loaded grid array:", grid_array.shape)

# Update all -1 values to -2
grid_array[grid_array == -1] = -2

def normalize_data(data):
    # Create a mask for -2 values
    mask = (data != -2) & ~np.isnan(data)
    
    # Compute min and max excluding -2 values, along the third axis (channels)
    with np.errstate(all='ignore'):  # Suppress warnings for NaN slices
        min_val = np.nanmin(np.where(mask, data, np.nan), axis=(0, 1), keepdims=True)
        max_val = np.nanmax(np.where(mask, data, np.nan), axis=(0, 1), keepdims=True)
    
    # Handle cases where the min and max are NaN (i.e., no valid data points in the channel)
    min_val = np.nan_to_num(min_val, nan=0.0)  # Replace NaN with 0.0
    max_val = np.nan_to_num(max_val, nan=1.0)  # Replace NaN with 1.0
    
    # Ensure min_val != max_val to avoid division by zero
    scale = np.where(max_val != min_val, max_val - min_val, 1)
    
    # Normalize data excluding -2 values
    normalized_data = np.where(mask, (((data - min_val) / scale) * 2) - 1, -2)
    
    return normalized_data, min_val, max_val

def denormalize_data(data, min_val, max_val):
    # Ensure min_val and max_val are numpy arrays
    if isinstance(min_val, torch.Tensor):
        min_val = min_val.cpu().numpy()
    if isinstance(max_val, torch.Tensor):
        max_val = max_val.cpu().numpy()
    
    # Handle NaN values in min and max
    min_val = np.nan_to_num(min_val, nan=0.0)
    max_val = np.nan_to_num(max_val, nan=1.0)

    # Ensure min_val and max_val are 3D for broadcasting (1, channels, 1, 1)
    min_val = min_val.reshape(1, -1, 1, 1)
    max_val = max_val.reshape(1, -1, 1, 1)

    # Create a mask for -2 values
    mask = (data != -2) & ~np.isnan(data)

    # Ensure min_val != max_val to avoid division by zero
    scale = np.where(max_val != min_val, max_val - min_val, 1)

    # Denormalize data excluding -2 values
    denormalized_data = np.where(mask, (data + 1) / 2 * scale + min_val, -2)

    return denormalized_data

normalized_grid_array, min_val, max_val = normalize_data(grid_array)
test_data = normalized_grid_array[rect_y_start:rect_y_end, rect_x_start:rect_x_end, :].copy()
normalized_grid_array[rect_y_start:rect_y_end, rect_x_start:rect_x_end, :] = -3
train_data = normalized_grid_array

FileNotFoundError: [Errno 2] No such file or directory: '/home/s184310/3.Project/data/final_merged_data/50m_26_features_array.npy'

In [1]:
def create_input_output_pairs(data, patch_size=patch_size, stride=patch_stride):
    patches = []
    for i in range(0, data.shape[0] - patch_size, stride):
        for j in range(0, data.shape[1] - patch_size, stride):
            current_patch = data[i:patch_size + i, j:patch_size + j].copy()
            if not np.any(current_patch == -3):
                if not np.all(current_patch == -2):
                    patches.append(current_patch)
    return np.array(patches)

train_patches = create_input_output_pairs(train_data)
test_patches = create_input_output_pairs(test_data)

# Split the dataset into training and validation
train_patches, val_patches = train_test_split(train_patches, test_size=test_size, random_state=random_state)

print("Train inputs shape:", train_patches.shape)
print("Val inputs shape:", val_patches.shape)
print("Test inputs shape:", test_patches.shape)

class PatchDataset(Dataset):
    def __init__(self, data, mask_pixels=mask_count, fixed_channel=fixed_channel, fixed_value=None):
        self.data = data
        self.mask_pixels = mask_pixels
        self.height, self.width = data.shape[1], data.shape[2]
        self.fixed_channel = fixed_channel
        self.fixed_value = fixed_value

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        image = torch.tensor(image, dtype=torch.float32)

        mask = torch.ones_like(image)
        num_pixels_to_mask = self.mask_pixels

        i = np.random.randint(0, self.height)
        j = np.random.randint(0, self.width)
        
        directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        queue = deque([(i, j)])
        mask[i, j, :] = 0
        num_pixels_to_mask -= 1

        while queue and num_pixels_to_mask > 0:
            ci, cj = queue.popleft()

            for di, dj in directions:
                ni, nj = ci + di, cj + dj
                if 0 <= ni < self.height and 0 <= nj < self.width and mask[ni, nj, 0] == 1:
                    mask[ni, nj, :] = 0
                    num_pixels_to_mask -= 1
                    queue.append((ni, nj))

                    if num_pixels_to_mask == 0:
                        break

        if num_pixels_to_mask > 0:
            for i in range(self.height):
                for j in range(self.width):
                    if num_pixels_to_mask == 0:
                        break
                    if mask[i, j, 0] == 1:
                        mask[i, j, :] = 0
                        num_pixels_to_mask -= 1

        # Apply the mask to all channels except the fixed channel
        for ch in range(image.shape[0]):
            if ch != self.fixed_channel:
                image[ch] = image[ch] * mask[ch]

        # Set the fixed channel to the known value if specified
        if self.fixed_value is not None:
            image[self.fixed_channel, :, :] = self.fixed_value

        condition_value = image[self.fixed_channel, :, :].mean().unsqueeze(0)
        condition_channel = torch.tensor([self.fixed_channel], dtype=torch.float32)

        return image.permute(2, 0, 1), image.permute(2, 0, 1), mask.permute(2, 0, 1), condition_channel, condition_value

train_dataset = PatchDataset(train_patches)
val_dataset = PatchDataset(val_patches)
test_dataset = PatchDataset(test_patches)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

for masked_image, image, mask, condition_channel, condition_value in train_dataloader:
    print(masked_image.shape)
    print(image.shape)
    print(mask.shape)
    print(condition_channel.shape)
    print(condition_value.shape)
    break

In [None]:
def cosine_beta_schedule(timesteps, s=0.008):
    def f(t):
        return torch.cos((t / timesteps + s) / (1 + s) * 0.5 * torch.pi) ** 2
    x = torch.linspace(0, timesteps, timesteps + 1).to(device)
    alphas_cumprod = f(x) / f(torch.tensor([0]).to(device))
    betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
    betas = torch.clip(betas, 0.0001, 0.999).to(device)
    return betas

def get_index_from_list(vals, t, x_shape):
    batch_size = t.shape[0]
    out = vals.gather(-1, t).to(t.device)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device=device):
    x_0 = x_0.to(device)
    noise = torch.randn_like(x_0, device=device)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape).to(device)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    ).to(device)
    c1 = sqrt_alphas_cumprod_t.to(device) * x_0
    c2 = c1 + sqrt_one_minus_alphas_cumprod_t.to(device) * noise
    return c2, noise.to(device)

betas = cosine_beta_schedule(timesteps=T)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)


In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch).to(device)
        if up:
            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1).to(device)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1).to(device)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1).to(device)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1).to(device)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1).to(device)
        self.bnorm1 = nn.BatchNorm2d(out_ch).to(device)
        self.bnorm2 = nn.BatchNorm2d(out_ch).to(device)
        self.relu = nn.ReLU().to(device)

    def forward(self, x, t):
        h = self.bnorm1(self.relu(self.conv1(x)))
        time_emb = self.relu(self.time_mlp(t))
        time_emb = time_emb[(..., ) + (None, ) * 2]  # Extend to match spatial dimensions
        h = h + time_emb  # Add time embedding
        h = self.bnorm2(self.relu(self.conv2(h)))
        return self.transform(h)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class SimpleUnet(nn.Module):
    def __init__(self):
        super().__init__()
        image_channels = channels
        down_channels = (patch_size, patch_size * 2, patch_size * 4, patch_size * 8, patch_size * 16)
        up_channels = (patch_size * 16, patch_size * 8, patch_size * 4, patch_size * 2, patch_size)
        out_dim = channels
        time_emb_dim = 32
        cond_dim = 2  # Channel number and value

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim).to(device),
            nn.Linear(time_emb_dim, time_emb_dim).to(device),
            nn.ReLU().to(device)
        )

        self.cond_mlp = nn.Sequential(
            nn.Linear(cond_dim, time_emb_dim).to(device),
            nn.ReLU().to(device)
        )

        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1).to(device)

        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], 
                                          time_emb_dim).to(device) 
                    for i in range(len(down_channels) - 1)])
    
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], 
                                        time_emb_dim, up=True).to(device) 
                    for i in range(len(up_channels) - 1)])
        
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1).to(device)

    def forward(self, x, timestep, condition):
        t = self.time_mlp(timestep)
        c = self.cond_mlp(condition)
        t = t + c  # Combine time and condition embeddings
        x = self.conv0(x)
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            if x.size(2) != residual_x.size(2) or x.size(3) != residual_x.size(3):
                x = F.interpolate(x, size=(residual_x.size(2), residual_x.size(3)), mode='bilinear', align_corners=False)
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

model = SimpleUnet().to(device)
print("Num params: ", sum(p.numel() for p in model.parameters()))


In [None]:
@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean 
    else:
        noise = torch.randn_like(x)#, device=device)
        return model_mean + torch.sqrt(posterior_variance_t) * noise #


@torch.no_grad()
def inpainting_plot_image(image, mask):
    plt.figure(figsize=(15, 15))
    stepsize = int(T / num_images)

    # Display the original image
    plt.subplot(1, num_images + 1, 1)
    show_tensor_image(image.detach().cpu())
    plt.title(f"Original Image")  # Set the title of the subplot
    
    # Display the initial masked image
    masked_image = image * mask

    image_list = []
    image_idx_plot_list = []

    # Forward diffusion process
    for idx in range(0, T):  # Iterate over all timesteps
        t = torch.Tensor([idx]).type(torch.int64).to(device)  # Convert idx to tensor
        img, noise = forward_diffusion_sample(masked_image, t, device)  # Apply forward diffusion to generate noisy image
        if idx % stepsize == 0:  # Plot the image if the timestep is in the list of timesteps to plot
            image_idx_plot_list.append(idx)  # Append the timestep index to the list
            plt.subplot(1, num_images + 1, int(idx / stepsize) + 2)  # Create a subplot for the image
            show_tensor_image(img.detach().cpu())  # Show the noisy image at this timestep
            plt.title(f"Timestep {idx}")  # Set the title of the subplot
        image_list.append((img, t))  # Save the noisy image and its timestep

    plt.show()

    # Backward denoising process
    plt.figure(figsize=(15, 15))  # Prepare to plot the denoising process
    plt.subplot(1, num_images + 1, 1)  # Create a subplot for the masked image
    tmp_mask = torch.ones_like(mask)  # Initialize the mask to 1
    tmp_img = image_list[-1][0]  # Start with the last noisy image from the forward process

    for idx, (noisy_img, t) in enumerate(image_list[::-1]):  # Iterate over the noisy images in reverse order
        # Extract the single element from the tensor
        t_idx = int(t.item())

        # Predict the denoised image at this timestep
        tmp_img = sample_timestep(noisy_img, t)  # Predict the denoised image at this timestep

        # Combine the noisy image with the progressively denoised image using the mask
        img = (noisy_img * tmp_mask) + (tmp_img * (1 - tmp_mask))

        # Update the tmp_mask to mask so that after the first step only the originally masked part in the forward process is used
        tmp_mask = mask

        # Debug prints for denoising process
        print(f"Backward process - Step {idx}, Timestep {t_idx}")
        print(f"noisy_img min: {noisy_img.min().item()}, max: {noisy_img.max().item()}")
        print(f"tmp_img min: {tmp_img.min().item()}, max: {tmp_img.max().item()}")
        print(f"combined_img min: {img.min().item()}, max: {img.max().item()}")

        # Plot the image if the timestep is in the list of timesteps to plot
        if t_idx in image_idx_plot_list:
            subplot_index = num_images - (image_idx_plot_list.index(t_idx))
            print(f"INSIDE LOOP: Timestep {t_idx}")
            print(f"original idx: {subplot_index}")
            plt.subplot(1, num_images + 1, subplot_index + 2)  # Correct subplot index for denoised image
            show_tensor_image(img.detach().cpu())  # Show denoised image for that step
            plt.title(f"Timestep {t_idx}")

    plt.show()

In [None]:
def get_loss(model, x_0, t, condition):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t, condition)
    mask = (x_0 != -2)
    valid_noise = noise[mask]
    valid_noise_pred = noise_pred[mask]
    loss = F.mse_loss(valid_noise_pred, valid_noise)
    return loss

def train_epoch(model):
    loss_value = 0
    for masked_image, image, mask, condition_channel, _ in train_dataloader:
        optimizer.zero_grad()
        t = torch.randint(0, T, (image.shape[0],), device=device).long()
        condition_value = torch.tensor([random.uniform(*fixed_value_range)], device=device).unsqueeze(0)  # Draw a new value from 0-100
        condition = torch.cat((condition_channel, condition_value), dim=1)
        loss = get_loss(model, image.to(device), t, condition.to(device))
        loss.backward()
        optimizer.step()
        loss_value += loss.item()
    return loss_value / len(train_dataloader)

@torch.no_grad()
def val_epoch(model):
    loss_value = 0
    for masked_image, image, mask, condition_channel, condition_value in val_dataloader:
        t = torch.randint(0, T, (image.shape[0],), device=device).long()
        condition = torch.cat((condition_channel, condition_value), dim=1)
        loss = get_loss(model, image.to(device), t, condition.to(device))
        loss_value += loss.item()
    return loss_value / len(val_dataloader)



In [None]:
loss_list = []
for epoch in range(epochs):
    train_loss = train_epoch(model)
    val_loss = val_epoch(model)
    print(f"Epoch {epoch}/{epochs}: Train Loss: {train_loss}, Validation Loss: {val_loss}")
    loss_list.append((train_loss, val_loss))

In [None]:
# Testing function
@torch.no_grad()
def test_model(model, test_dataloader):
    model.eval()
    for masked_image, image, mask, condition_channel, condition_value in test_dataloader:
        condition = torch.cat((condition_channel, condition_value), dim=1)
        t = torch.randint(0, T, (image.shape[0],), device=device).long()
        x_noisy, _ = forward_diffusion_sample(image.to(device), t, device)
        generated_images = []
        for i in range(T - 1, -1, -1):
            t_i = torch.full((image.shape[0],), i, device=device, dtype=torch.long)
            x_noisy = sample_timestep(x_noisy, t_i, condition.to(device))
            generated_images.append(x_noisy.cpu().detach().numpy())

        # Visualize the results for the first image in the batch
        idx = 0
        original_image = image[idx].cpu().numpy()
        masked_image = masked_image[idx].cpu().numpy()
        generated_image = generated_images[0][idx]

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(original_image[channel_to_display], cmap='gray')
        axes[0].set_title("Original Image")
        axes[1].imshow(masked_image[channel_to_display], cmap='gray')
        axes[1].set_title("Masked Image")
        axes[2].imshow(generated_image[channel_to_display], cmap='gray')
        axes[2].set_title("Generated Image")
        plt.show()

        break

test_model(model, test_dataloader)

## testing for my one example in thesis

In [None]:
@torch.no_grad()
def generate_from_fixed_channel(model, patch, fixed_channel=13, fixed_value=25, mask_pixels=20):
    """
    Generate the other channels based on a fixed value for one channel in a given patch.
    
    Parameters:
    - model: The trained model to use for generation.
    - patch: The input patch (tensor) to process.
    - fixed_channel: The channel to fix to a specific value.
    - fixed_value: The value to set for the fixed channel.
    - mask_pixels: The number of pixels to mask.
    
    Returns:
    - original_patch: The original patch before modification.
    - generated_patch: The generated patch after applying the condition and using the model.
    """
    # Ensure the patch is a tensor and move it to the correct device
    patch = torch.tensor(patch, dtype=torch.float32).to(device)
    
    # Mask all channels except the fixed channel
    mask = torch.ones_like(patch)
    num_pixels_to_mask = mask_pixels

    height, width = patch.shape[1], patch.shape[2]
    i = np.random.randint(0, height)
    j = np.random.randint(0, width)
    
    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    queue = deque([(i, j)])
    mask[:, i, j] = 0
    num_pixels_to_mask -= 1

    while queue and num_pixels_to_mask > 0:
        ci, cj = queue.popleft()

        for di, dj in directions:
            ni, nj = ci + di, cj + dj
            if 0 <= ni < height and 0 <= nj < width and mask[0, ni, nj] == 1:
                mask[:, ni, nj] = 0
                num_pixels_to_mask -= 1
                queue.append((ni, nj))

                if num_pixels_to_mask == 0:
                    break

    if num_pixels_to_mask > 0:
        for i in range(height):
            for j in range(width):
                if num_pixels_to_mask == 0:
                    break
                if mask[0, i, j] == 1:
                    mask[:, i, j] = 0
                    num_pixels_to_mask -= 1

    for ch in range(patch.shape[0]):
        if ch != fixed_channel:
            patch[ch] = patch[ch] * mask[ch]

    # Set the fixed channel to the known value
    patch[fixed_channel, :, :] = fixed_value

    condition_value = torch.tensor([fixed_value], dtype=torch.float32).to(device)
    condition_channel = torch.tensor([fixed_channel], dtype=torch.float32).to(device)
    condition = torch.cat((condition_channel, condition_value), dim=0).unsqueeze(0)

    t = torch.randint(0, T, (1,), device=device).long()
    x_noisy, _ = forward_diffusion_sample(patch.unsqueeze(0), t, device)
    generated_patch = []
    
    for i in range(T - 1, -1, -1):
        t_i = torch.full((1,), i, device=device, dtype=torch.long)
        x_noisy = sample_timestep(x_noisy, t_i, condition.to(device))
        generated_patch.append(x_noisy.cpu().detach().numpy())

    original_patch = patch.cpu().numpy()
    generated_patch = generated_patch[0][0]

    # Print mean values of masked area
    for ch in range(patch.shape[0]):
        masked_area_original = original_patch[ch][mask[ch].cpu().numpy() == 0]
        masked_area_generated = generated_patch[ch][mask[ch].cpu().numpy() == 0]
        print(f"Channel {ch}: Known mean value: {masked_area_original.mean()}, Generated mean value: {masked_area_generated.mean()}")

    return original_patch, generated_patch

In [None]:
# Assume model is your trained model and test_patches is your test dataset
sample_patch = test_patches[0]  # Select a patch from the test set

original, generated = generate_from_fixed_channel(model, sample_patch, fixed_channel=13, fixed_value=25, mask_pixels=20)

# Visualize the results
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(original[channel_to_display], cmap='gray')
axes[0].set_title("Original Patch")
axes[1].imshow(generated[channel_to_display], cmap='gray')
axes[1].set_title("Generated Patch")
plt.show()