-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_prefetcher.py
32 lines (28 loc) · 1.03 KB
/
data_prefetcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
class DataPrefetcher(object):
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
self.next_rgb, self.next_t, self.next_gt = next(self.loader)
except StopIteration:
self.next_rgb = None
self.next_t = None
self.next_gt = None
return
with torch.cuda.stream(self.stream):
self.next_rgb = self.next_rgb.cuda(non_blocking=True).float()
self.next_t = self.next_t.cuda(non_blocking=True).float()
self.next_gt = self.next_gt.cuda(non_blocking=True).float()
#self.next_rgb = self.next_rgb #if need
#self.next_t = self.next_t #if need
#self.next_gt = self.next_gt # if need
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
rgb = self.next_rgb
t= self.next_t
gt = self.next_gt
self.preload()
return rgb, t, gt