In [None]:
import os
import numpy as np
import cv2
import torch
from monai.transforms import *
from monai.transforms.compose import Transform, Randomizable
from monai.data import Dataset, DataLoader
from monai.networks.nets import DenseNet121
import shutil
from PIL import Image

In [1]:
# Define transformsm and DataLoader to pre-process data
class SumDimension(Transform):
    def __init__(self, dim=1):
        self.dim = dim

    def __call__(self, inputs):
        return inputs.sum(self.dim)

class MyResize(Transform):
    def __init__(self, size=(120, 120)):
        self.size = size

    def __call__(self, inputs):
        image = cv2.resize(inputs, dsize=(self.size[1], self.size[0]), interpolation=cv2.INTER_CUBIC)
        image2 = image[30:90, 30:90]
        return image2

class Astype(Transform):
    def __init__(self, type='uint8'):
        self.type = type

    def __call__(self, inputs):
        return inputs.astype(self.type)

class AddChannel(Transform):

    def __call__(self, img):
        return img[None]

# Define transformation pipeline for validation data
val_transforms = Compose([
    LoadImage(image_only=True),
    Resize((-1, 1)),
    Astype(),
    SumDimension(2),
    Astype(),
    MyResize(),
    AddChannel(),
    ToTensor(),
])

# Define custom dataset class
class MedNISTDataset(Dataset):
    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]

# Set the directory for processed images
editted_test_dir = './temp'

# Set the device for model inference
device = torch.device("cuda:0")

# Load the pre-trained model
model = DenseNet121(
    spatial_dims=2,
    in_channels=1,
    out_channels=4,
).to(device)

model.load_state_dict(torch.load('./models/sea_state_model.pth'))
model.eval()

# Set the directory for output images
image_directory = './outputs'

# Iterate through files in the image directory
for filename in os.listdir(image_directory):
    # Build the full path to the file
    file_path = os.path.join(image_directory, filename)

    # Check if it is a file
    if os.path.isfile(file_path):
        # Move the file to a temporary directory
        shutil.move(file_path, f"temp/1/{filename}")

        # Load the class names
        t_class_names0 = os.listdir(editted_test_dir)
        t_class_names = sorted(t_class_names0)

        # Build a list of image files for each class
        t_image_files = [[os.path.join(editted_test_dir, t_class_name, x)
                          for x in os.listdir(os.path.join(editted_test_dir, t_class_name))]
                         for t_class_name in t_class_names]

        # Flatten the list of image files and build corresponding labels
        t_image_file_list = [x for sublist in t_image_files for x in sublist]
        t_image_label_list = [i for i, sublist in enumerate(t_image_files) for _ in sublist]

        # Convert to NumPy arrays
        testX, testY = np.array(t_image_file_list), np.array(t_image_label_list)

        # Create a dataset and DataLoader for the test set
        editted_test_ds = MedNISTDataset(testX, testY, val_transforms)
        editted_test_loader = DataLoader(editted_test_ds, batch_size=32, num_workers=2)

        # Perform model inference on the test set
        with torch.no_grad():
            for test_data in editted_test_loader:
                test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
                pred = model(test_images.float()).argmax(dim=1)

                # Iterate through predictions and perform image processing
                for i in range(len(pred)):
                    for filename in os.listdir(f"temp/{test_labels[i].item() + 1}"):
                        original_image_path = f"./inputs/images/{filename.split('_')[2]}"
                        original_image = Image.open(original_image_path)
                        original_size = original_image.size

                        # Resize processed image and save to the corresponding class directory
                        image = Image.open(f"temp/{test_labels[i].item() + 1}/{filename}").resize(original_size)
                        image.save(f"sea_state_classified/{pred[i].item() + 1}/{filename}")
                        
                        # Remove the temporary file
                        os.remove(f"temp/{test_labels[i].item() + 1}/{filename}")
