In [1]:
import numpy as np
import pescador
import logging
import os

import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from datetime import datetime

In [2]:
LOGGER = logging.getLogger('gbsd')
LOGGER.setLevel(logging.DEBUG)

In [3]:
torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)

In [4]:
matplotlib.use('Agg')

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [6]:
CMD_VOLENVPER = 0
CMD_DUTYLL = 1
CMD_MSB = 2
CMD_LSB = 3
CMD_COUNT = 4

def onehot_cmd(data):
    cmd = data[CMD_OFFSET]
    nd = [ 0, 0, 0, 0 ]
    nd[int(cmd)] = 1
    return nd


CH_1 = 1
CH_2 = 2
CH_COUNT = 2

TIME_OFFSET = 0
CH_OFFSET = 1
CMD_OFFSET = 2
CHANNEL_OFFSET = 3
PARAM1_OFFSET = 4
PARAM2_OFFSET = 5
PARAM3_OFFSET = 6
SIZE_OF_INPUT_FIELDS = 7

WINDOW_SIZE = 1024

NORMALIZE_TIME_BY = float(4194304 * 3) # 1 second is 4194304 cycles so this is 10s

def norm(val, max_val):
    return ((val / max_val) * 2.) - 1.

def unnorm(val, max_val):
    return ((val + 1.) / 2.) * max_val

def fresh_input(command, channel, time):
    newd = np.zeros(shape=SIZE_OF_INPUT_FIELDS, dtype=float)
    newd[TIME_OFFSET] = norm(time, NORMALIZE_TIME_BY)

    if int(channel) == 1:
        newd[CH_OFFSET] = norm(CH_1, CH_COUNT)
    elif int(channel) == 2:
        newd[CH_OFFSET] = norm(CH_2, CH_COUNT)
    else:
        raise "I didn't expect this"

    newd[CMD_OFFSET] = norm(channel, CMD_COUNT)
    return newd

def nop():
    return fresh_input(NOP_CMD_OFFSET, 1, 0)

def norm_command_of_parts(command, channel, parts, time):
    inp = fresh_input(command, channel, time)
    
    if command == CMD_DUTYLL:
        inp[PARAM1_OFFSET] = norm(float(parts[3]), 2.)
        inp[PARAM2_OFFSET] = norm(float(parts[4]), 64.)
    elif command == CMD_VOLENVPER:
        inp[PARAM1_OFFSET] = float(parts[3]) / 16.
        inp[PARAM2_OFFSET] = float(parts[4])
        inp[PARAM3_OFFSET] = float(parts[4]) / 7.
    elif command == CMD_LSB:
        inp[PARAM1_OFFSET] = norm(float(parts[3]), 255.)
        inp[PARAM2_OFFSET] = 0.
        inp[PARAM3_OFFSET] = 0
    elif command == CMD_MSB:
        inp[PARAM1_OFFSET] = norm(float(parts[3]), 7.)
        inp[PARAM2_OFFSET] = float(bool(parts[4]))
        inp[PARAM3_OFFSET] = float(bool(parts[5]))
    else:
        raise "this should not happen"
        
    return inp

def unnorm_feature(data):
    def l_unnorm(channel, maxv):
        data[channel] = unnorm(data[channel], maxv)
    l_unnorm(TIME_OFFSET, NORMALIZE_TIME_BY)
    l_unnorm(CH_OFFSET, CH_COUNT)
    l_unnorm(CMD_OFFSET, CMD_COUNT)
    return data

def load_training_data(src):
    data = []
    file = open(src, 'r')
    for line in file:
        parts = line.split()
        if len(parts) > 0 and parts[0] == "CH":
            #print(parts)
            channel = int(parts[1])
            command = parts[2]
            time = int(parts[-1])
            if command == "DUTYLL":
                new_item = norm_command_of_parts(CMD_DUTYLL, channel, parts, time)
            elif command == "VOLENVPER":
                new_item = norm_command_of_parts(CMD_VOLENVPER, channel, parts, time)
            elif command == "FREQLSB":
                new_item = norm_command_of_parts(CMD_LSB, channel, parts, time)
            elif command == "FREQMSB":
                new_item = norm_command_of_parts(CMD_MSB, channel, parts, time)
             # Otherwise unknown   
            data.append(new_item)
           #print("NEXTCMD", data[-1])
    return data

