In [None]:
def base_wave(n_periods=1, phase_angle=0, arr_len=100):
    return np.sin(np.linspace(phase_angle,n_periods*np.pi*2+phase_angle,arr_len))

def bias_amplitude(arr, amplitude=.5, bias=.5):
    if amplitude != 1:
        arr = np.multiply(arr, amplitude)
    if bias:
        arr = np.add(arr, bias)
    return arr

def pixel_period(arr, pixel_period=.5):
    if 0 < period < 1:
        period = int(period * len(arr))
    arr = np.resize(arr, (period))
    return arr

def pad_wave(arr, pad_size, pad_split, ret_indices=True):
    assert round(sum(pad_split),3) == 1, f'pad_split {pad_split} must add to 1, not {sum(pad_split)}'
    assert len(pad_split) == 2
    left_pad_size = int(pad_size * pad_split[0])
    right_pad_size = pad_size - left_pad_size
    
    left_pad = np.zeros(left_pad_size)
    right_pad = np.zeros(right_pad_size)
    
    arr = np.concatenate((left_pad, arr, right_pad), axis=0)
    if ret_indices:
        indices = ~arr.astype(np.bool_)
        return arr, indices
    return arr

def pixel_wave(pixel_period, amplitude, bias, arr_len, pad_split=(.5,.5), **base_kwargs):
#     arr_len = base_kwargs['arr_len']
    if 0 < pixel_period < 1:
        pixel_period = int(pixel_period * len(arr_len))
#     base_kwargs['arr_len'] = pixel_period
    wave = base_wave(arr_len=pixel_period, **base_kwargs)
    wave = bias_amplitude(wave, amplitude=amplitude, bias=bias)
    
    pad_size = (arr_len - pixel_period)
    wave = pad_wave(wave, pad_size, pad_split)
    return wave

def verify_indices(indices, arr=None, pad_val=0):
    assert indices.ndim == 1
    assert indices.dtype == np.bool_
    if arr:
        assert arr.shape == indices.shape
        assert indices[arr == pad_val].all() == True
    return
    
def noise_patch(arr, indices, start, stop, **noise_kwargs):
    indices[start:stop] = True
    arr[start:stop] = add_noise(arr[start:stop], **noise_kwargs)
    return arr, indices
    
def add_edge_noise(arr, indices, start=0, stop=-1, size_range=(.05,.1), pad_val=0, ret_indices=True, **noise_kwargs):
    valid_size = len(indices)
    false_indices = np.where(indices==False)
    patch_min = np.amin(false_indices)
    patch_max = np.amax(false_indices)
    
    left_size = int(valid_size * np.random.uniform(*size_range))
    right_size = int(valid_size * np.random.uniform(*size_range))
    
    arr, indices = noise_patch(arr, indices, patch_min, patch_min + left_size, **noise_kwargs)
    arr, indices = noise_patch(arr, indices, patch_max - right_size, patch_max, **noise_kwargs)
    return arr, indices

def plot_wave(wave, indices=None, i=''):
    if isinstance(i, int): i = ' ' + str(i)
    plt.figure(figsize=(16,4), dpi=200)
    x=[i for i in range(len(wave))]
    assert wave.ndim == 1
    sns.lineplot(x=x, y=wave, label=f'Wave{i}')
    if indices is not None:
        verify_indices(indices)
        sns.lineplot(x=x, y=indices, label=f'Label{i}')
    plt.legend()
    plt.show()
    return

arr, indices = pixel_wave(120, .4, .5, 200)
plot_wave(arr)
slanted_arr = np.add(arr, np.linspace(0, len(arr)*.005, len(arr)))
plot_wave(slanted_arr)
arr, indices = add_edge_noise(arr, indices)
plot_wave(arr)
plot_wave(indices)

In [None]:
wave = base_wave()
plot_wave(wave)

In [None]:
ARG_STEPS = 20
amplitudes = np.arange(.1,1,ARG_STEPS)
pixel_periods = np.arange(60,SINE_LEN,ARG_STEPS)
phase_angles = np.arange(0,np.pi,ARG_STEPS)

y_list = list()
X_list = list()
for i, args in enumerate(itertools.product(amplitudes, pixel_periods, phase_angles)):
    y_list.append(np.repeat(args, NOISE_SAMPLES))
    amplitude, pixel_period, phase_angle = args
    X_list.append(
        np.vstack(
            [
                pixel_wave(
                    pixel_period, 
                    amplitude, 
                    .5, 
                    n_periods=1, 
                    phase_angle=phase_angle, 
                    arr_len=SINE_LEN
                ) 
                for _ 
                in range(NOISE_SAMPLES)
            ]
        )
    )
    if i % 1000 == 0:
        sns.lineplot(data=X_list[-1].transpose())
        plt.show()
        
y_train = np.concatenate(y_list, axis=0)
print(y_train.shape)
X_train = np.concatenate(X_list, axis=0)
print(X_train.shape)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.hidden = nn.Sequential(
            nn.Linear(SINE_LEN, 128),
            LeakyReLU(),
            nn.Linear(128, 64),
            LeakyReLU(),
            nn.Linear(64, 32),
            LeakyReLU(),
        )
        self.output = nn.Sequential(
            nn.Linear(32, 1),
        )
        
    def forward(self, x):
        output1 = self.model(x)
        output1 = self.output(output1)
        output2 = self.model(x)
        output2 = self.output(output2)
    
        return output1, output2

model = Net()
criterion = nn.MSELoss()
optimizer = optim.NAdam(model.parameters())
for epoch in range(N_EPOCHS):
    inputs = X
    targets = X
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    
    print(f'Epoch {epoch} Loss: {loss}')