In [None]:
import torch
from torch import nn
from torchmetrics.functional import jaccard_index
from torchmetrics.functional.classification import multiclass_accuracy
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from transformers import get_cosine_schedule_with_warmup
import torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation

from transformers import SegformerImageProcessor
import pandas as pd 
from torch.utils.data import Dataset, random_split
from torch.utils.data import DataLoader
import os
from PIL import Image
import numpy as np
import wandb
import matplotlib.pyplot as plt
import random

# adapted from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SegFormer/Fine_tune_SegFormer_on_custom_dataset.ipynb
class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            image_processor (SegformerImageProcessor): image processor to prepare images + segmentation maps.
        """
        self.root_dir = root_dir
        self.image_processor = SegformerImageProcessor(
            image_mean = [74.90, 85.26, 80.06], # use mean calculated over our dataset
            image_std = [15.05, 13.88, 12.01], # use std calculated over our dataset
            do_reduce_labels=False
            )

        self.img_dir = os.path.join(self.root_dir, "images")
        self.ann_dir = os.path.join(self.root_dir, "masks")
        
        # Get all image filenames without extension
        dataframe = pd.read_csv(
            f"{root_dir}/orig_palsa_labels.csv", 
            names=['filename', 'palsa'], 
            header=0
            )
        
        dataframe = dataframe.loc[dataframe['palsa']>0]
        dataframe = dataframe[~dataframe['filename'].str.endswith('aug')]
        checked_names = list(dataframe['filename'])
        self.filenames = [os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f[:-4] in checked_names]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = self.filenames[idx]
        img_path = os.path.join(self.img_dir, f"{img_name}.jpg")
        ann_path = os.path.join(self.ann_dir, f"{img_name}.png")

        image = Image.open(img_path)
        segmentation_map = Image.open(ann_path)

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.image_processor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs
    
# Create the full dataset
root_dir = "/root/Permafrost-Segmentation/Pseudomasks/TRAIN"
full_dataset = SemanticSegmentationDataset(root_dir)

# Function to visualize a random sample of the input image and segmentation mask
def print_value_ranges(dataset):
    idx = random.randint(0, len(dataset) - 1)  # Select a random index
    encoded_inputs = dataset[idx]  # Get the processed image and mask

    # Extract the image
    import matplotlib.pyplot as plt

    # Get the image name
    img_name = dataset.filenames[idx]  # Get the corresponding image filename
    img_path = os.path.join("/root/Permafrost-Segmentation/Pseudomasks/TRAIN/images", f"{img_name}.jpg")
    print(f"Image Name: {img_name}.jpg, Path: {img_path}")

    # Load the original image
    image = Image.open(img_path)

    # Plot the original image
    plt.imshow(image)
    plt.axis('off')  # Hide axis
    plt.title(f"Image: {img_name}.jpg")
    plt.show()

    # Plot the encoded image
    encoded_image = encoded_inputs['pixel_values'].squeeze(0).permute(1, 2, 0).numpy()  # Convert to numpy array
    plt.imshow(encoded_image)
    plt.axis('off')  # Hide axis
    plt.title(f"Encoded Image: {img_name}.jpg")
    plt.show()

    # Get value ranges for each channel
    image_np = np.array(image)  # Convert to numpy array for value range calculations
    channels = image_np.shape[2]
    for c in range(channels):
        channel_data = image_np[:, :, c]
        min_value = channel_data.min()
        max_value = channel_data.max()
        mean_value = channel_data.mean()
        std_value = channel_data.std()

        print(f"Channel {c}:")
        print(f"  Min: {min_value}, Max: {max_value}, Mean: {mean_value}, Std: {std_value}")


print_value_ranges(full_dataset)