In [0]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset

## Wavenet  class
Two generation functions are included. generate_slow( ) is easy to understand, but generate( ) is much faster.

In [0]:
class Conv(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
              dilation=1, bias=True, w_init_gain='linear', is_causal=False):
    super(Conv, self).__init__()
    self.is_causal = is_causal
    self.kernel_size = kernel_size
    self.dilation = dilation
    
    self.conv = nn.Conv1d(in_channels, out_channels,
                         kernel_size=kernel_size, stride=stride,
                         dilation=dilation, bias=bias)
    nn.init.xavier_uniform_(self.conv.weight, 
                           gain=nn.init.calculate_gain(w_init_gain))
    
  def forward(self, signal):
    if self.is_causal:
      padding = (int((self.kernel_size - 1) * (self.dilation)), 0)
      signal = nn.functional.pad(signal, padding)
    return self.conv(signal)



class WaveNet(nn.Module):
    def __init__(self, mu=1,n_residue=32, n_skip= 512, dilation_depth=10, n_repeat=5, n_cond_channel=315):
        # mu: audio quantization size
        # n_residue: residue channels
        # n_skip: skip channels
        # dilation_depth & n_repeat: dilation layer setup
        super(WaveNet, self).__init__()
        self.dilation_depth = dilation_depth
        dilations = self.dilations = [2**i for i in range(dilation_depth)] * n_repeat
        self.from_input = nn.Conv1d(in_channels=mu, out_channels=n_residue, kernel_size=1)
        self.conv_sigmoid = nn.ModuleList([Conv(in_channels=n_residue, out_channels=n_residue, 
                                                kernel_size=2, dilation=d, w_init_gain='sigmoid', is_causal=True)
                         for d in dilations])
        self.conv_tanh = nn.ModuleList([Conv(in_channels=n_residue, out_channels=n_residue, 
                                             kernel_size=2, dilation=d, w_init_gain='tanh', is_causal=True)
                         for d in dilations])
        self.skip_scale = nn.ModuleList([nn.Conv1d(in_channels=n_residue, out_channels=n_skip, kernel_size=1)
                         for d in dilations])
        self.residue_scale = nn.ModuleList([nn.Conv1d(in_channels=n_residue, out_channels=n_residue, kernel_size=1)
                         for d in dilations])
        self.conv_post_1 = nn.Conv1d(in_channels=n_skip, out_channels=n_skip, kernel_size=1)
        self.conv_post_2 = nn.Conv1d(in_channels=n_skip, out_channels=mu, kernel_size=1)
        
        self.cond_conv = nn.Conv1d(in_channels=n_cond_channel, 
                                   out_channels=dilation_depth*n_repeat*n_residue, kernel_size=1)
        self.n_layers = dilation_depth * n_repeat
        
    def forward(self, input, cond_input):
        output = self.preprocess(input)
        cond_output = self.cond_conv(cond_input.unsqueeze(0).transpose(1, 2))
        cond_output = cond_output.view(cond_output.size(0), self.n_layers, -1, cond_output.size(2))

        skip_connections = [] # save for generation purposes
        i = 0
        for s, t, skip_scale, residue_scale in zip(self.conv_sigmoid, self.conv_tanh, self.skip_scale, self.residue_scale):
            output, skip = self.residue_forward(output, cond_output[:, i, :, :], s, t, skip_scale, residue_scale)
            skip_connections.append(skip)
            i += 1
        # sum up skip connections
        output = sum([s[:,:,-output.size(2):] for s in skip_connections])
        output = self.postprocess(output)
        return output
    
    def preprocess(self, input):
        output = input.unsqueeze(1).unsqueeze(0).transpose(1,2) #self.one_hot(input).unsqueeze(0).transpose(1,2)
        output = self.from_input(output)
        return output
    
    def postprocess(self, input):
        output = nn.functional.elu(input)
        output = self.conv_post_1(output)
        output = nn.functional.elu(output)
        output = self.conv_post_2(output).squeeze(0).transpose(0,1)
        return output
    
    def residue_forward(self, input, cond_act, conv_sigmoid, conv_tanh, skip_scale, residue_scale):
        output = input
        output_sigmoid, output_tanh = conv_sigmoid(output), conv_tanh(output)
        output_sigmoid += cond_act
        output_tanh += cond_act
        output = torch.sigmoid(output_sigmoid) * torch.tanh(output_tanh)
        skip = skip_scale(output)
        output = residue_scale(output)
        output = output + input[:,:,-output.size(2):]
        return output, skip
    
