<a href="https://colab.research.google.com/github/brothermin00/JNU_2023/blob/main/unet_lung_cancer_2023.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Unet

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Download Model and Dataset

In [None]:
!wget https://github.com/onebottlekick/JNU_dl/releases/download/unet/lung_005_z160_anno.jpg
!wget https://github.com/onebottlekick/JNU_dl/releases/download/unet/lung_005_z160.jpg
!wget https://github.com/onebottlekick/JNU_dl/releases/download/unet/unet.h5

Import Modules

In [None]:
import tensorflow as tf
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

Load Model

In [None]:
model = tf.keras.models.load_model('unet.h5')

Utils

In [None]:
def process_img(img):
    img = Image.open(img).convert('RGB')
    img = np.asarray(img)/255.0
    return img

def process_anno(annotation):
    annotation = Image.open(annotation).convert('RGB')
    annotation = np.asarray(annotation)[:, :, 0]
    temp = np.zeros_like(annotation)
    temp[annotation > 127.5] = 1.0
    annotation = temp
    return temp

def show_img(x):
    plt.imshow(x, cmap='gray')
    plt.axis('off')
    plt.show()

Load Dataset

In [None]:
img = 'lung_005_z160.jpg'
annotation = 'lung_005_z160_anno.jpg'

img = process_img(img)
annotation = process_anno(annotation)

In [None]:
show_img(img)
show_img(annotation)

Model Prediction

In [None]:
prediction = model.predict(np.expand_dims(img, axis=0))[:, :, :, 0]
prediction[prediction < 0] = 0
prediction = prediction.squeeze(0)

In [None]:
show_img(prediction)

In [None]:
# img range (0~1) -> (0, 255)
img = (img*255).astype(np.uint8)

# get R channel of img
template = np.copy(img)[:, :, 0]

Get Prediction Mask

In [None]:
pred_mask = np.copy(template)
pred_mask[prediction > 0.5] = 255
show_img(pred_mask)
pred_mask = np.stack((template, template, pred_mask), axis=2)

Get Label Mask

In [None]:
mask = np.copy(template)
mask[annotation > 0.5] = 255
show_img(mask)
mask = np.stack((template, mask, template), axis=2)

Plot Results

In [None]:
plt.subplot(1, 3, 1)
plt.imshow(img)
plt.axis('off')
plt.title('Image')

plt.subplot(1, 3, 2)
plt.imshow(mask)
plt.axis('off')
plt.title('Mask')

plt.subplot(1, 3, 3)
plt.imshow(pred_mask)
plt.axis('off')
plt.title('Prediction Mask')

plt.show()