In [11]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter



In [30]:
# import my new datasetloader here
from X_ado_train_loader import X4K1000FPSDataset

In [None]:

# Custom Dataset Class
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_groups = []

        for subdir in os.listdir(root_dir):
            subdir_path = os.path.join(root_dir, subdir)
            if os.path.isdir(subdir_path):
                images = sorted(os.listdir(subdir_path))
                for i in range(0, len(images) - 2):
                    self.image_groups.append([os.path.join(subdir_path, images[i]),
                                              os.path.join(subdir_path, images[i+1]),
                                              os.path.join(subdir_path, images[i+2])])

    def __len__(self):
        return len(self.image_groups)

    def __getitem__(self, idx):
        image_paths = self.image_groups[idx]
        images = [Image.open(img_path).convert('RGB') for img_path in image_paths]

        if self.transform:
            images = [self.transform(image) for image in images]

        image_stack = torch.stack(images, dim=0)
        return image_stack

# Transformation
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])


In [33]:

# Dataset and DataLoader
transform = transforms.Compose([
    transforms.ToTensor()  # Convert images to tensor after all other transformations
])
dataset = X4K1000FPSDataset(root_dir='/home/jyzhao/Code/Datasets/X4K1000FPS', transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [34]:
# TensorBoard Setup
writer = SummaryWriter('runs/dataset_visualization')

# Fetch a batch of data
images = next(iter(dataloader))[0]  # Get the first batch and extract the image stack

# Log images to TensorBoard
writer.add_images('sample_images', images, 0)
writer.close()
print('Images logged to TensorBoard.')


Images logged to TensorBoard.
