From 5d2435a612ff630602457b2f3aab03edac41fad9 Mon Sep 17 00:00:00 2001 From: edenton Date: Thu, 22 Feb 2018 18:34:20 -0500 Subject: [PATCH] updated data loading to allow analogy plot to work. --- data/kth.py | 8 +++++--- train_drnet.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/data/kth.py b/data/kth.py index c539cf5..abb1e9d 100644 --- a/data/kth.py +++ b/data/kth.py @@ -17,9 +17,11 @@ def __init__(self, train, data_root, seq_len = 20, image_size=64, data_type='drn self.dirs = os.listdir(self.data_root) if train: + self.train = True data_type = 'train' self.persons = list(range(1, 21)) else: + self.train = False self.persons = list(range(21, 26)) data_type = 'test' @@ -76,10 +78,10 @@ def __getitem__(self, index): random.seed(index) np.random.seed(index) #torch.manual_seed(index) - if self.data_type == 'drnet': - return torch.from_numpy(self.get_drnet_data()) - elif self.data_type == 'sequence': + if not self.train or self.data_type == 'sequence': return torch.from_numpy(self.get_sequence()) + elif self.data_type == 'drnet': + return torch.from_numpy(self.get_drnet_data()) else: raise ValueError('Unknown data type: %d. Valid type: drnet | sequence.' % self.data_type) diff --git a/train_drnet.py b/train_drnet.py index 1f79069..9df3768 100644 --- a/train_drnet.py +++ b/train_drnet.py @@ -169,7 +169,6 @@ def plot_rec(x, epoch): def plot_analogy(x, epoch): x_c = x[0] - h_c = netEC(x_c) nrow = 10