# StainNET
All the codes in this notebook are referred to the **StainNET package** and the relative [paper](https://www.frontiersin.org/journals/medicine/articles/10.3389/fmed.2021.746307/full).\
StainNET is derived from the StainGAN algorithm, in fact, both the algorithms will be applied.

StainNET repository and relative tutorial notebook:\
https://github.com/khtao/StainNet\
https://github.com/khtao/StainNet/blob/master/demo.ipynb

StainGAN:\
https://github.com/xtarx/StainGAN


Once again, the author suggests to use the conda through this [link](https://anaconda.org/conda-forge/python-spams) and the related code for installing the SPAMS dependency:
```bash
conda install conda-forge::python-spams
conda install conda-forge/label/broken::python-spams
conda install conda-forge/label/cf201901::python-spams
conda install conda-forge/label/cf202003::python-spams
conda install conda-forge/label/gcc7::python-spams
```

moreover, the package GitHub repository had to be cloned in the `../data/packages/` folder for successfully being able to import the pre-trained models:
```bash
git clone https://github.com/khtao/StainNet.git
```

---
# 0. - Imports and setting paths

In this case we have to set the working directory at first because we have to import the NN models from the previously downloaded `models.py`file.

In [22]:
import numpy as np
import subprocess
import torch
import sys
import datetime
from PIL import Image 
import matplotlib.pyplot as plt

%matplotlib inline

In [23]:
import os
os.chdir("/disk2/user/gabgam/work/gigi_env/the_project/2_image_normalisation/")
print(os.getcwd())

/disk2/work/gabgam/gigi_env/the_project/2_image_normalisation


In [24]:
sys.path.append('../data/packages/StainNet/')

In [25]:
from models import StainNet, ResnetGenerator

In [26]:
# setting a single GPU as the only visible one
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 0 = first GPU, 1 = second GPU

In [27]:
# #INPUT_FOLDER = "../1_tiling/output/satac_C1/tiling_output/v3_allspots/tiles_100/"  # Replace with the path to your folder containing images
# INPUT_FOLDER = "../1_tiling/output/satac_C1/tiling_output/v3_allspots/tiles_68/"  # Replace with the path to your folder containing images

In [28]:
#INPUT_FOLDER = "../1_tiling/output/visium_2022_FF_WG_10X/tiling_output/img_not_changed_allspots/tiles_100"  # Replace with the path to your folder containing images
# INPUT_FOLDER = "../1_tiling/output/visium_2022_FF_WG_10X/tiling_output/img_not_changed_allspots/tiles_68"  # Replace with the path to your folder containing images


In [29]:
# INPUT_FOLDER = "../1_tiling/output/visium_FFPE_dcis_idc_10X/tiling_output/img_not_changed_allspots/tiles_100"  # Replace with the path to your folder containing images
INPUT_FOLDER = "../1_tiling/output/visium_FFPE_dcis_idc_10X/tiling_output/img_not_changed_allspots/tiles_68"  # Replace with the path to your folder containing images

---
# 1. - Normalisation

In this case, normalisation doesn't work with a target image.\
The idea is that the NN has to be trained and then normalisation will be perfored on our images, so, we can say that the "target" are the images used for the training of the network. In my case, I'll not train the model as right now I don't have time to do that, maybe in the future I'll do it, but now I'll simply use the pretrained models proposed by the [package](https://github.com/khtao/StainNet/tree/master/checkpoints).\
I would like to highlight that one of these models was trained with the very large [Camelyon16 imaging dataset](https://camelyon16.grand-challenge.org/), derived from sentinel lymph nodes of breast cancer patients of two different medical centers.

In [30]:
pretrained_models_path = '../data/packages/StainNet/checkpoints/camelyon16_dataset/'
print(os.listdir(pretrained_models_path))

['StainNet-Public-centerUni_layer3_ch32.pth', 'latest_net_G_A.pth', 'latest_net_G_B.pth']


We have to define some functions that are useful for our purposes.

In [31]:
def norm(image):
    image = np.array(image).astype(np.float32)
    image = image.transpose((2, 0, 1))
    image = ((image / 255) - 0.5) / 0.5
    image=image[np.newaxis, ...]
    image=torch.from_numpy(image)
    return image

def un_norm(image):
    image = image.cpu().detach().numpy()[0]
    image = ((image * 0.5 + 0.5) * 255).astype(np.uint8).transpose((1,2,0))
    return image

## 1.1 - StainGAN normalisation

Setting up the paths.

In [32]:
# setting the paths
normalisation_method = 'stainGAN'

tiles_info = INPUT_FOLDER.split('/')

# Remember: no target needed
output_folder = f"./output/{tiles_info[3]}/{tiles_info[5]}/{tiles_info[6]}/{normalisation_method}"
print(output_folder)

# Let's create the output folder files
os.makedirs(output_folder, exist_ok=True)

./output/visium_FFPE_dcis_idc_10X/img_not_changed_allspots/tiles_68/stainGAN


Loading the pre-trained NN.

### 1.1.1 - Model A: pre-trained for StainGAN

Let's load the first model: `latest_net_G_A.pth`.

In [33]:
# load pretrained StainGAN
model_GAN = ResnetGenerator(3, 3, ngf=64, norm_layer=torch.nn.InstanceNorm2d, n_blocks=9).cuda()
model_GAN.load_state_dict(torch.load('../data/packages/StainNet/checkpoints/camelyon16_dataset/latest_net_G_A.pth'))

  model_GAN.load_state_dict(torch.load('../data/packages/StainNet/checkpoints/camelyon16_dataset/latest_net_G_A.pth'))


<All keys matched successfully>

Real looping normalisation:

In [34]:
starttime = datetime.datetime.now()

# Path for model A
output_folder_model_A = os.path.join(output_folder, "model_A")
os.makedirs(output_folder_model_A, exist_ok=True)

# ---------------------------------------------------------------------------------
# File to log images that fail normalization
normalisation_fails_file = f"{output_folder_model_A}/0_failed_to_normalise.txt" # 0 just for having the file listed as first

with open(normalisation_fails_file, "w") as file:
    file.write("The following are the tiles not normalised:\n")
    
    # Process each image in the input folder
    for filename in os.listdir(INPUT_FOLDER):
        image_path = os.path.join(INPUT_FOLDER, filename)

        # Load and preprocess the image
        img = Image.open(image_path).convert("RGB")
        # print(img.size)
        try:
            # Perform normalization
            model_GAN.eval()
            with torch.no_grad():
                img_gan=model_GAN(norm(img).cuda())
                img_normed_array = un_norm(img_gan)
                # print(f"Normalized array shape: {img_normed_array.shape}")
                
            # Convert the normalized image back to PIL format
            img_normed_pil = Image.fromarray(img_normed_array)
            
            # Ensure output matches input size
            if img_normed_pil.size != img.size:
                print(f"Had to perform resizing step! Original size = {img.size}, Size after GAN normalization = {img_normed_pil.size}")
                img_normed_pil = img_normed_pil.resize(img.size, Image.Resampling.LANCZOS)
                    
            #print(img_normed_pil.size)
            # Save the normalized image
            output_path = os.path.join(output_folder_model_A, f"{os.path.splitext(filename)[0]}_{normalisation_method}_modelA.jpg") # or .png (but it's way bigger)
            img_normed_pil.save(output_path)

            #print(f"Normalized image saved to: {output_path}")
            
        except Exception as e:
            file.write(f"{filename}\n")
            #print(f"Error processing {filename}: {e}")


difference =  datetime.datetime.now() - starttime

# eventually deleting the previous time log file
for filename in os.listdir(output_folder_model_A):
    if filename.startswith("0_started_"):
        file_path = os.path.join(output_folder_model_A, filename)
        if os.path.isfile(file_path):  # Check if it is a file
            os.remove(file_path)      # Delete the file
            print(f"Deleted: {file_path}")

# saving the start and finish time in the file's name for simplicity in the reading.
with open(f"{output_folder_model_A}/0_started_at_{starttime}_finished_at_{datetime.datetime.now()}.txt", "w") as file:
    file.write(f"The run started at {starttime} and finished at {datetime.datetime.now()}.")

print(f"Finished! The normalisation took {difference} seconds!")


Finished! The normalisation took 0:00:28.313588 seconds!


### 1.1.2 - Model B: pre-trained for StainGAN
Let's load the second model: `latest_net_G_B.pth`.

In [35]:
# load pretrained StainGAN
model_GAN = ResnetGenerator(3, 3, ngf=64, norm_layer=torch.nn.InstanceNorm2d, n_blocks=9).cuda()
model_GAN.load_state_dict(torch.load('../data/packages/StainNet/checkpoints/camelyon16_dataset/latest_net_G_B.pth'))

  model_GAN.load_state_dict(torch.load('../data/packages/StainNet/checkpoints/camelyon16_dataset/latest_net_G_B.pth'))


<All keys matched successfully>

Real looping normalisation:

In [36]:
starttime = datetime.datetime.now()

# Path for model A
output_folder_model_B = os.path.join(output_folder, "model_B")
os.makedirs(output_folder_model_B, exist_ok=True)

# ---------------------------------------------------------------------------------
# File to log images that fail normalization
normalisation_fails_file = f"{output_folder_model_B}/0_failed_to_normalise.txt" # 0 just for having the file listed as first

with open(normalisation_fails_file, "w") as file:
    file.write("The following are the tiles not normalised:\n")
    
    # Process each image in the input folder
    for filename in os.listdir(INPUT_FOLDER):
        image_path = os.path.join(INPUT_FOLDER, filename)

        # Load and preprocess the image
        img = Image.open(image_path).convert("RGB")

        try:
            # Perform normalization
            model_GAN.eval()
            with torch.no_grad():
                img_gan=model_GAN(norm(img).cuda())
                img_normed_array=un_norm(img_gan)

            # Convert the normalized image back to PIL format
            img_normed_pil = Image.fromarray(img_normed_array)
            
            # Ensure output matches input size
            if img_normed_pil.size != img.size:
                print(f"Had to perform resizing step! Original size = {img.size}, Size after GAN normalization = {img_normed_pil.size}")
                img_normed_pil = img_normed_pil.resize(img.size, Image.Resampling.LANCZOS)
            
            # Save the normalized image
            output_path = os.path.join(output_folder_model_B, f"{os.path.splitext(filename)[0]}_{normalisation_method}_modelB.jpg") # or .png (but it's way bigger)
            img_normed_pil.save(output_path)

            #print(f"Normalized image saved to: {output_path}")
            
        except Exception as e:
            file.write(f"{filename}\n")
            #print(f"Error processing {filename}: {e}")


difference =  datetime.datetime.now() - starttime

# eventually deleting the previous time log file
for filename in os.listdir(output_folder_model_B):
    if filename.startswith("0_started_"):
        file_path = os.path.join(output_folder_model_B, filename)
        if os.path.isfile(file_path):  # Check if it is a file
            os.remove(file_path)      # Delete the file
            print(f"Deleted: {file_path}")

# saving the start and finish time in the file's name for simplicity in the reading.
with open(f"{output_folder_model_B}/0_started_at_{starttime}_finished_at_{datetime.datetime.now()}.txt", "w") as file:
    file.write(f"The run started at {starttime} and finished at {datetime.datetime.now()}.")

print(f"Finished! The normalisation took {difference} seconds!")

Finished! The normalisation took 0:00:19.080976 seconds!


## 1.2 - StainNET normalisation

Let's set the correct paths and load the real StainNET pre-trained model `StainNet-Public-centerUni_layer3_ch32.pth`.

In [37]:
# setting the paths
normalisation_method = 'stainNET'

tiles_info = INPUT_FOLDER.split('/')

# Remember: no target needed
output_folder_stainnet = f"./output/{tiles_info[3]}/{tiles_info[5]}/{tiles_info[6]}/{normalisation_method}"
print(output_folder_stainnet)

# Let's create the output folder files
os.makedirs(output_folder_stainnet, exist_ok=True)

./output/visium_FFPE_dcis_idc_10X/img_not_changed_allspots/tiles_68/stainNET


In [38]:
#load  pretrained StainNet
model_Net = StainNet().cuda()
model_Net.load_state_dict(torch.load("../data/packages/StainNet/checkpoints/camelyon16_dataset/StainNet-Public-centerUni_layer3_ch32.pth", weights_only=True))

<All keys matched successfully>

Real looping normalisation:

In [39]:
starttime = datetime.datetime.now()

# ---------------------------------------------------------------------------------
# File to log images that fail normalization
normalisation_fails_file = f"{output_folder_stainnet}/0_failed_to_normalise.txt" # 0 just for having the file listed as first

with open(normalisation_fails_file, "w") as file:
    file.write("The following are the tiles not normalised:\n")
    
    # Process each image in the input folder
    for filename in os.listdir(INPUT_FOLDER):
        image_path = os.path.join(INPUT_FOLDER, filename)

        # Load and preprocess the image
        img = Image.open(image_path).convert("RGB")

        try:
            # Perform normalization
            model_Net.eval()
            with torch.no_grad():
                img_net=model_Net(norm(img).cuda())
                img_normed_array=un_norm(img_net)

            # Convert the normalized image back to PIL format
            img_normed_pil = Image.fromarray(img_normed_array)

            # Save the normalized image
            output_path = os.path.join(output_folder_stainnet, f"{os.path.splitext(filename)[0]}_{normalisation_method}.jpg") # or .png (but it's way bigger)
            img_normed_pil.save(output_path)

            #print(f"Normalized image saved to: {output_path}")
            
        except Exception as e:
            file.write(f"{filename}\n")
            #print(f"Error processing {filename}: {e}")


difference =  datetime.datetime.now() - starttime

# eventually deleting the previous time log file
for filename in os.listdir(output_folder_stainnet):
    if filename.startswith("0_started_"):
        file_path = os.path.join(output_folder_stainnet, filename)
        if os.path.isfile(file_path):  # Check if it is a file
            os.remove(file_path)      # Delete the file
            print(f"Deleted: {file_path}")

# saving the start and finish time in the file's name for simplicity in the reading.
with open(f"{output_folder_stainnet}/0_started_at_{starttime}_finished_at_{datetime.datetime.now()}.txt", "w") as file:
    file.write(f"The run started at {starttime} and finished at {datetime.datetime.now()}.")

print(f"Finished! The normalisation took {difference} seconds!")


Finished! The normalisation took 0:00:05.756819 seconds!


---
# Final - Saving the environment requirements

In [40]:
# Save package versions to a .txt file
with open("requirements_for_stainnet_env.txt", "w") as f:
    subprocess.run(["pip", "freeze"], stdout=f)