#     def generate_slow(self, input, cond_input, n=100):
#         res = input.data.tolist()
#         for ii in range(n):
#             x_prev = Variable(torch.FloatTensor(res[-sum(self.dilations)-1:]))
#             cond_input_t = cond_input[len(res)+1-sum(self.dilations):len(res)+1, :]
#             y = self.forward(x_prev, cond_input_t)
# #             _, i = y.max(dim=1)
#             res.append(y.data[-1, 0])
#         return res
    
#     def generate(self, input, cond_input, n=100, temperature=None, estimate_time=False):
#         ## prepare output_buffer
#         output = self.preprocess(input)
#         cond_output = self.cond_conv(cond_input.unsqueeze(0).transpose(1, 2))
#         cond_output = cond_output.view(cond_output.size(0), self.n_layers, -1, cond_output.size(2))

#         output_buffer = []
#         i = 0
#         for s, t, skip_scale, residue_scale, d in zip(self.conv_sigmoid, self.conv_tanh, self.skip_scale, self.residue_scale, self.dilations):
#             output, _ = self.residue_forward(output, cond_output[:, i, :, :], s, t, skip_scale, residue_scale)
#             sz = 1 if d==2**(self.dilation_depth-1) else d*2
#             output_buffer.append(output[:,:,-sz-1:-1])
#             i += 1
#         ## generate new 
#         res = input.data.tolist()
#         res_cond = cond_input.data.tolist()
#         for i in range(n):
#             output = Variable(torch.FloatTensor(res[-2:]))
#             cond_output = Variable(torch.FloatTensor(res_cond[-2:]))
#             output = self.preprocess(output)
#             cond_output = self.cond_conv(cond_output.unsqueeze(0).transpose(1, 2))
#             cond_output = cond_output.view(cond_output.size(0), self.n_layers, -1, cond_output.size(2))

#             output_buffer_next = []
#             skip_connections = [] # save for generation purposes
#             dd = 0
#             for s, t, skip_scale, residue_scale, b in zip(self.conv_sigmoid, self.conv_tanh, self.skip_scale, self.residue_scale, output_buffer):
#                 output, residue = self.residue_forward(output, cond_output[:, dd, :, :], s, t, skip_scale, residue_scale)
#                 output = torch.cat([b, output], dim=2)
#                 skip_connections.append(residue)
#                 dd += 1
#                 if i%100==0:
#                     output = output.clone()
#                 output_buffer_next.append(output[:,:,-b.size(2):])
#             output_buffer = output_buffer_next
#             output = output[:,:,-1:]
#             # sum up skip connections
#             output = sum(skip_connections)
#             output = self.postprocess(output)
# #             if temperature is None:
# #                 _, output = output.max(dim=1)
# #             else:
# #                 output = output.div(temperature).exp().multinomial(1).squeeze()
#             res.append(output.data[-1, 0])
#         return res

## Load data and cofigure dataset


In [0]:
from collections import namedtuple
import gc
import os

from google.colab import drive
drive.mount('/content/drive')


SubjectTaskAxis = namedtuple('SubjectTaskAxis', ['subject', 'task', 'axis'])

cache_file = os.path.abspath('/content/drive/My Drive/Colab Notebooks/cached_python_all_zero_pad_data.npz')
data = np.load(cache_file)
subjects = data['subjects']
tasks = data['tasks']
axes = data['axes']
mocap_distance = data['distance'] # 
di_distance = data['di_distance']
dist_diff = data['dist_diff']
length = data['length']
di_pos = data['pos'] # filtered double integrated position
di_pos_u = data['pos_u'] # unfiltered double integrated position
vel = data['vel']
vel_norm = data['vel_norm']
acc = data['acc']
acc_unfiltered = data['acc_unfiltered']
direction = data['direction']
segment_begin = data['segment_begin']
segment_end = data['segment_end']
mocap_pos = data['mocap_pos'][()]  # dict is stored as no shape array, access with [()]

# Exclude some data points
exclude_mask = ~((subjects == 8) | (subjects == 9))
subjects = subjects[exclude_mask]
tasks = tasks[exclude_mask]
axes = axes[exclude_mask]
mocap_distance = mocap_distance[exclude_mask]
di_distance = di_distance[exclude_mask]
dist_diff = dist_diff[exclude_mask]
length = length[exclude_mask]
di_pos = di_pos[exclude_mask, :]
di_pos_u = di_pos_u[exclude_mask, :]
vel = vel[exclude_mask, :]
vel_norm = vel_norm[exclude_mask, :]
acc = acc[exclude_mask, :]
acc_unfiltered = acc_unfiltered[exclude_mask, :]
direction = direction[exclude_mask]
segment_begin = segment_begin[exclude_mask]
segment_end = segment_end[exclude_mask]

