In [1]:
from PIL import Image  # to read images
import json
import os
import cv2
import shutil
import torch
import numpy as np
from monai.networks.nets import DenseNet121
from monai.transforms import Compose, LoadImage, Resize, ToTensor, Activations, AsDiscrete
from monai.data import Dataset, DataLoader

In [2]:
# Set flag to load truncated images
Image.LOAD_TRUNCATED_IMAGES = True

# Path to the ground truth JSON file
gt_json_file = "./scripts/instances_val.json"

# Loop over four sea state classes
for ii in range(4):
    # Define the directory containing sea state images
    images_directory = f"./sea_state_classified/{ii+1}"

    # Iterate through files in the images directory
    for filename in os.listdir(images_directory):
        # Extract relevant information from the filename
        ori_filename = filename.split("_")[2]
        seed = ori_filename[0]
        im_id = ori_filename.split('.')[0]
        bbox = []

        # Open the ground truth JSON file
        with open(gt_json_file, encoding="utf-8") as f:
            data = json.load(f)
            count = 0

            # Iterate through annotations in the JSON data
            for annotation in data["annotations"]:
                if int(annotation["image_id"]) == int(im_id) and int(annotation["category_id"]) == 2:
                    bbox = annotation["bbox"]
                    category = annotation["category_id"]
                    bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]

                    try:
                        # Open the image, crop it based on the bounding box, and save the cropped image
                        img = Image.open(f"{images_directory}/{filename}")
                        img = img.crop(bbox)
                        img.save(
                            f"./crops/{ii+1}/{filename.split('.')[0]}__{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg"
                        )

                    except IOError:
                        # Handle IOError if the image cannot be opened or saved
                        pass

                    count += 1  # Increment count for each matched annotation


In [3]:
# Define transforms and Dataloader 
class SumDimension:
    def __init__(self, dim=1):
        self.dim = dim

    def __call__(self, inputs):
        return inputs.sum(self.dim)

class MyResize:
    def __init__(self, size=(120, 120)):
        self.size = size

    def __call__(self, inputs):
        image = cv2.resize(inputs, dsize=(self.size[1], self.size[0]), interpolation=cv2.INTER_CUBIC)
        return image[30:90, 30:90]

class Astype:
    def __init__(self, type='uint8'):
        self.type = type

    def __call__(self, inputs):
        return inputs.astype(self.type)

class AddChannel:
    def __call__(self, img):
        return img[None]

val_transforms = Compose([
    LoadImage(image_only=True),
    Resize((-1, 1)),
    Astype(),
    SumDimension(2),
    Astype(),
    MyResize(),
    AddChannel(),
    ToTensor(),
])

to_onehot = AsDiscrete(to_onehot=6, n_classes=6)

class MedNISTDataset(Dataset):
    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]

# Load model
device = torch.device("cuda:0")
model = DenseNet121(
    spatial_dims=2,
    in_channels=1,
    out_channels=2,
).to(device)

model.load_state_dict(torch.load('./models/boat_model.pth'))
model.eval()

# Run test
for ii in range(4):
    main_dir = f"./crops/{ii+1}"
    for filename in os.listdir(main_dir):
        shutil.move(f"{main_dir}/{filename}", f"./temp_boat/1/{filename}")
        y_true, y_pred = [], []

        editted_test_dir = './temp_boat'

        t_class_names = sorted(os.listdir(editted_test_dir))
        t_image_files = [[os.path.join(editted_test_dir, t_class_name, x) 
                    for x in os.listdir(os.path.join(editted_test_dir, t_class_name))] 
                    for t_class_name in t_class_names]

        t_image_file_list = [x for sublist in t_image_files for x in sublist]
        t_image_label_list = [i for i, sublist in enumerate(t_image_files) for _ in sublist]

        testX, testY = np.array(t_image_file_list), np.array(t_image_label_list)

        editted_test_ds = MedNISTDataset(testX, testY, val_transforms)
        editted_test_loader = DataLoader(editted_test_ds, batch_size=32, num_workers=2)

        with torch.no_grad():
            for test_data in editted_test_loader:
                test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
                pred = model(test_images.float()).argmax(dim=1)
                y_pred.extend(pred.cpu().numpy())
                shutil.move(f"./temp_boat/1/{filename}", f"./crops_checked/{ii+1}/{pred[0].item()}/{filename}")


In [4]:
# This code snippet iterates over four sea state classes (1 to 4).
for ii in range(4):
    # Define the directory path for the sea state images.
    image_directory = f"./sea_state_classified/{ii+1}"

    # Define the directory path for checked crops corresponding to each sea state class.
    crop_directory = f"./crops_checked/{ii+1}/1"

    # Create an empty list to store unique image names from the checked crops directory.
    image_list = []

    # Iterate through files in the checked crops directory.
    for filename in os.listdir(crop_directory):
        # Extract the base image name (excluding any additional information).
        base_image_name = filename.split('__')[0]

        # Check if the base image name is not already in the list.
        if base_image_name not in image_list:
            # Add the base image name to the list.
            image_list.append(base_image_name)

    # Iterate through files in the sea state images directory.
    for filename in os.listdir(image_directory):
        # Extract the base image name (excluding file extension).
        base_image_name = filename.split('.')[0]

        # Check if the base image name is not in the list obtained from checked crops.
        if base_image_name not in image_list:
            # Remove the image file from the sea state images directory.
            os.remove(f"{image_directory}/{filename}")


In [7]:
# for ii in range(4):
#     # Define the directory path for the sea state images.
#     image_directory = f"./sea_state_classified/{ii+1}"

#     # Define the directory path for checked crops corresponding to each sea state class.
#     crop_directory = f"./crops_checked/{ii+1}/1"
#     for filename in os.listdir(crop_directory):
#          os.remove(f"{crop_directory}/{filename}")