-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.py
88 lines (82 loc) · 2.28 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch.utils.data as data
import torch
import numpy as np
import h5py
from skimage import transform,measure,color
# import cv2
import os
from PIL import Image
from torchvision import transforms
import random
from tqdm import tqdm
import pickle
class FastLoader(data.Dataset):
'''
read from disk to ram
'''
def __init__(self, file_path):
super(FastLoader, self).__init__()
self.readHR = []
self.readLR = []
self.HD = []
self.LD = []
with open(file_path, 'rb') as f:
data = pickle.load(f)
self.readHR,self.readLR,self.HD,self.LD = data
def __getitem__(self, index):
im_h = self.readHR[index]
im_l = self.readLR[index]
HR = transforms.ToTensor()(im_h)
LR = transforms.ToTensor()(im_l)
Hd = self.HD[index]
Ld = self.LD[index]
return LR,HR,Ld,Hd
def __len__(self):
len_h = len(self.readHR)
len_l = len(self.readLR)
if len_h >= len_l:
len_file = len_l
else:
len_file = len_h
return len_file
class FastLoader2(data.Dataset):
'''
read from disk to ram
'''
def __init__(self, file_path):
super(FastLoader2, self).__init__()
self.readHR = []
self.readLR = []
self.HD = []
self.LD = []
# self.bicLR = []
with open(file_path, 'rb') as f:
data = pickle.load(f)
self.readHR,self.readLR,self.HD,self.LD = data
def __getitem__(self, index):
im_h = self.readHR[index]
im_l = self.readLR[index]
h,w = im_h.size
bic_h = im_l.resize(((h*2,w*2)),Image.BICUBIC)
HR = transforms.ToTensor()(im_h)
LR = transforms.ToTensor()(im_l)
Hd = self.HD[index]
Ld = self.LD[index]
bic_HR = transforms.ToTensor()(bic_h)
return LR,HR,Ld,Hd,bic_HR
def __len__(self):
len_h = len(self.readHR)
len_l = len(self.readLR)
if len_h >= len_l:
len_file = len_l
else:
len_file = len_h
return len_file
def cal_mean(image):
im = np.array(image)
mean = np.mean(im)
# mean_tensor = torch.from_numpy(mean)
return mean
def test():
file_path = "./"
dfi = FastLoader(file_path)