In [None]:
!git clone https://github.com/4uiiurz1/pytorch-nested-unet

In [None]:
import os
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

# Define paths to data directories
DATA_DIR = './inputs/polyp/'
image_dir = os.path.join(DATA_DIR, 'images')
mask_dir = os.path.join(DATA_DIR, 'masks')

In [None]:
# Read an example image and its corresponding mask
img = cv2.imread(os.path.join(image_dir, os.listdir(image_dir)[0]))
mask = cv2.imread(os.path.join(mask_dir, os.listdir(mask_dir)[0]))

# Visualize the example image and mask
plt.subplot(1,2,1)
plt.title('IMAGE')
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

plt.subplot(1,2,2)
plt.title('MASK')
plt.imshow(cv2.cvtColor(mask, cv2.COLOR_BGR2RGB))

In [None]:
# Training the model using train.py script
!python train.py --dataset polyp --arch NestedUNet --name polyp_segmentation --epochs 150 --batch_size 8 --input_w 384 --input_h 384 --img_ext jpg --mask_ext jpg --optimizer Adam

In [None]:
# Validating the trained model using val.py script
!python val.py --name polyp_segmentation

In [13]:
from PIL import Image
import numpy as np
import os
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import torch
import archs
import yaml
from torchsummary import summary

In [14]:
# Load the trained model weights
best_model = 'path to weight file'
# Load the configuration file for the trained model
yml_path = 'models/polyp_segmentation/config.yml'

In [None]:
with open(yml_path) as f:
    data = yaml.load(f, Loader=yaml.FullLoader)
print(data)

In [None]:
# Set the device for model inference (GPU if available, otherwise CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model = archs2.__dict__[data['arch']](data['num_classes'],data['input_channels'],data['deep_supervision'])
model = model.to(DEVICE)
model.load_state_dict(torch.load(best_model, map_location=DEVICE))
print("model loaded")

In [None]:
# Perform inference on test images and save segmentation results
test_folder = 'path to test folder'
gt_folder = 'path to groundtruth mask of test folder'
result_folder = 'path to result folder'

if not os.path.exists(result_folder):
    os.makedirs(result_folder)

test_file_list = os.listdir(test_folder)

In [None]:
# Define a function to blend original image with predicted mask
def blend_images(ori, pred):
    ori = cv2.cvtColor(ori, cv2.COLOR_BGR2RGB)
    output = Image.fromarray(pred)
    background = Image.fromarray(ori).convert('RGBA')
    output = output.resize((ori.shape[1], ori.shape[0])).convert('RGBA')
    output_final = Image.blend(background, output, alpha=0.5)
    return cv2.cvtColor(np.array(output_final), cv2.COLOR_BGR2RGB)

In [None]:
# Process each test image
for idx, file in enumerate(test_file_list):
    black = np.zeros(shape=(384,384*4,3), dtype=np.uint8)
    # Load ground truth mask and resize it
    gt = cv2.imread(os.path.join(gt_folder, file))
    gt = cv2.resize(gt, (384,384))

    # Load test image and resize it
    img = cv2.imread(os.path.join(test_folder, file))
    img = cv2.resize(img, (384,384))

    # Preprocess input image for model inference
    input = img.astype('float32') / 255
    input = np.expand_dims(input, axis=0)
    input = torch.from_numpy(input).to(DEVICE)
    input = input.permute(0,3,1,2)

    # Perform model inference
    output = model(input)
    output = torch.sigmoid(output)
    output = output.permute(0,2,3,1).cpu().detach()

    # Post-process predicted mask
    pred = np.array(output[0])*255
    pred = np.where(pred<240, 0, pred)
    pred_ = np.repeat(pred, 3, -1).astype(np.uint8)

    # Blend original image with predicted mask
    output_final = blend_images(img, pred_)[:,:,:3]

    cv2.putText(img, "Origninal Image", (70,40),cv2.FONT_HERSHEY_DUPLEX, 1,(0,0,255), thickness=3, lineType=cv2.LINE_AA)
    cv2.putText(gt, "GroundTruth Mask", (60,40),cv2.FONT_HERSHEY_DUPLEX, 1,(0,0,255), thickness=3, lineType=cv2.LINE_AA)
    cv2.putText(pred_, "Predicted Mask", (70,40),cv2.FONT_HERSHEY_DUPLEX, 1,(0,0,255), thickness=3, lineType=cv2.LINE_AA)
    cv2.putText(output_final, "Blended Images", (60,40),cv2.FONT_HERSHEY_DUPLEX, 1,(0,0,255), thickness=3, lineType=cv2.LINE_AA)

    # Create a visualization grid
    black[:,:384,:] = img[:,:,:]
    black[:,384:384*2,:] = gt[:,:,:]
    black[:,384*2:384*3,:] = pred_[:,:,:]
    black[:,384*3:384*4,:] = output_final[:,:,::]

    # Save the visualization
    cv2.imwrite(os.path.join(result_folder, file), black)

    # Display the visualization (for the first 10 images)
    if idx <10:
        plt.imshow(cv2.cvtColor(black, cv2.COLOR_BGR2RGB))
        plt.show()

    if idx == 99:
        break
