## Clone Project

In [None]:
# Clone the repo
!git clone https://<DRIVE>@github.com/java-master007/Adversarial-Representation-Learning-for-Medical-Imaging.git

Cloning into 'Adversarial-Representation-Learning-for-Medical-Imaging'...
remote: Enumerating objects: 499, done.[K
remote: Counting objects: 100% (499/499), done.[K
remote: Compressing objects: 100% (353/353), done.[K
remote: Total 499 (delta 286), reused 343 (delta 134), pack-reused 0[K
Receiving objects: 100% (499/499), 49.62 MiB | 12.27 MiB/s, done.
Resolving deltas: 100% (286/286), done.


In [None]:
# Change to the correct directory
%cd Adversarial-Representation-Learning-for-Medical-Imaging/

/content/Adversarial-Representation-Learning-for-Medical-Imaging


In [None]:
# Install requirements
# It will need restart on colab
! pip install -r requirements.txt

## Image Preparation

Requires that:
- malign.png be 3-channel
- normal.png be 3-channel
- malign_mask.png be one-channel (BINARY)

In [None]:
# Import required libraries
import cv2
import os
from skimage import io as img
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.image as mpimg
import torch
import torchvision as tv
from PIL import Image, ImageDraw, ImageFilter

In [None]:
def get_image_laterality(image):
    left_edge = np.sum(image[:, 0])  
    right_edge = np.sum(image[:, -1])
    return (True, False) if left_edge < right_edge else (False, True)

In [None]:
def get_measures(image):
    positions = np.nonzero(image)
    top = positions[0].min()
    bottom = positions[0].max()
    left = positions[1].min()
    right = positions[1].max()
    return top, right, bottom, left

In [None]:
def get_start_coordinate(image):
    positions = np.nonzero(image)
    bottom = positions[0].max()
    x_bottom = int(np.mean(np.nonzero(image[bottom])))
    return x_bottom, bottom

In [None]:
def get_correct_value(number):
    if number == 0:
        return 0
    else:
        return 1

In [None]:
def image_to_binary(image, pth):
    b_image = []
    for arr in image:
        curr = [get_correct_value(elem) for elem in arr]
        b_image.append(curr)
    b_image = np.array(b_image, dtype=np.uint8)

    os.remove(pth)
    plt.imsave(pth, np.array(b_image), cmap=cm.gray)
    return b_image

In [None]:
def does_collage_mask(width, height, malign, normal):
    
    # Crop both the mass, and the normal
    crop_segmentation(malign, 'malign_aux.png')

    normal_image = Image.open(normal)
    mass_to_paste = Image.open('malign_aux.png')

    # Creates collage and save
    back_im = normal_image.copy()
    back_im.paste(mass_to_paste, (width,height), mass_to_paste)
    
    return list(back_im.getdata()) == list(normal_image.getdata())

In [None]:
def is_collage_possible(malign_mask_pth, normal_breast_pth):

  # Operations Threshold
  threshold = 50

  # Read the images
  malign_mask = cv2.imread(malign_mask_pth, cv2.IMREAD_GRAYSCALE)
  normal_breast = cv2.imread(normal_breast_pth, cv2.IMREAD_GRAYSCALE)
  _, normal_x = normal_breast.shape
  normal_breast = image_to_binary(normal_breast, normal_breast_pth)

  # Get images laterality
  R, _ = get_image_laterality(normal_breast)

  # Get images measures
  # Calculate malign mass measures
  m_top, m_right, m_bottom, m_left = get_measures(malign_mask)

  # Calculate normal breast measures
  n_top, n_right, n_bottom, n_left = get_measures(normal_breast)

  # Calculate widths and heights
  malign_mass_width = abs(m_right-m_left)
  malign_mass_height = abs(m_bottom-m_top)
  normal_breast_width = abs(n_right-n_left)
  normal_breast_height = abs(n_bottom-n_top)

  # Check if its worth the try
  if malign_mass_width > normal_breast_width or malign_mass_height > normal_breast_height:
    return -1, -1

  # Get bottom base coordinate
  bottom_coordinate = get_start_coordinate(normal_breast)

  # Coordinate collage starts bottom
  c, d = bottom_coordinate

  if R:

    # Check if mass is all inside image. If not, then go left + threshold
    if normal_x - c < malign_mass_width:
      c, d = c-(malign_mass_width-(normal_x - c)+threshold), d

    # Go up the height plus the threshold
    c, d = c, d-(malign_mass_height+threshold)

    # Go up until the masks match. If never match then skip them
    while d > threshold:
      if does_collage_mask(c, d, malign_mask_pth, normal_breast_pth):
        return c, d

      c, d = c, d-threshold

    return -1, -1
  else:
    
    # Check if mass is all inside image. If not, then go right + threshold
    if c < malign_mass_width:
      c, d = c+(malign_mass_width-c+threshold), d

    # Go up the height plus the threshold
    c, d = c, d-(malign_mass_height+threshold)

    # Go up until the masks match. If never match then skip them
    while d > threshold:
      if does_collage_mask(c, d, malign_mask_pth, normal_breast_pth):
        return c, d

      c, d = c, d-threshold

    return -1, -1

In [None]:
# Remove the 4 channel to collage image
def remove_4_channel(im_path, output_path):

    img = cv2.imread(im_path, cv2.IMREAD_UNCHANGED)

    # Transpose naive image to properly see it
    tranposed = img.transpose(2,0,1)

    # Transpose image again with only the 3 rgb channels to save
    output = tranposed[0:3].transpose(1,2,0)

    # Save new naive image (3-channels)
    cv2.imwrite(output_path, output)

In [None]:
# Resize image for hamronisation
def resize_image(im_path, percent_original, output_path):
    img = cv2.imread(im_path, cv2.IMREAD_UNCHANGED)
    
    print('Original Dimensions : ',img.shape)
    
    scale_percent = percent_original # percent of original size
    width = int(img.shape[1] * scale_percent / 100)
    height = int(img.shape[0] * scale_percent / 100)
    dim = (width, height)
    
    # resize image
    resized = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
    
    print('Resized Dimensions : ',resized.shape)
    cv2.imwrite(output_path, resized)

In [None]:
# Make mask have 3 channels
def make_3_channels_mask(im_path, out_path):
  i = img.imread(im_path)
  new_i = []
  new_i.append(i)
  new_i.append(i)
  new_i.append(i)
  new_i = torch.tensor(np.array(new_i))
  tv.io.write_png(new_i, out_path)

In [None]:
# Crops the segmentation by its limits
def crop_segmentation(fp, outp):
  imag = cv2.imread(fp, cv2.IMREAD_UNCHANGED)
  imageObject = Image.open(fp)
  positions = np.nonzero(imag)

  top = positions[0].min()
  bottom = positions[0].max()
  left = positions[1].min()
  right = positions[1].max()

  cropped = imageObject.crop((left,top,right,bottom))
  cropped.save(outp)

In [None]:
# Makes a collage given the malign image, the malign mask, and the normal image
def make_collage(malign_pth, malign_mask_pth, normal_pth, width, height):

  # Reads malign base image
  malign = cv2.imread(malign_pth, cv2.IMREAD_UNCHANGED)

  # Convert mask to 3 channels
  make_3_channels_mask(malign_mask_pth, '/content/malign_mask3.png')
  malign_mask = cv2.imread('/content/malign_mask3.png', cv2.IMREAD_UNCHANGED)

  # Grab the image mask from the mass image
  masked = malign.copy()
  masked[malign_mask == 0] = 0
  cv2.imwrite('/content/segmented_mass.png', masked)

  # Crop both the mask, and the masked mass
  crop_segmentation('/content/segmented_mass.png', '/content/cropped_mass.png')
  crop_segmentation(malign_mask_pth, '/content/malign_mask_cropped.png')

  normal_image = Image.open(normal_pth)
  mass_to_paste = Image.open('/content/cropped_mass.png')
  mass_mask = Image.open('/content/malign_mask_cropped.png')

  # Creates collage and save
  back_im = normal_image.copy()
  #TODO: Calculate how to paste
  back_im.paste(mass_to_paste, (width,height), mass_mask)
  back_im.save('/content/collage.png', quality=95)

  # Creates collage mask
  collage_mask = Image.new("L", back_im.size, 0)
  collage_mask.paste(mass_mask, (width,height))
  collage_mask.save('/content/collage_mask.png', quality=95)

  # Deletes unecessary images
  try:
    os.remove('/content/malign_mask3.png')
    os.remove('/content/segmented_mass.png')
    os.remove('/content/cropped_mass.png')
    os.remove('/content/malign_mask_cropped.png')
  except OSError as e:
    print(f"FAILED\nFile: {e.filename}\nError: {e.strerror}")


In [None]:
w, h = is_collage_possible(malign_mask_pth='malign_mask.png', normal_breast_pth='normal.png')
w, h

In [None]:
w_collage, h_collage = is_collage_possible('/content/malign_mask.png', '/content/normal.png')

In [None]:
make_collage(malign_pth='/content/malign.png', malign_mask_pth='/content/malign_mask.png', normal_pth='/content/normal.png', width=1000, height=1000)

## Harmonizer

### Harmonzer Train

In [None]:
# Change to correct directory
%cd MedSinGAN/

/content/Adversarial-Representation-Learning-for-Medical-Imaging/MedSinGAN


In [None]:
# Make the collage mask 3-channel
make_3_channels_mask('/content/collage_mask.png', '/content/collage_mask3.png')
os.remove('/content/collage_mask.png')
os.rename('/content/collage_mask3.png', '/content/collage_mask.png')

In [None]:
# Normal breast collage Harmonizer creation
!python main_train.py --train_mode harmonization --gpu 0 --train_stages 3 --im_max_size 720 --lrelu_alpha 0.3 --niter 1000 --batch_norm --input_name /content/normal.png --naive_img /content/collage.png

Training model (TrainedModels/normal/2022_01_19_14_52_13_harmonization_niter_1000_lr_scale_0.1_nstages_3_BN_act_lrelu_0.3)
Training model with the following parameters:
	 number of stages: 3
	 number of concurrently trained stages: 3
	 learning rate scaling: 0.1
	 non-linearity: lrelu
Training on image pyramid: [torch.Size([1, 3, 31, 25]), torch.Size([1, 3, 149, 121]), torch.Size([1, 3, 720, 585])]

stage [0/2]:: 100% 1000/1000 [00:48<00:00, 20.51it/s]
stage [1/2]:: 100% 1000/1000 [10:42<00:00,  1.56it/s]
stage [2/2]:: 100% 1000/1000 [1:20:51<00:00,  4.85s/it]
Time for training: 5553.907527446747 seconds


### Fine-Tune

In [None]:
# Get the latest model
def get_latest_model():
  base_path = "TrainedModels/normal/"
  models = os.listdir(base_path)

  latest = 0 # Values will always be bigger than 0
  desired = models[0]

  for id, model in enumerate(models):
    splitted = model.split("_")
    code = splitted[:6]
    code = int(''.join(code))
    if code > latest:
      latest = code
      desired = model

  return os.path.join(base_path, desired)

In [None]:
# FINE TUNE
m = get_latest_model()
fine_tune_cmd = "python main_train.py --gpu 0 --train_mode harmonization --input_name /content/normal.png --naive_img /content/collage.png --fine_tune --model_dir " + str(m)
os.system(fine_tune_cmd)

In [None]:
# FINE TUNE
# !python main_train.py --gpu 0 --train_mode harmonization --input_name /content/normal.png --naive_img /content/collage.png --fine_tune --model_dir TrainedModels/normal/2022_01_17_19_51_18_harmonization_niter_1000_lr_scale_0.1_nstages_8_BN_act_lrelu_0.3 

### Harmonise The Naive

In [None]:
# Normal breast collage harmonisation
m = get_latest_model()
harmonise_cmd = "python evaluate_model.py --gpu 0 --model_dir " + str(m) + " --naive_img /content/collage.png"
os.system(harmonise_cmd)

In [None]:
# Normal breast collage harmonisation
#!python evaluate_model.py --gpu 0 --model_dir TrainedModels/normal/2022_01_17_21_07_38_harmonization_fine-tune_niter_2000_lr_scale_0.1_nstages_8_BN_act_lrelu_0.3 --naive_img /content/collage.png

0

### Evaluate Results

In [None]:
# Resizes an image to a specific dimension
def resize_to_dim(img_pth, width, height, out_pth):
  base = cv2.imread(img_pth, cv2.IMREAD_UNCHANGED)
  dim = (width, height)
  resized = cv2.resize(base, dim)
  cv2.imwrite(out_pth, resized)

In [None]:
resize_to_dim('/content/normal.png', 204, 250, '/content/normal_resized.png')

In [None]:
import evaluate_generation

base_img = '/content/normal_resized.png'
eval_folder = os.path.join(get_latest_model(),'Evaluation_/content/collage.png')

evaluator = evaluate_generation.GenerationEvaluator(base_img, eval_folder)

lpips = evaluator.run_lpips()
ssim, ms_ssim = evaluator.run_mssim()
print(f"LPIPS: {lpips}\nSSIM: {ssim}\nMS-SSIM: {ms_ssim}")


LPIPS: 0.1152394488453865
SSIM: 0.8789518475532532
MS-SSIM: 0.9195001125335693


## Save Harmonizer Model

In [None]:
# Import files to download zips
from google.colab import files

In [None]:
# Zip the mlruns metrics to analyse
!zip -r /content/mlrun.zip /content/Adversarial-Representation-Learning-for-Medical-Imaging/MedSinGAN/mlruns
files.download("/content/mlrun.zip")

In [None]:
# Zip the best model analysed based on the mlruns
m = get_latest_model()
zip_cmd = "zip -r /content/best_harmonisation_model.zip " + str(m)
os.system(zip_cmd)
files.download("/content/best_harmonisation_model.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Utils

In [None]:
# Remove the MLFlow runs
! rm -r /content/Adversarial-Representation-Learning-for-Medical-Imaging/MedSinGAN/mlruns

In [None]:
# Remove all trained models
! rm -r /content/Adversarial-Representation-Learning-for-Medical-Imaging/MedSinGAN/TrainedModels

In [None]:
# Remove all images
%cd ../..
! find . -name "*.png" -type f -delete

/content
