## Image Preparation

Requires that:
- malign.png be 3-channel
- normal.png be 3-channel
- malign_mask.png be one-channel (BINARY)

In [None]:
# Import required libraries
import cv2
import os
from skimage import io as img
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
import subprocess
import torchvision as tv
from PIL import Image

In [None]:
def get_image_laterality(image):
    left_edge = np.sum(image[:, 0])  
    right_edge = np.sum(image[:, -1])
    return (True, False) if left_edge < right_edge else (False, True)

In [None]:
def get_measures(image):
    positions = np.nonzero(image)
    top = positions[0].min()
    bottom = positions[0].max()
    left = positions[1].min()
    right = positions[1].max()
    return top, right, bottom, left

In [None]:
def get_start_coordinate(image, mass_path, laterality_r):
    imag = cv2.imread(mass_path, cv2.IMREAD_UNCHANGED)
    positions = np.nonzero(image)
    left = positions[1].min()
    right = positions[1].max()
    vertical_co = positions[0][list(positions[1]).index(left)]
    vertical_co_r = positions[0][list(positions[1]).index(right)]

    if laterality_r:
        return left, int(vertical_co-imag.shape[1]/2)
    else:
        return right, int(vertical_co_r-imag.shape[1]/2)

In [None]:
def get_correct_value(number):
    if number == 0:
        return 0
    else:
        return 1

In [None]:
def image_to_binary(image, pth):
    b_image = []
    for arr in image:
        curr = [get_correct_value(elem) for elem in arr]
        b_image.append(curr)
    b_image = np.array(b_image, dtype=np.uint8)

    plt.imsave(pth, np.array(b_image), cmap=cm.gray)
    return b_image

In [None]:
def does_collage_mask(width, height, normal):

    normal_image = Image.open(normal)
    mass_to_paste = Image.open('malign_aux.png')

    # Creates collage and save
    back_im = normal_image.copy()
    back_im.paste(mass_to_paste, (width,height), mass_to_paste)
    
    return list(back_im.getdata()) == list(normal_image.getdata())

In [None]:
def is_collage_possible(malign_mask_pth, normal_breast_pth):

  # Operations Threshold
  threshold = 50

  # Read the images
  malign_mask = cv2.imread(malign_mask_pth, cv2.IMREAD_GRAYSCALE)
  normal_breast = cv2.imread(normal_breast_pth, cv2.IMREAD_GRAYSCALE)
  _, normal_x = normal_breast.shape
  normal_breast = image_to_binary(normal_breast, 'normal_aux.png')

  # Get images laterality
  R, _ = get_image_laterality(normal_breast)

  # Get images measures
  # Calculate malign mass measures
  m_top, m_right, m_bottom, m_left = get_measures(malign_mask)

  # Calculate normal breast measures
  n_top, n_right, n_bottom, n_left = get_measures(normal_breast)

  # Calculate widths and heights
  malign_mass_width = abs(m_right-m_left)
  malign_mass_height = abs(m_bottom-m_top)
  normal_breast_width = abs(n_right-n_left)
  normal_breast_height = abs(n_bottom-n_top)

  # Check if its worth the try
  if malign_mass_width > normal_breast_width or malign_mass_height > normal_breast_height:
    return -1, -1

  # Crop the malign mask
  crop_segmentation(malign_mask_pth, 'malign_aux.png')

  # Get bottom base coordinate
  base_coordinate = get_start_coordinate(normal_breast, 'malign_aux.png', R)

  # Coordinate collage starts bottom
  c, d = base_coordinate

  if R:

    # Go up until the masks match. If never match then skip them
    while c < normal_breast.shape[0]:
      if does_collage_mask(c, d, 'normal_aux.png'):
        return c, d

      c, d = c+threshold, d

    return -1, -1
  else:

    # Go up until the masks match. If never match then skip them
    while c > 0:
      if does_collage_mask(c, d, 'normal_aux.png'):
        return c, d

      c, d = c-threshold, d

    return -1, -1

In [None]:
# Remove the 4 channel to collage image
def remove_4_channel(im_path, output_path):

    img = cv2.imread(im_path, cv2.IMREAD_UNCHANGED)

    # Transpose naive image to properly see it
    tranposed = img.transpose(2,0,1)

    # Transpose image again with only the 3 rgb channels to save
    output = tranposed[0:3].transpose(1,2,0)

    # Save new naive image (3-channels)
    cv2.imwrite(output_path, output)