@pescador.streamable
def samples_from_training_data(src, window_size=WINDOW_SIZE):
    sample_data = None

    try:
        sample_data = load_training_data(src)
    except Exception as e:
        LOGGER.error('Could not load {}: {}'.format(src, str(e)))
        raise StopIteration()

    true_window_size = window_size + 1

    # Pad small samples with nop
    while len(sample_data) < true_window_size:
        sample_data.append(nop())

    while True:

        if len(sample_data) == true_window_size:
            sample = sample_data
        else:
            # Sample a random window from the audio file
            start_idx = np.random.randint(0, len(sample_data) - true_window_size)
            end_idx = start_idx + true_window_size
            sample = sample_data[start_idx:end_idx]

        sample_input = sample[0:window_size]
        sample_output = sample[window_size:window_size+1]

        sample_input = np.array(sample_input).astype(np.float32)
        sample_output = np.array(sample_output).astype(np.float32)

        yield { 'X':sample_input, 'Y': sample_output }

def create_batch_generator(paths, batch_size):
    streamers = []
    for path in paths:
        print("Creating a batch generator")
        streamers.append(samples_from_training_data(path))
        print("Done creating batch generator")
    mux = pescador.ShuffledMux(streamers)
    batch_gen = pescador.buffer_stream(mux, batch_size)
    return batch_gen

def training_files(dirp):
    return [
      os.path.join(root, fname)
      for (root, dir_names, file_names) in os.walk(dirp, followlinks=True)
      for fname in file_names
    ]

def create_data_split(paths, batch_size):
    train_gen = create_batch_generator(paths, batch_size)
    return train_gen



In [7]:
print("Collecting training data")
train_gen = create_data_split(training_files("../..//training_data/"), 1)
print("Collected")

Collecting training data
Creating a batch generator
Done creating batch generator
Creating a batch generator
Done creating batch generator
Creating a batch generator
Done creating batch generator
Creating a batch generator
Done creating batch generator
Creating a batch generator
Done creating batch generator
Creating a batch generator
Done creating batch generator
Collected


In [None]:
NUM_EVENTS_PER_ROUND = WINDOW_SIZE
DIM = SIZE_OF_INPUT_FIELDS * NUM_EVENTS_PER_ROUND

class CommandNet(nn.Module):

    def __init__(self):
        super(CommandNet, self).__init__()

        self.main = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Conv2d(1, 1, kernel_size=(SIZE_OF_INPUT_FIELDS, 2)),
            nn.Conv2d(1, 1, kernel_size=(SIZE_OF_INPUT_FIELDS, 2)),
            nn.Flatten(),
            nn.Dropout(p=0.2),
            nn.Linear(5060, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.Sigmoid(),
            nn.Linear(128, SIZE_OF_INPUT_FIELDS),
        )

    def forward(self, sequence):
        output = self.main(sequence)
        return output


EPOCHS = 2000
ROUND_SZ = 10000

def train():
    
    lr=0.0001

    command_generator = CommandNet().to(device)
    command_optimizer= optim.Adam(command_generator.parameters(), lr=lr, weight_decay=1e-5)
    command_criterion = nn.MSELoss()

    for iteration in range(EPOCHS):
        print(f"Round {iteration}")

        for i in range(ROUND_SZ):

          command_optimizer.zero_grad()

          data = next(train_gen)
          data_train_cmd = torch.Tensor(data['X']).to(device)
          data_test_cmd = torch.Tensor(data['Y'][0]).to(device)
            
          prediction_command = command_generator(data_train_cmd)
          #print(prediction_command.flatten(), data_test_command.flatten())
          command_loss = command_criterion(prediction_command, data_test_cmd)
          command_loss.backward()
          command_optimizer.step()


        print("Command batch loss:", command_loss.item())
        print("Last data:", unnorm_feature(data_test_cmd.detach().cpu().numpy()[0]))
        print("Last prediction:", unnorm_feature(prediction_command.detach().cpu().numpy()[0]))
        torch.save(command_generator.state_dict, "./" + str(int(datetime.now().timestamp())))
        print("Saved checkpoint")

    return data['X'][0], command_generator.eval()

seed, command_generator = train()

Round 0
Command batch loss: 0.006072419695556164
Last data: [28.125       2.          2.          0.          0.14285715  1.
  1.        ]
Last prediction: [75612.75           2.0705736      2.0556736      0.00683031
     0.09456278     1.1110128      1.1480008 ]
