In [1]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms 
import torchvision

In [2]:
# utility functions
def showImage(img):
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    return None

In [3]:
def load_and_sort_images(input_folder):
    files = os.listdir(input_folder)
    
    sorted_files = []
    
    for file_name in files:
        try:
            # Attempt to split the filename and convert the second part into an integer
            part = file_name.split('_')[1].split('.')[0]
            num = int(part)  # Attempt conversion
            # If splitting and conversion are successful, add the filename and number to the list of sorted files
            sorted_files.append((file_name, num))
        except (IndexError, ValueError):
            # If splitting or conversion fails, skip the file
            pass
    
    # Sort the list of sorted files based on the extracted numbers
    sorted_files.sort(key=lambda x: x[1])
    
    # Extract only the filenames from the sorted list
    sorted_filenames = [file_name for file_name, _ in sorted_files]
    
    return sorted_filenames

# Usage
input_folder = './dataset/Images'
output_folder = './output'
sorted_filenames = load_and_sort_images(input_folder)
print(sorted_filenames)


['image_0.jpg', 'image_1.jpg', 'image_2.jpg', 'image_3.jpg', 'image_4.jpg', 'image_5.jpg', 'image_6.jpg', 'image_7.jpg', 'image_8.jpg', 'image_9.jpg', 'image_10.jpg', 'image_11.jpg', 'image_12.jpg', 'image_13.jpg', 'image_14.jpg', 'image_15.jpg', 'image_16.jpg', 'image_17.jpg', 'image_18.jpg', 'image_19.jpg', 'image_20.jpg', 'image_21.jpg', 'image_22.jpg', 'image_23.jpg', 'image_24.jpg', 'image_25.jpg', 'image_26.jpg', 'image_27.jpg', 'image_28.jpg', 'image_29.jpg', 'image_30.jpg', 'image_31.jpg', 'image_32.jpg', 'image_33.jpg', 'image_34.jpg', 'image_35.jpg', 'image_36.jpg', 'image_37.jpg', 'image_38.jpg', 'image_39.jpg', 'image_40.jpg', 'image_41.jpg', 'image_42.jpg', 'image_43.jpg', 'image_44.jpg', 'image_45.jpg', 'image_46.jpg', 'image_47.jpg', 'image_48.jpg', 'image_49.jpg', 'image_50.jpg', 'image_51.jpg', 'image_52.jpg', 'image_53.jpg', 'image_54.jpg', 'image_55.jpg', 'image_56.jpg', 'image_57.jpg', 'image_58.jpg', 'image_59.jpg', 'image_60.jpg', 'image_61.jpg', 'image_62.jpg', '

In [4]:
# load in model
model_name = 'unetc_model'
model_path = f'./model/{model_name}.pth'

model = torch.load(model_path);
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device);
model.eval();

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import albumentations as A
import cv2
import torch
from torchvision import transforms
from PIL import Image
import numpy as np

imgCounter=-1
# Sorted filenames loop through 
for filename in sorted_filenames:
    imgCounter+=1
    #print(f'Filename = {filename}')
    
    # Open the image using PIL
    img = Image.open(os.path.join(input_folder, filename))
    
    # Show the image
    #showImage(img)
    
    # Apply the transformation
    transform = A.Resize(256, 256, interpolation=cv2.INTER_NEAREST)
    transformed_img = transform(image=np.array(img))['image']
    
    # Convert the transformed image back to PIL format
    transformed_img_pil = Image.fromarray(transformed_img)
    
    # Display the transformed image
    #showImage(transformed_img_pil)
    
    # Convert the transformed image to tensor and move to device
    input_tensor = transforms.ToTensor()(transformed_img_pil).unsqueeze(0).to(device)
    #print(f'Input Tensor Shape: {input_tensor.shape}')
    
    # Perform inference
    with torch.no_grad():
        output = model(input_tensor)
        #print(f'Output Shape: {output.shape}')
        
        output_clamped = torch.clamp(output, 0, 1)
        
        zero_tensor = torch.tensor(0., device=device, requires_grad=True)
        one_tensor = torch.tensor(1., device=device, requires_grad=True)
        output_bin = torch.where(output_clamped < 0.5, zero_tensor, one_tensor)
        
        # get the indices where output_bin == 1
        indices = torch.nonzero(output_bin == 1)
        
        # Extract row and column indices for one image
        row_indices = indices[:, 2]
        col_indices = indices[:, 3]
        
        for row, col in zip(row_indices, col_indices):
            # Change the pixel value to dark blue (0, 0, 139)
            transformed_img_pil.putpixel((col, row), (0, 0, 139))
        
        # Save the modified image
        transformed_img_pil.save(f'{output_folder}/image_{imgCounter}.jpg')

    if imgCounter % 20 == 0: 
        print(f'{imgCounter} images processed...')
    # Break loop after processing one image
    #break  

print(f'DONE')

0 images processed...
20 images processed...
40 images processed...
60 images processed...
80 images processed...
100 images processed...
120 images processed...
140 images processed...
160 images processed...
180 images processed...
200 images processed...
220 images processed...
240 images processed...
260 images processed...
280 images processed...
300 images processed...
320 images processed...
340 images processed...
360 images processed...
380 images processed...
400 images processed...
420 images processed...
440 images processed...
460 images processed...
480 images processed...
500 images processed...
520 images processed...
540 images processed...
560 images processed...
580 images processed...
600 images processed...
620 images processed...
640 images processed...
660 images processed...
680 images processed...
700 images processed...
720 images processed...
740 images processed...
760 images processed...
780 images processed...
800 images processed...
820 images processed..