In [None]:
import zipfile
import os

def unzip_dataset(zip_file, extract_folder):
    # Check if the zip file exists
    if not os.path.exists(zip_file):
        print(f"Error: Zip file '{zip_file}' not found.")
        return

    # Create the extraction folder if it doesn't exist
    os.makedirs(extract_folder, exist_ok=True)

    # Open and extract the zip file
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(extract_folder)

    print(f"Dataset extracted successfully to '{extract_folder}'.")

# Example usage:
zip_file = '/teamspace/studios/this_studio/C2Seg_AB_splitted.zip'  # Replace with your zip file path
extract_folder = '/teamspace/studios/this_studio/dataset'  # Replace with the folder where you want to extract

unzip_dataset(zip_file, extract_folder)


In [None]:
!pip install rasterio
!pip install segmentation_models_pytorch
!pip install -U albumentations
!pip install torchmetrics

In [None]:
import os
import torch
from torch.utils.data import Dataset
import rasterio
import numpy as np
import torchvision.transforms as T

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform

        # List of image file names (assuming all three directories have the same file names)
        self.image_names = os.listdir(os.path.join(root_dir, 'msi'))

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Get the file names
        img_name = self.image_names[idx]

        # Paths to the msi, sar, and label images
        msi_path = os.path.join(self.root_dir, 'msi', img_name)
        sar_path = os.path.join(self.root_dir, 'sar', img_name)
        label_path = os.path.join(self.root_dir, 'label', img_name)

        # Read the msi image
        with rasterio.open(msi_path) as msi_src:
            msi_image = msi_src.read()  # Shape: (4, height, width)

        # Normalize the msi image to [0, 1]
        msi_image = msi_image.astype(np.float32)
        msi_image = (msi_image - msi_image.min()) / (msi_image.max() - msi_image.min())

        # Read the sar image
        with rasterio.open(sar_path) as sar_src:
            sar_image = sar_src.read()  # Shape: (2, height, width)

        # Normalize the sar image to [0, 1]
        sar_image = sar_image.astype(np.float32)
        sar_image = (sar_image - sar_image.min()) / (sar_image.max() - sar_image.min())

        # Concatenate msi and sar images along the first dimension
        combined_image = np.concatenate((msi_image, sar_image), axis=0)  # Shape: (6, height, width)

        # Read the label image
        with rasterio.open(label_path) as label_src:
            label_image = label_src.read(1)  # Read the first band. Shape: (height, width)

        # Convert to PyTorch tensors
        combined_image = torch.tensor(combined_image, dtype=torch.float32)
        label_image = torch.tensor(label_image, dtype=torch.float32)

        if self.transform:
            combined_image, label_image = self.transform((combined_image, label_image))

        return combined_image, label_image

import numpy as np
import torch
import torchvision.transforms as T

# Custom transform function
class CustomTransform:
    def __init__(self):
        self.transform = T.Compose([
            T.RandomHorizontalFlip(p=0.5),  # Set probability to 0.5
            T.RandomVerticalFlip(p=0.5),    # Set probability to 0.5
            # T.RandomRotation(degrees=30),  # Uncomment if you want to use rotation
        ])

    def __call__(self, sample):
        image, label = sample
        # Apply the same transformation to the image and the label
        seed = np.random.randint(2147483647)  # Make a seed with numpy generator
        torch.manual_seed(seed)
        image = self.transform(image)
        torch.manual_seed(seed)
        label = self.transform(label.unsqueeze(0)).squeeze(0)  # Unsqueeze and squeeze to keep label shape
        return image, label



In [None]:


# Create an instance of the custom dataset

train_root_dir = '/teamspace/studios/this_studio/dataset/C2Seg_AB_splitted/train'
test_root_dir = '/teamspace/studios/this_studio/dataset/C2Seg_AB_splitted/test'
val_root_dir = '/teamspace/studios/this_studio/dataset/C2Seg_AB_splitted/val'

train_dataset = CustomDataset(root_dir=train_root_dir)
test_dataset = CustomDataset(root_dir=test_root_dir)
val_dataset = CustomDataset(root_dir=val_root_dir)

# You can now pass this custom dataset to a DataLoader
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Example of iterating through the DataLoader
for i, (images, labels) in enumerate(val_dataloader):
    print(images.shape, labels.shape)
    print(labels.unique())
    # Process your batch