Saved checkpoint
Round 1
Command batch loss: 0.09043489396572113
Last data: [16.125  2.     2.     0.     0.375  0.     0.   ]
Last prediction: [-60945.              2.158892        2.153615        0.00243907
      0.527324        0.5769598       0.49569464]
Saved checkpoint
Round 2
Command batch loss: 0.08581046760082245
Last data: [136708.12      2.        2.        0.        0.       -1.        0.  ]
Last prediction: [134836.88            1.8336294       1.8755605       0.00105389
      0.21677679     -0.27761525      0.01702073]
Saved checkpoint
Round 3
Command batch loss: 0.00041969731682911515
Last data: [28.125       2.          2.          0.          0.14285715  1.
  1.        ]
Last prediction: [154064.25            

Command batch loss: 0.00029262356110848486
Last data: [16.125       1.          1.          0.          0.71428573  1.
  1.        ]
Last prediction: [-36876.              0.9772401       0.9652324       0.00100656
      0.74002534      0.9798607       0.9888359 ]
Saved checkpoint
Round 30
Command batch loss: 0.00010477915202500299
Last data: [16.125       1.          1.          0.          0.71428573  1.
  1.        ]
Last prediction: [-2442.75           1.0042485      0.9856639      0.00001742
     0.7063648      0.98800886     1.0213856 ]
Saved checkpoint
Round 31
Command batch loss: 0.003360210219398141
Last data: [55.875       1.          1.          0.          0.42857143  1.
  1.        ]
Last prediction: [12954.             1.0162852      1.0298575     -0.00083357
     0.58024293     0.9961266      1.0030417 ]
Saved checkpoint
Round 32
Command batch loss: 0.0028781022410839796
Last data: [592.125        2.           2.           0.          -0.16078432
   0.           0.      

Command batch loss: 0.00023436718038283288
Last data: [2863.875    2.       2.       0.       0.875    0.       0.   ]
Last prediction: [89709.             1.9897552      1.9703255      0.00122762
     0.85155785    -0.02282808    -0.00726402]
Saved checkpoint
Round 59
Command batch loss: 0.05156256631016731
Last data: [68148.            1.            1.            0.           -0.9372549
     0.            0.       ]
Last prediction: [133294.12            0.98768103      1.0026844      -0.0044052
     -0.33676106     -0.00323154      0.0073773 ]
Saved checkpoint
Round 60
Command batch loss: 0.0011537434766069055
Last data: [264.        2.        2.        0.        0.        0.28125   0.     ]
Last prediction: [27467.25           1.9910731      1.9903021     -0.00050578
    -0.06620082     0.22293705    -0.01308067]
Saved checkpoint
Round 61
Command batch loss: 0.003039608709514141
Last data: [68100.             1.             1.             0.
    -0.22352941     0.             0.   

Command batch loss: 6.412428774638101e-05
Last data: [336220.12           1.             1.             0.
     -0.4745098      0.             0.       ]
Last prediction: [309267.              0.9874671       0.98485696     -0.00032998
     -0.48882052      0.00297638     -0.00153336]
Saved checkpoint
Round 88
Command batch loss: 0.001362775219604373
Last data: [55.875  2.     2.     0.     1.     1.     1.   ]
Last prediction: [98389.5            1.9971919      2.0313787      0.0016688
     0.92299765     0.9565517      0.96505445]
Saved checkpoint
Round 89
Command batch loss: 0.00020798125478904694
Last data: [16.125       1.          1.          0.          0.71428573  1.
  1.        ]
Last prediction: [10803.75           1.0313729      1.0270035      0.00017779
     0.6977194      0.99725103     0.9979165 ]
Saved checkpoint
Round 90
Command batch loss: 0.00013258621038403362
Last data: [16.125  1.     1.     0.     1.     1.     1.   ]
Last prediction: [-38906.25            0.99670

Command batch loss: 0.019069764763116837
Last data: [592.125       2.          2.          0.          0.3882353   0.
   0.       ]
Last prediction: [151159.12            2.006562        2.010687        0.00026741
      0.02451476      0.02099247      0.01051752]
Saved checkpoint
Round 118
Command batch loss: 0.0033729150891304016
Last data: [66112.125         1.            1.            0.            0.3882353
     0.            0.       ]
Last prediction: [174384.              1.0005115       1.0065036       0.0001433
      0.38753596      0.1388175       0.06350274]
Saved checkpoint
Round 119
Command batch loss: 0.0012052158126607537
Last data: [686440.1         1.          1.          0.          0.6875      0.
      0.    ]
Last prediction: [215118.75            0.9745894       0.96041167      0.00000928
      0.65024537      0.0189927      -0.00617954]
Saved checkpoint
Round 120
Command batch loss: 0.004865818656980991
Last data: [264.        1.        1.        0.        0.     

Command batch loss: 1.8294722394784912e-05
Last data: [1288.125     1.        1.        0.        0.0625    0.        0.    ]
Last prediction: [16130.25           1.001558       1.0022213     -0.00076603
     0.06792892    -0.00342473     0.00877773]
Saved checkpoint
Round 148
Command batch loss: 0.00024434112128801644
Last data: [203388.      1.      1.      0.      1.     -1.      0.]
Last prediction: [405214.5             0.9954797       0.9941013       0.00106367
      1.0104327      -1.0166825      -0.01624455]
Saved checkpoint
Round 149
Command batch loss: 0.002002799417823553
Last data: [28.125       2.          2.          0.          0.42857143  1.
  1.        ]
Last prediction: [432.           1.901695     1.8993988   -0.00301874   0.46508417
   1.0171486    1.0137551 ]
Saved checkpoint
Round 150
Command batch loss: 0.00017537349776830524
Last data: [16.125  2.     2.     0.     1.     1.     1.   ]
Last prediction: [16748.625          1.9893596      1.9859321     -0.00059628

Command batch loss: 0.00012629374396055937
Last data: [28.125       2.          2.          0.          0.14285715  1.
  1.        ]
Last prediction: [20635.5            1.9911126      1.9822884      0.00034191
     0.16608715     1.0083287      1.0103344 ]
Saved checkpoint
Round 178
Command batch loss: 6.718916847603396e-05
Last data: [16.125       1.          1.          0.          0.71428573  1.
  1.        ]
Last prediction: [-9823.5            0.9896902      0.98424983    -0.00015545
     0.6988322      0.99443084     0.9945486 ]
Saved checkpoint
Round 179
Command batch loss: 6.777178350603208e-05
Last data: [16.125       1.          1.          0.          0.71428573  1.
  1.        ]
Last prediction: [-26858.25            0.9907678       0.9799515       0.00303125
      0.7016065       1.0082572       1.0056838 ]
Saved checkpoint
Round 180
Command batch loss: 0.0004738961870316416
Last data: [69672.             1.             1.             0.
     0.39607844     0.            

Command batch loss: 0.0012593194842338562
Last data: [16.125       2.          2.          0.          0.42857143  1.
  1.        ]
Last prediction: [184887.75            1.97718         1.9765122      -0.00012191
      0.38224998      1.0577959       0.9574905 ]
Saved checkpoint
Round 206
Command batch loss: 0.00047570824972353876
Last data: [55.875      2.         2.         0.        -0.8745098  0.
  0.       ]
Last prediction: [-34038.75            1.9963691       1.9957608      -0.00035114
     -0.8173835      -0.00331376      0.00289788]
Saved checkpoint
Round 207
Command batch loss: 5.246752698440105e-05
Last data: [66628.125         1.            1.            0.           -0.4745098
     0.            0.       ]
Last prediction: [162319.5             1.0086441       1.0005746      -0.00071945
     -0.47607785      0.00761169      0.00046809]
Saved checkpoint
Round 208
Command batch loss: 0.0013251281343400478
Last data: [482476.12            2.              2.              0.


Command batch loss: 0.05163596197962761
Last data: [3712.125         2.            2.            0.            0.13725491
    0.            0.        ]
Last prediction: [40217.25           2.004233       2.009298       0.00031485
     0.40797335     0.4127229      0.34314686]
Saved checkpoint
Round 235
Command batch loss: 0.00012758467346429825
Last data: [28.125  1.     1.     0.     1.     1.     1.   ]
Last prediction: [-6846.             0.9998669      0.99475384    -0.00093778
     1.0248729      1.0140446      1.0082588 ]
Saved checkpoint
Round 236
Command batch loss: 0.00020004049292765558
Last data: [273868.12            2.              2.              0.
     -0.23921569      0.              0.        ]
Last prediction: [383613.              2.0039477       2.0079067       0.00150058
     -0.27179885     -0.00038135     -0.00085827]
Saved checkpoint
Round 237
Command batch loss: 0.0010912338038906455
Last data: [28.125       2.          2.          0.          0.14285715  1.
 