# Content-aware image restoration

Fluorescence microscopy is constrained by the microscope's optics, fluorophore chemistry, and the sample's photon tolerance. These constraints require balancing imaging speed, resolution, light exposure, and depth. CARE demonstrates how Deep learning can extend the range of biological phenomena observable by microscopy.
 

![image](nb_data/tradeoff.png)

### CARE
In this first exercise we will train a CARE model for a 2D denoising task. 
We'll use a UNet model that we defined in the semantic segmentation exercise, so we'd
need to import the necessary modules and functions from that module


 

![image](nb_data/img_intro.png) 

### Mandatory actions
<div class="alert alert-danger">
Set your python kernel to <code>regression</code> <br>
</div>

In [None]:
import importlib.util

unet_spec = importlib.util.spec_from_file_location(
    "segm", "/home/igor.zubarev/projects/dl_courses/01_segmentation/unet.py"
)
unet_module = importlib.util.module_from_spec(unet_spec)
unet_spec.loader.exec_module(unet_module)

In [None]:
from careamics_portfolio import PortfolioManager
import tifffile
import numpy as np
from pathlib import Path
from typing import Union, List, Tuple
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torch import no_grad, cuda
from transforms import augment_batch, normalize, denormalize
import matplotlib.pyplot as plt

<hr style="height:2px;"><div class="alert alert-block alert-success"><h1>Checkpoint 1: Prepare data</h1>
</div>

CARE is a fully supervised algorithm, therefore we'd need image pairs for training .In practice this is best achieved by acquiring 2 interleaved stacks, i.e. different channels that correspond to the different exposure/laser settings.

We'll be using high SNR(signal-to-noise ratio) images of Human U2OS cells taken from the Broad Bioimage Benchmark Collection. Low SNR images were created by synthetically adding strong read-out and shot-noise and applying pixel binning of 2x2, thus mimicking acquisitions at a very low light level.

Since the image pairs were synthetically created in this example, they are already aligned perfectly. Note that when working with real paired acquisitions, the low and high SNR images are not pixel-perfect aligned so typically need to be co-registered before training a CARE model.

To train a denoising network, we will use the same UNet model we used in the semantic segmentation exercise.


<div class="alert alert-block alert-success"><h3>Download the data</h3>
</div>

For downloading the data, we will use the careamics-portfolio package. The package provides a collection of microscopy datasets and convenience functions for downloading the data.

In [None]:
# Download the data
portfolio = PortfolioManager()
print(portfolio.denoising)

root_path = Path("./data")
files = portfolio.denoising.CARE_U2OS.download(root_path)
print(f"List of downloaded files: {files}")

<div class="alert alert-block alert-success"><h3>Split the dataset into training and validation</h3>
</div>


In [None]:
# Define the paths
root_path = Path("./data/denoising-CARE_U2OS.unzip/data/U2OS")
assert root_path.exists(), f"Path {root_path} does not exist"

train_images_path = root_path / "train" / "low"
train_targets_path = root_path / "train" / "GT"
test_image_path = root_path / "test" / "low"
test_target_path = root_path / "test" / "GT"


image_files = list(Path(train_images_path).rglob("*.tif"))
target_files = list(Path(train_targets_path).rglob("*.tif"))
assert len(image_files) == len(
    target_files
), "Number of images and targets do not match"

print(f"Total size of train dataset: {len(image_files)}")

# Split the train data into train and validation
seed = 42
train_files_percentage = 0.8
np.random.seed(seed)
shuffled_indices = np.random.permutation(len(image_files))
image_files = np.array(image_files)[shuffled_indices]
target_files = np.array(target_files)[shuffled_indices]
assert all(
    [i.name == j.name for i, j in zip(image_files, target_files)]
), "Files do not match"

train_image_files = image_files[: int(train_files_percentage * len(image_files))]
train_target_files = target_files[: int(train_files_percentage * len(target_files))]
val_image_files = image_files[int(train_files_percentage * len(image_files)) :]
val_target_files = target_files[int(train_files_percentage * len(target_files)) :]
assert all(
    [i.name == j.name for i, j in zip(train_image_files, train_target_files)]
), "Train files do not match"
assert all(
    [i.name == j.name for i, j in zip(val_image_files, val_target_files)]
), "Val files do not match"

print(f"Train dataset size: {len(train_image_files)}")
print(f"Validation dataset size: {len(val_image_files)}")

# Read the test files
test_image_files = list(test_image_path.rglob("*.tif"))
test_target_files = list(test_target_path.rglob("*.tif"))
print(f"Number of test files: {len(test_image_files)}")

<div class="alert alert-block alert-success"><h3>Task 1.0: Define patching function</h3>
</div>

In the majority of cases microscopy images are too large to be processed at once and need to be divided into smaller patches. We will define a function that takes an image and divides it into patches of a given size.

