In [None]:
import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from underwater_unet.model import UNet

% matplotlib inline

In [None]:
# Insert the UNet and AttentionUNet code here
model = UNet(n_channels=1, n_classes=2)  # Example for a grayscale image to be classified into 2 classes

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resizing to fit the U-Net architecture
    transforms.ToTensor(),
])

test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=True)

In [None]:
def display_image_from_npy(npy_path, image_index, method="opencv"):
    """
    Load and display an image from a .npy file.

    Parameters:
    - npy_path (str): Path to the .npy file containing the images.
    - image_index (int): 0-based index of the image to display from the .npy file.
    - method (str): Method to use for displaying the image. Options are "opencv" or "matplotlib".
    """

    # Load the dataset from the .npy file
    dataset = np.load(npy_path)

    # Check if the image_index is valid
    if image_index < 0 or image_index >= len(dataset):
        print(f"Invalid image index. Please provide an index between 0 and {len(dataset) - 1}.")
        return

    # Get the desired image
    image = dataset[image_index]

    if method == "opencv":
        # Display the image using OpenCV
        cv2.imshow(f'Image {image_index}', image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    elif method == "matplotlib":
        # Display the image using Matplotlib
        plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        plt.title(f'Image {image_index}')
        plt.axis('off')
        plt.show()
    else:
        print("Invalid method. Choose 'opencv' or 'matplotlib'.")

In [None]:
dataset_path = '/workspaces/UnderWaterU-Net/dataset.npy'
display_image_from_npy(dataset_path, 20, method="matplotlib")