In [9]:
import os

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

In [8]:
class CamoDataset(Dataset):
    def __init__(self, img_dir, label_dir=None, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir

        self.transform = transform
        
        self.img_names = sorted(os.listdir(img_dir))
        self.label_names = sorted(os.listdir(label_dir)) if label_dir else None

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_names[idx])
        img = Image.open(img_path).convert("RGB")
        
        if self.label_dir:
            label_path = os.path.join(self.label_dir, self.label_names[idx])
            label = Image.open(label_path)
            if self.transform:
                img, label = self.transform(img, label)
            return img, label
        else:
            if self.transform:
                img = self.transform(img)
            return img

In [None]:
# Define transforms
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

data_path = os.path.abspath("../data/")

labeled_dataset = CamoDataset(f"{data_path}/datasets/camouflage_1/img", f"{data_path}/datasets/camouflage_1/labels", transform)
unlabeled_dataset = CamoDataset(f"{data_path}/raw/camouflage_1/", transform=transform)

labeled_loader = DataLoader(labeled_dataset, batch_size=8, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=8, shuffle=True)