In [1]:

import os
from PIL import Image
from pathlib import Path
from dotenv import load_dotenv
import dataset
import tifffile as tiff
import numpy as np

load_dotenv()

data_path = Path(os.getenv("prepare_dataset_folder"))
processed_path = Path(os.getenv("processed_images_folder"))
split = "train"  # Adjust as needed for 'val' or 'test'

# Ensure the processed data folder exists
processed_path.mkdir(parents=True, exist_ok=True)

# Instantiate the dataset
forest_sat_dataset = dataset.MineSATDataset(split=split, data_path=data_path)

# Iterate over all dataset indices
for index in range(len(forest_sat_dataset)):
    # Get the filepaths for the current index
    filepath = forest_sat_dataset.filepaths[index]
    
    # Extract the original directory name
    original_dir_name = Path(filepath).name
    
    # Get the transformed images
    images_dict = forest_sat_dataset.get_images(index)  # Ensure you call the correct method
    
    # Create a new directory path for the processed images
    processed_dir_path = processed_path / original_dir_name
    processed_dir_path.mkdir(parents=True, exist_ok=True)
    
    for img_type, image in images_dict.items():
        # Skip saving the mask if you only want the indices and RGB images
        if img_type == "Mask":
            continue

        # Normalize the image data to 0-255
        image = np.clip(image, 0, np.max(image))  # Clip to the range you want
        image_8bit = ((image - np.min(image)) / (np.max(image) - np.min(image)) * 255).astype('uint8')

        # If the image has more than one channel, convert it to RGB
        if image_8bit.ndim > 2 and image_8bit.shape[2] > 3:
            # Convert multi-band images (e.g., 4 bands) to RGB (3 bands) before saving as JPEG
            image_8bit = image_8bit[:, :, :3]

        # Create the PIL Image from the numpy array
        pil_image = Image.fromarray(image_8bit)

        # Define the filename, replacing .tif with .jpg
        filename = f"{img_type}.jpg"

        # Define the full path for the .jpg file
        filepath = processed_dir_path / filename

        # Save the image as a JPEG
        pil_image.save(filepath, 'JPEG')

AssertionError: dataset_splits.csv not found in data_path.