In [None]:
# Resize image for hamronisation
def resize_image(im_path, percent_original, output_path):
    img = cv2.imread(im_path, cv2.IMREAD_UNCHANGED)
    
    print('Original Dimensions : ',img.shape)
    
    scale_percent = percent_original # percent of original size
    width = int(img.shape[1] * scale_percent / 100)
    height = int(img.shape[0] * scale_percent / 100)
    dim = (width, height)
    
    # resize image
    resized = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
    
    print('Resized Dimensions : ',resized.shape)
    cv2.imwrite(output_path, resized)

In [None]:
# Make mask have 3 channels
def make_3_channels_mask(im_path, out_path):
  i = img.imread(im_path)
  new_i = []
  new_i.append(i)
  new_i.append(i)
  new_i.append(i)
  new_i = torch.tensor(np.array(new_i))
  tv.io.write_png(new_i, out_path)

In [None]:
# Crops the segmentation by its limits
def crop_segmentation(fp, outp):
  imag = cv2.imread(fp, cv2.IMREAD_UNCHANGED)
  imageObject = Image.open(fp)
  positions = np.nonzero(imag)

  top = positions[0].min()
  bottom = positions[0].max()
  left = positions[1].min()
  right = positions[1].max()

  cropped = imageObject.crop((left,top,right,bottom))
  cropped.save(outp)

In [None]:
# Makes a collage given the malign image, the malign mask, and the normal image
def make_collage(malign_pth, malign_mask_pth, normal_pth, width, height):

  # Reads malign base image
  malign = cv2.imread(malign_pth, cv2.IMREAD_UNCHANGED)

  # Convert mask to 3 channels
  make_3_channels_mask(malign_mask_pth, 'malign_mask3.png')
  malign_mask = cv2.imread('malign_mask3.png', cv2.IMREAD_UNCHANGED)

  # Grab the image mask from the mass image
  masked = malign.copy()
  masked[malign_mask == 0] = 0
  cv2.imwrite('segmented_mass.png', masked)

  # Crop both the mask, and the masked mass
  crop_segmentation('segmented_mass.png', 'cropped_mass.png')
  crop_segmentation(malign_mask_pth, 'malign_mask_cropped.png')

  normal_image = Image.open(normal_pth)
  mass_to_paste = Image.open('cropped_mass.png')
  mass_mask = Image.open('malign_mask_cropped.png')

  # Creates collage and save
  back_im = normal_image.copy()
  back_im.paste(mass_to_paste, (width,height), mass_mask)
  back_im.save('collage.png', quality=95)

  # Creates collage mask
  collage_mask = Image.new("L", back_im.size, 0)
  collage_mask.paste(mass_mask, (width,height))
  collage_mask.save('collage_mask.png', quality=95)

  # Deletes unecessary images
  try:
    os.remove('malign_mask3.png')
    os.remove('segmented_mass.png')
    os.remove('cropped_mass.png')
    os.remove('malign_mask_cropped.png')
  except OSError as e:
    print(f"FAILED\nFile: {e.filename}\nError: {e.strerror}")

In [None]:
w, h = is_collage_possible(malign_mask_pth='malign_mask.png', normal_breast_pth='normal.png')
w, h

(1048, 997)

In [None]:
make_collage(malign_pth='malign.png', malign_mask_pth='malign_mask.png', normal_pth='normal.png', width=w, height=h)

In [None]:
# Crop the collage and mask
imag = cv2.imread('collage.png', cv2.IMREAD_UNCHANGED)
imageObject = Image.open('collage.png')
positions = np.nonzero(imag)

top = positions[0].min()
bottom = positions[0].max()
left = positions[1].min()
right = positions[1].max()

cropped = imageObject.crop((left,top,right,bottom))
cropped.save('new_collage.png')

imageObject = Image.open('collage_mask.png')
cropped = imageObject.crop((left,top,right,bottom))
cropped.save('new_collage_mask.png')

In [None]:
crop_segmentation('normal.png', 'normal_cropped.png')

## Harmonizer

### Harmonzer Train

In [None]:
# Change to correct directory
%cd MedSinGAN/

/content/Adversarial-Representation-Learning-for-Medical-Imaging/MedSinGAN