In [None]:
def generate_patches(
    image_array: np.ndarray,
    target_array: np.ndarray,
    patch_size: Union[List[int], Tuple[int, ...]],
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate patches from an array in a random manner.

    The method calculates how many patches the image can be divided into and then
    extracts an equal number of random patches.

    Parameters
    ----------
    arr : np.ndarray
        Input image array.
    patch_size : Tuple[int]
        Patch sizes in each dimension.

    Yields
    ------
    Generator[np.ndarray, None, None]
        Generator of patches.
    """
    # random generator
    rng = np.random.default_rng()
    image_patches = []
    target_patches = []

    # iterate over the number of samplesin the input array
    for s in range(image_array.shape[0]):
        # calculate the number of patches
        sample = image_array[s]
        target_sample = target_array[s]
        n_patches = np.ceil(np.prod(sample.shape) / np.prod(patch_size)).astype(int)
        # iterate over the number of patches
        for _ in range(n_patches):
            # get crop coordinates
            crop_coords = [
                rng.integers(0, sample.shape[i] - patch_size[i], endpoint=True)
                for i in range(len(patch_size))
            ]

            # extract patch
            patch = (
                sample[
                    (
                        ...,  # type: ignore
                        *[  # type: ignore
                            slice(c, c + patch_size[i])
                            for i, c in enumerate(crop_coords)
                        ],
                    )
                ]
                .copy()
                .astype(np.float32)
            )

            # same for target
            target_patch = (
                target_sample[
                    (
                        ...,  # type: ignore
                        *[  # type: ignore
                            slice(c, c + patch_size[i])
                            for i, c in enumerate(crop_coords)
                        ],
                    )
                ]
                .copy()
                .astype(np.float32)
            )
            # return patch and target patch
            image_patches.append(patch)
            target_patches.append(target_patch)

    return np.stack(image_patches), np.stack(target_patches)

<div class="alert alert-block alert-success"><h3>Create patches</h3>
</div>

To train the network, we will use patches of size 128x128

In [None]:
# Load images and stack them into arrays
train_images_array = np.stack([tifffile.imread(str(f)) for f in train_image_files])
train_targets_array = np.stack([tifffile.imread(str(f)) for f in train_target_files])
val_images_array = np.stack([tifffile.imread(str(f)) for f in val_image_files])
val_targets_array = np.stack([tifffile.imread(str(f)) for f in val_target_files])

test_images_array = np.stack([tifffile.imread(str(f)) for f in test_image_files])
test_targets_array = np.stack([tifffile.imread(str(f)) for f in test_target_files])


print(f"Train images array shape: {train_images_array.shape}")
print(f"Validation images array shape: {val_images_array.shape}")
print(f"Test array shape: {test_images_array.shape}")

In [None]:
# Calculate the mean and std of the train dataset
mean = train_images_array.mean()
std = train_images_array.std()

In [None]:
# Create patches
patch_size = (128, 128)

train_images_patches, train_targets_patches = generate_patches(
    train_images_array, train_targets_array, patch_size
)
assert (
    train_images_patches.shape[0] == train_targets_patches.shape[0]
), "Number of patches do not match"

val_images_patches, val_targets_patches = generate_patches(
    val_images_array, val_targets_array, patch_size
)
assert (
    val_images_patches.shape[0] == val_targets_patches.shape[0]
), "Number of patches do not match"

print(f"Train images patches shape: {train_images_patches.shape}")
print(f"Validation images patches shape: {val_images_patches.shape}")

<div class="alert alert-block alert-success"><h3>Visualize training patches</h3>
</div>

In [None]:
fig, ax = plt.subplots(3, 2, figsize=(15, 15))
ax[0, 0].imshow(train_images_patches[0], cmap="magma")
ax[0, 0].set_title("Train image")
ax[0, 1].imshow(train_targets_patches[0], cmap="magma")
ax[0, 1].set_title("Train target")
ax[1, 0].imshow(train_images_patches[1], cmap="magma")
ax[1, 0].set_title("Train image")
ax[1, 1].imshow(train_targets_patches[1], cmap="magma")
ax[1, 1].set_title("Train target")
ax[2, 0].imshow(train_images_patches[2], cmap="magma")
ax[2, 0].set_title("Train image")
ax[2, 1].imshow(train_targets_patches[2], cmap="magma")
ax[2, 1].set_title("Train target")
plt.tight_layout()

<div class="alert alert-block alert-success"><h3>Task 1.2: Define a dataset class</h3>
</div>

In [None]:
# Define a Dataset
class CAREDataset(Dataset):
    def __init__(
        self, image_data: np.ndarray, target_data: np.ndarray, apply_augmentations=False
    ):
        self.image_data = image_data
        self.target_data = target_data
        self.patch_augment = apply_augmentations

    def __len__(self):
        return self.image_data.shape[
            0
        ]  # Your code here, define the total number of patches

    def __getitem__(self, index):
        # Your code here, return the patch and target patch, 
        # apply augmentations with a condition. Hint: use the augment_batch function
        # apply the normalize function to the patch and target patch
        # return the patch and target patch. Hint: check the dimensions and the datatype

        # get patch
        patch = self.image_data[index]

        # get target
        target = self.target_data[index]

        # Apply transforms
        if self.patch_augment:
            patch, target = augment_batch(patch=patch, target=target)

        # Normalize the patch
        patch = normalize(patch, mean, std)
        target = normalize(target, mean, std)

        return patch[np.newaxis].astype(np.float32), target[np.newaxis].astype(
            np.float32
        )

In [None]:
# Instantiate the dataset and create a DataLoader

train_dataset = CAREDataset(
    image_data=train_images_patches, target_data=train_targets_patches
)
val_dataset = CAREDataset(
    image_data=val_images_patches, target_data=val_targets_patches
)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)

<hr style="height:2px;"><div class="alert alert-block alert-success"><h1>Checkpoint 2: Train a CARE model</h1>
</div>

Image restoration task is very similar to semantic segmentation task we've done in the previous exercise. We can use the same UNet model for this task.

![image](nb_data/carenet.png)

In [None]:
# Load the model
model = unet_module.UNet(depth=3, in_channels=1, out_channels=1)

<div class="alert alert-block alert-success"><h3>Task 2.0: Define the loss function and the optimizer</h3>
</div>

CARE algorithm uses mean squared error as the loss function. We can use the class `MSE` from pytorch.


In [None]:
loss = nn.MSELoss()  # Your code here, define the loss function, hint: think about the suitable loss function
optimizer = optim.Adam(
    model.parameters(), lr=1e-4
)  # Your code here, define the optimizer

<div class="alert alert-block alert-success"><h3>Task 2.1: Train a model</h3>
</div>

Here we will train a CARE model using classes and functions you defined in the previous tasks.
We're using the same training loop we were using in the semantic segmentation exercise.


In [None]:
# Training loop

n_epochs = 3
device = "cuda" if cuda.is_available() else "cpu"
model.to(device)

train_losses = []
val_losses = []

for epoch in range(n_epochs):
    model.train()
    for i, (image_batch, target_batch) in enumerate(train_dataloader):
        batch = image_batch.to(device)
        target = target_batch.to(device)

        optimizer.zero_grad()
        output = model(batch)
        train_loss = loss(output, target)
        train_loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f"Epoch: {epoch}, Batch: {i}, Loss: {train_loss.item()}")

    model.eval()
    with no_grad():
        val_loss = 0
        for i, (batch, target) in enumerate(val_dataloader):
            batch = batch.to(device)
            target = target.to(device)

            output = model(batch)
            val_loss = loss(output, target)

        print(f"Validation loss: {val_loss.item()}")
    
    # Save the losses for plotting
    train_losses.append(train_loss.item())
    val_losses.append(val_loss.item())


<div class="alert alert-block alert-success"><h3>Plot the loss</h3>
</div>

In [None]:
# Plot training and validation losses
plt.figure(figsize=(10, 5))
plt.plot(train_losses)
plt.plot(val_losses)
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend(["Train loss", "Validation loss"])


<div class="alert alert-block alert-success"><h3>Task 2.1: Prediction on the test set</h3>
</div>

Real microscopy images are often too large to be processed in one step, so tiling approach is used to prediction. Note that we call it tiling because it's crucial to reconstruct the image from overlapping tiles to avoid artifacts at the tile boundaries, while during training we used random non-overlapping tiles.
For the sake of simplicity, we will use the whole image for prediction.

In [None]:
# Define the dataset for the test data

test_dataset = CAREDataset(
    image_data=test_images_array, target_data=test_targets_array
)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
# Define the prediction loop

test_images = []
predictions = []

model.eval()
with no_grad():
    for i, (image_batch, target_batch) in enumerate(test_dataloader):
        image_batch = image_batch.to(device)
        target_batch = target_batch.to(device)
        output = model(image_batch)

        # Save the images and predictions for visualization
        test_images.append(denormalize(image_batch.cpu().numpy(), mean, std))
        predictions.append(denormalize(output.cpu().numpy(), mean, std))

<div class="alert alert-block alert-success"><h3>Visualize the predictions</h3>
</div>

In [None]:
fig, ax = plt.subplots(3, 2, figsize=(15, 15))
ax[0, 0].imshow(test_images[0][0].squeeze(), cmap="magma")
ax[0, 0].set_title("Test image")
ax[0, 1].imshow(predictions[0][0].squeeze(), cmap="magma")
ax[0, 1].set_title("Prediction")
ax[1, 0].imshow(test_images[1][0].squeeze(), cmap="magma")
ax[1, 0].set_title("Test image")
ax[1, 1].imshow(predictions[1][0].squeeze(), cmap="magma")
ax[1, 1].set_title("Prediction")
ax[2, 0].imshow(test_images[2][0].squeeze(), cmap="magma")
ax[2, 0].set_title("Test image")
ax[2, 1].imshow(predictions[2][0].squeeze(), cmap="magma")
plt.tight_layout()