# For later convenience
abs_distance = np.absolute(mocap_distance)
abs_di_distance = np.absolute(di_distance)

gc.collect()

In [0]:
class MocapTrainDataset(Dataset):
  def __init__(self, select_sub, segment_length): 
    np.random.seed(1234)
    self.select_sub = select_sub
    self.segment_length = segment_length
    self.mocap_d_data = []
    self.cond_vel_data = []

    for sub in self.select_sub:
      for t in range(1, 3):
        for ax in range(1, 4):
          mask = (subjects==sub) & (tasks == t) & (axes==ax)
          self.mocap_d_data.append(mocap_distance[mask])
          self.cond_vel_data.append(acc[mask, :])
          
    self.mocap_d_data = np.array(self.mocap_d_data)
    self.cond_vel_data = np.array(self.cond_vel_data)

    ind_shuffle = np.arange(len(self.mocap_d_data))
    np.random.shuffle(ind_shuffle)
    self.mocap_d_data = self.mocap_d_data[ind_shuffle]
    self.cond_vel_data = self.cond_vel_data[ind_shuffle]


  def __getitem__(self, index):
    d_seq = self.mocap_d_data[index]
    cond_vel_seq = self.cond_vel_data[index]
    
    if len(d_seq) >= self.segment_length:
      max_start = len(d_seq) - self.segment_length
      start = np.random.randint(0, max_start)
      d_seq = d_seq[start:start+self.segment_length] 
      cond_vel_seq = cond_vel_seq[start:start+self.segment_length, :]
    
    d_seq = d_seq.astype(np.float32)
    cond_vel_seq = cond_vel_seq.astype(np.float32)

    return Variable(torch.from_numpy(d_seq)), Variable(torch.from_numpy(cond_vel_seq)) 
  
  def __len__(self):
    return len(self.mocap_d_data)
  
  
  
class MocapTestDataset(Dataset):
  def __init__(self, select_sub): 
    np.random.seed(1234)
    self.select_sub = select_sub
    self.mocap_d_data = []
    self.cond_vel_data = []
    self.di_d_data = []

    for sub in self.select_sub:
      for t in range(1, 3):
        for ax in range(1, 4):
          mask = (subjects==sub) & (tasks == t) & (axes==ax)
          self.mocap_d_data.append(mocap_distance[mask])
          self.cond_vel_data.append(acc[mask, :])
          self.di_d_data.append(di_distance[mask])
          
    self.mocap_d_data = np.array(self.mocap_d_data)
    self.cond_vel_data = np.array(self.cond_vel_data)
    self.di_d_data = np.array(self.di_d_data)

  def __getitem__(self, index):
    d_seq = self.mocap_d_data[index]
    cond_vel_seq = self.cond_vel_data[index]
    di_d_seq = self.di_d_data[index]
    
    d_seq = d_seq.astype(np.float32)
    cond_vel_seq = cond_vel_seq.astype(np.float32)
    di_d_seq = di_d_seq.astype(np.float32)

    return Variable(torch.from_numpy(di_d_seq)), Variable(torch.from_numpy(cond_vel_seq)), Variable(torch.from_numpy(d_seq))
  
  def __len__(self):
    return len(self.mocap_d_data)

## network training

This network tries to generate the vibrating sine wave above. 
- 24 channels in residue outputs
- 128 channels in skip outputs
- 5 dilation layers (n_repeat=1, dilation_depth=5)

In [0]:
n_cond = len(acc[0])
depth = 5
net = WaveNet(mu=1,n_residue=24,n_skip=128,dilation_depth=depth,n_repeat=1, n_cond_channel=n_cond)
batch_length = 100 
optimizer = optim.Adam(net.parameters(),lr=0.001)
uniq_sub = np.unique(subjects)
train_sub = np.delete(uniq_sub, 0)
test_sub = [uniq_sub[0]]
trainset = MocapTrainDataset(train_sub, batch_length)
batch_size = 8#64 #trainset.__len__()
dataset_len = trainset.__len__()
print('dataset lenght', dataset_len)
print('batch size', batch_size)

