In [1]:
from PIL import Image
import torch
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
import os
from torchvision import transforms
import torchvision
import numpy as np
from random import sample 
import re
import cv2

def ls(directory):
    return [os.path.join(directory, path) for path in os.listdir(directory)]
    

class ChangeDetectionDataset(Dataset):
    def __init__(self, root, seq_length = 5, transform = None, distortion = None, size=(256,256)):
        self.root = root
        self.seq_length = seq_length
        
        all_imgs = os.listdir(root)
        self.image_paths = sorted(all_imgs)
        self.image_paths = [os.path.join(self.root, image_path) for image_path in self.image_paths] 
        self.images = [Image.open(img_loc).convert("RGB") for img_loc in self.image_paths]
        
        if transform is None:
            self.transform = transforms.Compose([transforms.RandomCrop(size),lambda x: np.array(x),transforms.ToTensor()])
        
        if distortion is None:
            self.distortion= transforms.Compose([transforms.ColorJitter(0.6,0.6,0.5,0.1),transforms.RandomErasing()])
        
    def __getitem__(self, index):
        y = self.images[index]
        if self.transform:
            y = self.transform(y)
        xs = [self.distortion(y.clone()) for _ in range(self.seq_length)]
        xs = torch.stack(xs)
        return xs, y
    
    def __len__(self):
        return len(self.images)
    
    def show(self, idx):
        return self.images[idx]

In [2]:
root = "data/dataset2014/dataset/"
categories = ls(root)
videos =  [ls(category) for category in categories]
video = videos[0][0]

root = video
root


'data/dataset2014/dataset/badWeather\\blizzard'

In [3]:
class ChangeDetectionVideo(Dataset):
    def __init__(self, root, seq_length = 10, transform = None, size=(256,256)):
        self.root = root
        self.seq_length = seq_length
        
        
        #  Extract temporal ROI (frames that have a GT label)
        f = open(os.path.join(root, "temporalROI.txt"), "rt")
        data = f.read()
        f.close()
        data = data.split()
        low, high = int(data[0]) - 1,int(data[1]) - 1

        
        # Paths for images and GT labels 
        image_files = sorted(ls(root+"/input"))[low:high]
        gt_files = sorted(ls(root+"/groundtruth"))[low:high]
        
        
        if transform is None:
            self.transform = transforms.Compose([transforms.Resize(size),lambda x: np.array(x),transforms.ToTensor()])
        else:
            self.transform = transform
        
        # Load images in PIL format
        self.image_files = image_files
        self.images = [self.transform(Image.open(img_loc).convert("RGB")) for img_loc in self.image_files]
        
        # Load GT labels in PIL format
        self.gt_files = gt_files
        self.gt = [self.transform(Image.open(img_loc).convert("RGB")) for img_loc in self.gt_files]
        

        
        
    def __getitem__(self, index):
        labels = self.gt[index:index+self.seq_length]
        labels = torch.stack(labels)
        xs = self.images[index:index+self.seq_length]
        xs = torch.stack(xs)
        return xs, labels
    
    def __len__(self):
        return len(self.images) - self.seq_length + 1
    
dataset = ChangeDetectionVideo(root)

In [7]:
x,y = dataset[len(dataset)-1]
y

tensor([[[[0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          ...,
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667]],

         [[0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          ...,
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667]],

         [[0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0