Skip to content

Commit

Permalink
Create a dataset from a periodic function
Browse files Browse the repository at this point in the history
  • Loading branch information
nschaetti committed Feb 20, 2019
1 parent 14ffcd4 commit c8866a4
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions echotorch/datasets/PeriodicSignalDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class PeriodicSignalDataset(Dataset):
"""

# Constructor
def __init__(self, sample_len, period, n_samples, start=0):
def __init__(self, sample_len, period, n_samples, height=1.8, start=0, dtype=torch.float32):
"""
Constructor
:param sample_len: Sample's length
Expand All @@ -25,6 +25,8 @@ def __init__(self, sample_len, period, n_samples, start=0):
self.n_samples = n_samples
self.period = period
self.start = start
self.height = height
self.dtype = dtype

# Period length
if type(period) is list:
Expand Down Expand Up @@ -73,21 +75,24 @@ def _generate(self):
# List of samples
samples = list()

# Pattern
maxVal = torch.max(self.period)
minVal = torch.min(self.period)
rp = self.height * (self.period - minVal) / (maxVal - minVal) - (self.height / 2.0)
p_length = rp.size(0)

# For each sample
for i in range(self.n_samples):
# Tensor
period_tensor = torch.FloatTensor(self.period)
sample = period_tensor.repeat(int(self.sample_len // self.period_length) + 1)
sample = torch.zeros(self.sample_len, 1, dtype=self.dtype)

# Start
if type(self.start) is list:
start = self.start[i]
else:
start = self.start
# end if
# Timestep
for t in range(self.sample_len):
sample[t, 0] = rp[(t + self.start) % p_length]
# end for

# Append
samples.append(sample[start:start+self.sample_len].unsqueeze(-1))
samples.append(sample)
# end for

return samples
Expand Down

0 comments on commit c8866a4

Please sign in to comment.