-
Notifications
You must be signed in to change notification settings - Fork 153
/
dataset.py
executable file
·104 lines (86 loc) · 3.39 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from PIL import Image
import cv2
import torch
from torch.utils import data
from torchvision import transforms
from torchvision.transforms import functional as F
import numbers
import numpy as np
import random
class ImageDataTrain(data.Dataset):
def __init__(self, data_root, data_list):
self.sal_root = data_root
self.sal_source = data_list
with open(self.sal_source, 'r') as f:
self.sal_list = [x.strip() for x in f.readlines()]
self.sal_num = len(self.sal_list)
def __getitem__(self, item):
# sal data loading
im_name = self.sal_list[item % self.sal_num].split()[0]
gt_name = self.sal_list[item % self.sal_num].split()[1]
sal_image = load_image(os.path.join(self.sal_root, im_name))
sal_label = load_sal_label(os.path.join(self.sal_root, gt_name))
sal_image, sal_label = cv_random_flip(sal_image, sal_label)
sal_image = torch.Tensor(sal_image)
sal_label = torch.Tensor(sal_label)
sample = {'sal_image': sal_image, 'sal_label': sal_label}
return sample
def __len__(self):
return self.sal_num
class ImageDataTest(data.Dataset):
def __init__(self, data_root, data_list):
self.data_root = data_root
self.data_list = data_list
with open(self.data_list, 'r') as f:
self.image_list = [x.strip() for x in f.readlines()]
self.image_num = len(self.image_list)
def __getitem__(self, item):
image, im_size = load_image_test(os.path.join(self.data_root, self.image_list[item]))
image = torch.Tensor(image)
return {'image': image, 'name': self.image_list[item % self.image_num], 'size': im_size}
def __len__(self):
return self.image_num
def get_loader(config, mode='train', pin=False):
shuffle = False
if mode == 'train':
shuffle = True
dataset = ImageDataTrain(config.train_root, config.train_list)
data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin)
else:
dataset = ImageDataTest(config.test_root, config.test_list)
data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin)
return data_loader
def load_image(path):
if not os.path.exists(path):
print('File {} not exists'.format(path))
im = cv2.imread(path)
in_ = np.array(im, dtype=np.float32)
in_ -= np.array((104.00699, 116.66877, 122.67892))
in_ = in_.transpose((2,0,1))
return in_
def load_image_test(path):
if not os.path.exists(path):
print('File {} not exists'.format(path))
im = cv2.imread(path)
in_ = np.array(im, dtype=np.float32)
im_size = tuple(in_.shape[:2])
in_ -= np.array((104.00699, 116.66877, 122.67892))
in_ = in_.transpose((2,0,1))
return in_, im_size
def load_sal_label(path):
if not os.path.exists(path):
print('File {} not exists'.format(path))
im = Image.open(path)
label = np.array(im, dtype=np.float32)
if len(label.shape) == 3:
label = label[:,:,0]
label = label / 255.
label = label[np.newaxis, ...]
return label
def cv_random_flip(img, label):
flip_flag = random.randint(0, 1)
if flip_flag == 1:
img = img[:,:,::-1].copy()
label = label[:,:,::-1].copy()
return img, label