loss_save = []
max_epoch = 3000
for epoch in range(max_epoch):
    optimizer.zero_grad()
    loss = 0
    batch_idx = [np.random.randint(0, dataset_len) for _ in range(batch_size)]
    for idx in range(batch_size):
        batch = trainset.__getitem__(batch_idx[idx])
        x, y = batch 
        x_in = x[:-1]
        y_in = y[1:, :]
        logits = net(x_in, y_in)
        sz = batch_length - 2**depth

        loss = loss + nn.functional.mse_loss(logits[:, 0], x[1:])
    loss = loss/batch_size
    loss.backward()
    optimizer.step()
    loss_save.append(loss.item())
    # monitor progress
    if epoch%100==0:
        print('epoch {}, loss {}'.format(epoch, loss.item()))
        ii = np.random.randint(0, dataset_len)
        batch = trainset.__getitem__(ii)
        x, y = batch 
        x_in = x[:-1]
        y_in = y[1:, :]
        logits = net(x_in, y_in)
#         _, i = logits.max(dim=1)
        plt.figure(figsize=[16,4])
#         plt.plot(i.data.tolist())
        sz = batch_length - 2**depth

        plt.step(logits[:, 0].tolist(), 'tab:blue', label='pred')
#         print(logits.tolist())
        plt.step(x.tolist()[1:],'tab:orange', ms=1, label='gt')
        plt.title('epoch {}'.format(epoch))
        plt.legend()
        plt.ylim(-1.22, 1.22)
        plt.show()

## loss function

In [0]:
plt.figure(figsize=[15,4])
plt.plot(loss_save)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('loss function')
plt.ylim(0, 0.01)
torch.save(net.state_dict(), '/content/drive/My Drive/Colab Notebooks/wavenet_cond_fixed')

## data generation

In [0]:
# given mocap dataset 
loaded_net = torch.load('/content/drive/My Drive/Colab Notebooks/wavenet_cond_fixed')
net_test = WaveNet(mu=1,n_residue=24,n_skip=128,dilation_depth=depth,n_repeat=1, n_cond_channel=n_cond)

net_test.load_state_dict(loaded_net)
print(train_sub)
print(test_sub)
total_mae = []
testset = MocapTestDataset(test_sub) 
for kk in range(len(testset)):
  batch = testset.__getitem__(kk)

  x, y, gt = batch
  x_in = x[:-1]
  y_in = y[1:, :]
  logits = net_test(x_in, y_in)
  plt.figure(figsize=[50,4])
  plt.step(logits[:, 0].tolist(), 'tab:blue', label='pred')
  plt.step(gt[1:].tolist(),'tab:orange', ms=1, label='gt')
  plt.legend()
  plt.ylim(-1.22, 1.22)
  plt.ylabel('distance')
  plt.grid()

  total_mae += list(abs(np.array(logits[:, 0].tolist()) - np.array(gt[1:].tolist())))
  
print('total mean', np.mean(np.array(total_mae)), 'total median', np.median(total_mae))

In [0]:
# generate
testset = MocapTestDataset(test_sub) 

total_gen_mae = []
for kk in range(len(testset)):
  batch = testset.__getitem__(kk)
  dilation_dim = sum([2**i for i in range(depth)])

  x_whole, y_whole, gt = batch 
  x_s = x_whole[:dilation_dim]

  res = x_s.tolist()
  for l in range(len(y_whole)-dilation_dim):
    x_prev = Variable(torch.FloatTensor(res[-dilation_dim:]))
    results = net_test(x_prev, y_whole[l+1:l+1+dilation_dim, :])
    res.append(results[-1, 0].tolist())

    
  mask = (subjects == test_sub[0]) & (tasks == int(kk/3+1)) & (axes == int(kk % 3+1))
  plt.figure(figsize=[21,4])
  plt.step(res, 'tab:blue', label='pred')
  plt.step(gt.tolist(),'tab:orange', ms=1, label='gt')
  plt.legend()
#   plt.ylim(-1.22, 1.22)
  plt.ylabel('distance (m)')
  plt.grid()
  plt.axhline(0, color='k')
  plt.title('sub ' + str(test_sub[0]) + ' tasks ' + str(int(kk/3+1)) + ' axes ' + str(int(kk % 3+1)))

  
  plt.figure(figsize=[21,4])
  cum_x = np.cumsum(length[mask]) / (100 * 60)
  plt.step(cum_x, np.cumsum(np.array(res) - np.array(gt.tolist())), 'tab:red', label='cumulatvie error')
  plt.legend()
  plt.ylim(-2, 2)
  plt.ylabel('cumulatvie error (distance (m))')
  plt.grid()
  plt.axhline(0, color='k')

  plt.show()

  total_gen_mae += list(abs(np.array(res) - np.array(gt.tolist())))
  
print('total mean', np.mean(np.array(total_gen_mae)), 'total median', np.median(total_gen_mae))