In [None]:
"""
Code adapted from https://github.com/ngaggion/HybridGNet
"""

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import glob
import cv2

## Train-Val-Test splits

In [None]:
with open('train_list.txt') as file:
    train_list = [line.rstrip() for line in file]
    
with open('val_list.txt') as file:
    val_list = [line.rstrip() for line in file]
    
with open('test_list.txt') as file:
    test_list = [line.rstrip() for line in file]

## Process Landmarks

In [None]:
RL_files = glob.glob('../All_Landmarks/RL/MCU*.npy')
LL_files = glob.glob('../All_Landmarks/LL/MCU*.npy')
len(RL_files), len(LL_files)

In [None]:
for RL_file, LL_file in zip(RL_files, LL_files):
    RL = np.load(RL_file)
    LL = np.load(LL_file)
    L = np.concatenate([RL, LL], axis=0)
    np.save('Landmarks/' + LL_file.split('/')[-1], L)

## Preprocess Images

In [None]:
all_files = glob.glob('Images/*.png')
len(all_files), all_files[:2]

In [None]:
i = 1

for file in all_files:
    print('\r',i,'of', len(all_files),end='')

    img = cv2.imread(file, 0)

    gray = 255*(img > 1) # To invert the text to white
    coords = cv2.findNonZero(gray) # Find all non-zero points (text)

    x, y, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
    cropimg = img[y:y+h, x:x+w] # Crop the image - note we do this on the original image

    shape = cropimg.shape

    if shape[0] < shape[1]:
        pad = (shape[1] - shape[0])    
        
        if pad % 2 == 1:
            pad = pad // 2
            pad_y = [pad, pad+1]
        else:
            pad = pad // 2
            pad_y = [pad, pad]
            
        pad_x = [0, 0]
    elif shape[1] < shape[0]:
        pad = (shape[0] - shape[1]) 
        
        if pad % 2 == 1:
            pad = pad // 2
            pad_x = [pad, pad+1]
        else:
            pad = pad // 2
            pad_x = [pad, pad]
            
        pad_y = [0, 0]

    img = np.pad(cropimg, pad_width = [pad_y, pad_x])    

    if img.shape[0] != img.shape[1]:
        print('Error padding image')
        break

    img_ = cv2.resize(img, [1024, 1024])
    
    if file.split('/')[-1].split('.')[0] in train_list:
        cv2.imwrite('Train/'+file, img_)
        #pass
    elif file.split('/')[-1].split('.')[0] in val_list:
        cv2.imwrite('Val/'+file, img_)
        #pass
    elif file.split('/')[-1].split('.')[0] in test_list:
        cv2.imwrite('Test/'+file, img_)
        #pass
    else:
        print('File not in list')

    i = i+1

## Create and Process Masks/Landmarks

In [None]:
import sys
sys.path.append('../..')

In [None]:
from utils.fun import drawBinary, reverseVector
import matplotlib.pyplot as plt

In [None]:
blank = np.zeros([1024, 1024])

for list_str, list_ in [['Train/', train_list], ['Val/', val_list], ['Test/', test_list]]:
    for example in list_:
        landmarks = np.load('Landmarks/'+example+'.npy')
        p1, p2, _, _, _ = reverseVector(landmarks.reshape(-1))
        RLUNG = drawBinary(blank.copy(), p1)
        LLUNG = drawBinary(blank.copy(), p2)
        
        LUNG_mask = (RLUNG + LLUNG) / 255
        
        assert np.all(np.unique(LUNG_mask) == [0., 1.])

        np.save(list_str+'Masks/'+example+'.npy', LUNG_mask)
        np.save(list_str+'Landmarks/'+example+'.npy', landmarks)
        
        #plt.figure(figsize=(10,10))
        #plt.scatter(*landmarks.T)
        #plt.imshow(LUNG_mask)

## Create SDF ground truth

In [None]:
from utils.SDF import sdf
from tqdm import tqdm

In [None]:
mask_paths = glob.glob('*/Masks/*.npy')
len(mask_paths), mask_paths[0]

In [None]:
for path in tqdm(mask_paths):
    mask = np.load(path)
    lung_sdf = sdf(mask, organ=1)
    lung_sdf = np.expand_dims(lung_sdf, -1)
    np.save(path.replace('Masks', 'SDF'), lung_sdf)
    #break