In [1]:
import os.path as osp
from PIL import Image

import torch.utils.data as data

In [None]:
class makeDatapathList():
    def __init__(self, rootpath):
        self.rootpath=rootpath
        self.img_path_template=osp.join(rootpath,'JPEGImage','%s.jpg')
        self.anno_path_template=osp.join(rootpath,'SegmentationClass','%s.png')
    
    def make_list(self,phase):
        id_names=osp.join(self.rootpath+f"ImageSets/Segmentation/{phase}.txt")
        img_list=[]
        anno_list=[]
        for line in open(id_names):
            file_id=line.strip()
            img_list.append(self.img_path_template%file_id)
            anno_list.append(self.anno_path_template%file_id)
        return [img_list,anno_list]

    def __call__(self,phase):
        return self.make_list(phase)

In [None]:
rootpath='./data/VOCdevkit/VOC2012/'
datapath_list=makeDatapathList(rootpath)
train_img_list,train_anno_list=datapath_list('train')
val_img_list,val_anno_list=datapath_list('val')

In [None]:
'''check'''
len(train_img_list),len(train_anno_list),len(val_img_list),len(val_anno_list)

In [None]:
from utils.data_augumentation import Compose,Scale,RandomRotation,RandomMirror,Resize,Normalize_Tensor

class dataTransform():
    def __init__(self,input_size,color_mean,color_std):
        self.data_transform={
            'train':Compose([
                Scale(scale=[0.5,1.5]),
                RandomRotation(angle=[-10,10]),
                RandomMirror(),
                Resize(input_size),
                Normalize_Tensor(color_mean,color_std)
            ]),
            'val':Compose([
                Resize(input_size),
                Normalize_Tensor(color_mean,color_std)
            ])
        }
    def __call__(self,phase,img,anno_class_img):
        return self.data_transform[phase](img,anno_class_img)

In [None]:
class VOCDataset(data.Dataset):
    def __init__(self,img_list,anno_list,phase,transform):
        self.img_list=img_list
        self.anno_list=anno_list
        self.phase=phase
        self.transform=transform

    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self,index):
        img,anno_class_img=self.pull_item(index)
        return img,anno_class_img
    
    def pull_item(self,index):
        img_file_path=self.img_list[index]
        img=Image.open(img_file_path)
        
        anno_file_path=self.anno_list[index]
        anno_class_img=Image.open(anno_file_path)
        img,anno_class_img=self.transform(self.phase,img,anno_class_img)


In [None]:
'''check'''
color_mean=(0.485,0.456,0.406)
color_std=(0.229,0.224,0.225)

train_dataset=VOCDataset(train_img_list,train_anno_list,phase="train",transform=dataTransform(input_size=475,color_mean=color_mean,color_std=color_std))
val_dataset=VOCDataset(val_img_list,val_anno_list,phase="val",transform=dataTransform(input_size=475,color_mean=color_mean,color_std=color_std))

print(train_dataset.__getitem__(0)[0].shape,train_dataset.__getitem__(0)[1].shape)
print(val_dataset.__getitem__(0)[0].shape,val_dataset.__getitem__(0)[1].shape)


In [None]:
batch_size=0
train_dataloader=data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
val_dataloader=data.DataLoader(val_dataset,batch_size=batch_size,shuffle=True)
datalodaers_dict={"train":train_dataloader,"val":val_dataloader}
batch_iterator=iter(datalodaers_dict["val"])
imgs,anno_class_imgs=next(batch_iterator)
print(imgs.size())
print(anno_class_imgs.size())