In [None]:
# Make the collage mask 3-channel
make_3_channels_mask('collage_mask.png', 'collage_mask3.png')
os.remove('collage_mask.png')
os.rename('collage_mask3.png', 'collage_mask.png')

In [None]:
def execute(cmd):
    popen = subprocess.Popen(cmd, stdout=subprocess.PIPE, universal_newlines=True)
    for stdout_line in iter(popen.stdout.readline, ""):
        yield stdout_line 
    popen.stdout.close()
    return_code = popen.wait()
    if return_code:
        raise subprocess.CalledProcessError(return_code, cmd)

In [None]:
# Normal breast collage Harmonizer creation
command = "python main_train.py --train_mode harmonization --gpu 0 --train_stages 3 --im_min_size 120 --lrelu_alpha 0.3 --niter 1000 --batch_norm --input_name normal.png --naive_img collage.png"
#process = subprocess.Popen(command.split(), stdout=subprocess.PIPE)
#output, error = process.communicate()
#if error:
#    print("Problem training harmoniser! Terminating...")
#else:
#    print(output)

for path in execute(command.split()):
    print(path, end="")

Training model (TrainedModels/normal/2022_02_03_09_44_38_harmonization_niter_1000_lr_scale_0.1_nstages_3_BN_act_lrelu_0.3)
Training model with the following parameters:
	 number of stages: 3
	 number of concurrently trained stages: 3
	 learning rate scaling: 0.1
	 non-linearity: lrelu
Training on image pyramid: [torch.Size([1, 3, 148, 120]), torch.Size([1, 3, 192, 157]), torch.Size([1, 3, 250, 204])]



KeyboardInterrupt: ignored

### Fine-Tune

In [None]:
# Get the latest model
def get_latest_model():
  base_path = "TrainedModels/normal_cropped/"
  models = os.listdir(base_path)

  latest = 0 # Values will always be bigger than 0
  desired = models[0]

  for id, model in enumerate(models):
    splitted = model.split("_")
    code = splitted[:6]
    code = int(''.join(code))
    if code > latest:
      latest = code
      desired = model

  return os.path.join(base_path, desired)

In [None]:
# FINE TUNE
m = get_latest_model()
fine_tune_cmd = "python main_train.py --gpu 0 --train_mode harmonization --input_name normal_cropped.png --naive_img new_collage.png --fine_tune --model_dir " + str(m)
os.system(fine_tune_cmd)

0

In [None]:
# FINE TUNE
# !python main_train.py --gpu 0 --train_mode harmonization --input_name normal.png --naive_img collage.png --fine_tune --model_dir TrainedModels/normal/2022_01_17_19_51_18_harmonization_niter_1000_lr_scale_0.1_nstages_8_BN_act_lrelu_0.3 

### Harmonise The Naive

In [None]:
# Normal breast collage harmonisation
m = get_latest_model()
harmonise_cmd = "python evaluate_model.py --gpu 0 --model_dir " + str(m) + " --naive_img new_collage.png"
os.system(harmonise_cmd)

0

In [None]:
# Normal breast collage harmonisation
#!python evaluate_model.py --gpu 0 --model_dir TrainedModels/normal/2022_01_17_21_07_38_harmonization_fine-tune_niter_2000_lr_scale_0.1_nstages_8_BN_act_lrelu_0.3 --naive_img collage.png

### Evaluate Results

In [None]:
# Resizes an image to a specific dimension
def resize_to_dim(img_pth, width, height, out_pth):
  base = cv2.imread(img_pth, cv2.IMREAD_UNCHANGED)
  dim = (width, height)
  resized = cv2.resize(base, dim)
  cv2.imwrite(out_pth, resized)

In [None]:
resize_to_dim('normal_cropped.png', 204, 250, 'normal_resized.png')

In [None]:
from MedSinGAN.evaluate_generation import GenerationEvaluator

base_img = 'normal_resized.png'
eval_folder = os.path.join(get_latest_model(),'Evaluation_new_collage.png')

evaluator = GenerationEvaluator(base_img, eval_folder)

lpips = evaluator.run_lpips()
ssim, ms_ssim = evaluator.run_mssim()
print(f"LPIPS: {lpips}\nSSIM: {ssim}\nMS-SSIM: {ms_ssim}")

## Save Model

In [None]:
# Zip everything from this model
!zip -r current_model.zip .

## Delete Current Model

In [None]:
# Delete everything that the model produces
! rm -r mlruns TrainedModels