In [1]:
import os
import numpy as np
import cv2
import torch
from monai.transforms import *
from monai.data import Dataset, DataLoader
from monai.networks.nets import DenseNet121
import shutil
import os
from PIL import Image
# assign directory

##Define MONAI transforms, Dataset and Dataloader to pre-process data
class SumDimension(Transform):
    def __init__(self, dim=1):
        self.dim = dim

    def __call__(self, inputs):
        return inputs.sum(self.dim)
class MyResize(Transform):
    def __init__(self, size=(120,120)):
        self.size = size
    def __call__(self, inputs):
        image=cv2.resize(inputs,dsize=(self.size[1],self.size[0]),interpolation=cv2.INTER_CUBIC)
        image2=image[30:90,30:90]
        return image2
class Astype(Transform):
    def __init__(self, type='uint8'):
        self.type = type
    def __call__(self, inputs):
        return inputs.astype(self.type)

val_transforms = Compose([
    LoadImage(image_only=True),
    Resize((-1,1)),
    Astype(),
    SumDimension(2),
    Astype(),
    MyResize(),
    AddChannel(),    
    ToTensor(),
])

class MedNISTDataset(Dataset):

    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]

editted_test_dir='./temp'

device = torch.device("cuda:0")   #"cuda:0"
model = DenseNet121(
    spatial_dims=2,            
    in_channels=1,
    out_channels=4,
).to(device)

model.load_state_dict(torch.load('./models/sea_state_model.pth'))
model.eval()


#test
t_class_names0 = os.listdir(editted_test_dir)
t_class_names = sorted(t_class_names0)
t_num_class = len(t_class_names)

#image directory
image_directory = './outputs'

for filename in os.listdir(image_directory):
    f = os.path.join(image_directory, filename)
    # checking if it is a file
    if os.path.isfile(f):
        shutil.move(f, f"temp/1/{filename}")
        t_image_files = [[os.path.join(editted_test_dir, t_class_name, x) 
                    for x in os.listdir(os.path.join(editted_test_dir, t_class_name))] 
                    for t_class_name in t_class_names]

        t_image_file_list = []
        t_image_label_list = []
        for i, class_name in enumerate(t_class_names):
            t_image_file_list.extend(t_image_files[i])
            t_image_label_list.extend([i] * len(t_image_files[i]))
        ['1', '2', '3', '4']

        testX=np.array(t_image_file_list)
        testY=np.array(t_image_label_list)

        editted_test_ds = MedNISTDataset(testX, testY, val_transforms)
        editted_test_loader = DataLoader(editted_test_ds, batch_size=32, num_workers=2)

        with torch.no_grad():
            for test_data in editted_test_loader:
                test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
                pred = model(test_images.float()).argmax(dim=1)
                for i in range(len(pred)):
                    for filename in os.listdir(f"temp/{test_labels[i].item()+1}"):
                        original_image_path = f"../inputs/val_images/{filename.split('_')[2]}"
                        original_image = Image.open(original_image_path)
                        original_size = original_image.size
                        image = Image.open(f).resize(original_size)
                        image.save( f"sea_state_classified/{pred[i].item()+1}/{filename}")
                        os.remove(f"temp/{test_labels[i].item()+1}/{filename}")

ModuleNotFoundError: No module named 'torch'