Skip to content

Commit

Permalink
updated data loading to allow analogy plot to work.
Browse files Browse the repository at this point in the history
  • Loading branch information
edenton committed Feb 22, 2018
1 parent b0105b5 commit 5d2435a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 5 additions & 3 deletions data/kth.py
Expand Up @@ -17,9 +17,11 @@ def __init__(self, train, data_root, seq_len = 20, image_size=64, data_type='drn

self.dirs = os.listdir(self.data_root)
if train:
self.train = True
data_type = 'train'
self.persons = list(range(1, 21))
else:
self.train = False
self.persons = list(range(21, 26))
data_type = 'test'

Expand Down Expand Up @@ -76,10 +78,10 @@ def __getitem__(self, index):
random.seed(index)
np.random.seed(index)
#torch.manual_seed(index)
if self.data_type == 'drnet':
return torch.from_numpy(self.get_drnet_data())
elif self.data_type == 'sequence':
if not self.train or self.data_type == 'sequence':
return torch.from_numpy(self.get_sequence())
elif self.data_type == 'drnet':
return torch.from_numpy(self.get_drnet_data())
else:
raise ValueError('Unknown data type: %d. Valid type: drnet | sequence.' % self.data_type)

Expand Down
1 change: 0 additions & 1 deletion train_drnet.py
Expand Up @@ -169,7 +169,6 @@ def plot_rec(x, epoch):

def plot_analogy(x, epoch):
x_c = x[0]


h_c = netEC(x_c)
nrow = 10
Expand Down

0 comments on commit 5d2435a

Please sign in to comment.