In [1]:
import numpy as np
import pescador
import logging
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import math

import gc
import sys

from datetime import datetime

from functools import reduce

In [2]:
LOGGER = logging.getLogger('gbsd')
LOGGER.setLevel(logging.DEBUG)

In [3]:
torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)

In [4]:
matplotlib.use('Agg')

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
cpu = torch.device('cpu')

In [6]:
CMD_VOLENVPER = 0
CMD_DUTYLL = 1
CMD_MSB = 2
CMD_LSB = 3
CMD_COUNT = 4

def onehot_cmd(data):
    cmd = data[CMD_OFFSET]
    nd = [ 0, 0, 0, 0 ]
    nd[int(cmd)] = 1
    return nd

CH_COUNT = 2

TIME_OFFSET = 0
CH_OFFSET = 1
CMD_OFFSET = 2
PARAM1_OFFSET = 3
PARAM2_OFFSET = 4
PARAM3_OFFSET = 5
SIZE_OF_INPUT_FIELDS = 6

MAX_WINDOW_SIZE = 4096

M_CYCLES_PER_SECOND = 4194304.
NORMALIZE_TIME_BY = M_CYCLES_PER_SECOND * 10.

def norm(val, max_val):
    if val > max_val:
        return 1.
    else:
        return ((val / max_val) * 2.) - 1.

def unnorm(val, max_val):
    return ((val + 1.) / 2.) * max_val

def fresh_input(command, channel, time):
    newd = np.zeros(shape=SIZE_OF_INPUT_FIELDS, dtype=int)
    newd[TIME_OFFSET] = time
    newd[CH_OFFSET] = channel
    newd[CMD_OFFSET] = command
    return newd

def parse_bool(v):
    if v == "true":
        return 1
    elif v == "false":
        return 0
    else:
        return int(v)

def command_of_parts(command, channel, parts, time):
    inp = fresh_input(command, channel, time)
    
    if command == CMD_DUTYLL:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = int(parts[4])
    elif command == CMD_VOLENVPER:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = parse_bool(parts[4])
        inp[PARAM3_OFFSET] = int(parts[4])
    elif command == CMD_LSB:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = 0
        inp[PARAM3_OFFSET] = 0
    elif command == CMD_MSB:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = parse_bool(parts[4])
        inp[PARAM3_OFFSET] = parse_bool(parts[5])
    else:
        raise "this should not happen"
    return inp

def int32_as_bytes(ival):
    return np.frombuffer(ival.item().to_bytes(4, byteorder = 'big'), dtype=np.uint8)

def int32_of_bytes(np):
    return int.from_bytes(np, byteorder = 'big')

def int8_as_bytes(ival):
    return np.frombuffer(ival.item().to_bytes(1, byteorder='big'), dtype=np.uint8)

def int8_of_bytes(np):
    return int.from_bytes(np, byteorder = 'big')

def merge_params(data):
    command = data[CMD_OFFSET]
    if command == CMD_DUTYLL:
        return (data[PARAM1_OFFSET] << 6) | data[PARAM2_OFFSET]
    elif command == CMD_VOLENVPER:
        return (data[PARAM1_OFFSET] << 4) | (data[PARAM2_OFFSET] << 3) | data[PARAM3_OFFSET]
    elif command == CMD_LSB:
        return data[PARAM1_OFFSET]
    elif command == CMD_MSB:
        return data[PARAM1_OFFSET]  | (data[PARAM2_OFFSET] << 6) | (data[PARAM3_OFFSET] << 7)
    else:
        raise "this should not happen"
        
def unmerge_params(command, data,v ):
    if command == CMD_DUTYLL:
        data[PARAM1_OFFSET] = v >> 6;
        data[PARAM2_OFFSET] = v & 0b0011_1111
    elif command == CMD_VOLENVPER:
        data[PARAM1_OFFSET] = v >> 4
        data[PARAM2_OFFSET] = (v & 0b0000_1000) >> 3
        data[PARAM3_OFFSET] = (v & 0b0000_0111)
    elif command == CMD_LSB:
        data[PARAM1_OFFSET] = v
    elif command == CMD_MSB:
        data[PARAM1_OFFSET] = v & 0b0011_1111
        data[PARAM2_OFFSET] = (v & 0b0100_0000) >> 6
        data[PARAM3_OFFSET] = (v & 0b1000_0000) >> 7
    else:
        raise Exception("this should not happen")
        
BYTES_PER_ENTRY=7

def command_to_bytes(command):
    new_arr = np.concatenate([
                    int32_as_bytes(command[TIME_OFFSET]),
                    int8_as_bytes(command[CH_OFFSET]),
                    int8_as_bytes(command[CMD_OFFSET]),
                    int8_as_bytes(merge_params(command)),]).flatten()
    return new_arr

def command_of_bytes(byte_arr):
    d = fresh_input(0, 0, 0)
    d[TIME_OFFSET] = int32_of_bytes(byte_arr[0:4])
    print(byte_arr[0:4], d[TIME_OFFSET])
    d[TIME_OFFSET] = min(d[TIME_OFFSET], 400000)
    print("TOFF:", d[TIME_OFFSET])
    print("CH:", byte_arr, byte_arr[4])
    d[CH_OFFSET] = int8_of_bytes(byte_arr[4:5])
    if d[CH_OFFSET] != 1 and d[CH_OFFSET] != 2:
        raise Exception("bad channel prediction")
    d[CMD_OFFSET] = int8_of_bytes(byte_arr[5:6])
    print("NCMD", d[CMD_OFFSET])
    unmerge_params(d[CMD_OFFSET], d, byte_arr[6])
    return d

def unnorm_feature(data):
    
    data = data.copy()
    
    # Unnormalize a channel given a specific max value
    def l_unnorm(channel, maxv):
        data[channel] = unnorm(data[channel], maxv)
    
    # Round and box a channel to a min and max value
    def l_round(channel, minv, maxv):
        data[channel] = round(min(max(data[channel], minv), maxv))
    
    l_unnorm(TIME_OFFSET, NORMALIZE_TIME_BY)
    
    if data[TIME_OFFSET] < 4:
        raise Exception('bad time prediction')
    
    l_unnorm(CH_OFFSET, CH_COUNT)
    data[CH_OFFSET] = round(data[CH_OFFSET])
    
    l_unnorm(CMD_OFFSET, CMD_COUNT)
    data[CMD_OFFSET] = round(data[CMD_OFFSET])
    
    command = data[CMD_OFFSET]
    
    if command == CMD_DUTYLL:
        l_unnorm(PARAM1_OFFSET, 2)
        l_round(PARAM1_OFFSET, 0, 2)
        
        l_unnorm(PARAM2_OFFSET, 64)
        l_round(PARAM2_OFFSET, 0, 64)
        
        data[PARAM3_OFFSET] = 0.
    elif command == CMD_VOLENVPER:
        
        l_unnorm(PARAM1_OFFSET, 16)
        l_round(PARAM1_OFFSET, 0, 16)
        
        l_round(PARAM2_OFFSET, 0, 1)
        
        l_unnorm(PARAM3_OFFSET, 7)
        l_round(PARAM3_OFFSET, 0, 7)
    elif command == CMD_LSB:
        
        l_unnorm(PARAM1_OFFSET, 255.)
        l_round(PARAM1_OFFSET, 0., 255.)
        
        data[PARAM2_OFFSET] = 0.
        
        data[PARAM3_OFFSET] = 0.
    elif command == CMD_MSB:
        l_unnorm(PARAM1_OFFSET, 7.)
        l_round(PARAM1_OFFSET, 0., 7.)
        
        l_round(PARAM2_OFFSET, 0., 1.)
        l_round(PARAM3_OFFSET, 0., 1.)
    else:
        raise Exception("pred was bad")
    
    return data

def print_feature(data, file=sys.stdout):
    
    print("FILE: ", file)
    print("FORPRINT", data)

    command = data[CMD_OFFSET]
    
    if command == CMD_DUTYLL:
        print(f"CH {data[CH_OFFSET]} DUTYLL {data[PARAM1_OFFSET]} {data[PARAM2_OFFSET]} AT {data[TIME_OFFSET]}", file=file, flush=True)
    elif command == CMD_VOLENVPER:
        print(f"CH {data[CH_OFFSET]} VOLENVPER {data[PARAM1_OFFSET]} {data[PARAM2_OFFSET]} {data[PARAM3_OFFSET]} AT {data[TIME_OFFSET]}", file=file, flush=True)
    elif command == CMD_LSB:
        print(f"CH {data[CH_OFFSET]} FREQLSB {data[PARAM1_OFFSET]} AT {data[TIME_OFFSET]}", file=file, flush=True)
    elif command == CMD_MSB:
        print(f"CH {data[CH_OFFSET]} FREQMSB {data[PARAM1_OFFSET]} {data[PARAM2_OFFSET]} {data[PARAM3_OFFSET]} AT {data[TIME_OFFSET]}", file=file, flush=True)
    else:
        print(f"Bad prediction", file=file, flush=True)

def load_training_data(src):
    data = []
    file = open(src, 'r')
    for line in file:
        parts = line.split()
        if len(parts) > 0 and parts[0] == "CH":
            #print(parts)
            channel = int(parts[1])
            command = parts[2]
            time = int(parts[-1])
            if command == "DUTYLL":
                new_item = command_of_parts(CMD_DUTYLL, channel, parts, time)
            elif command == "VOLENVPER":
                new_item = command_of_parts(CMD_VOLENVPER, channel, parts, time)
            elif command == "FREQLSB":
                new_item = command_of_parts(CMD_LSB, channel, parts, time)
            elif command == "FREQMSB":
                new_item = command_of_parts(CMD_MSB, channel, parts, time)
            else:
                print("Unknown", command)
             # Otherwise unknown   
            data.append(new_item)
    return data

@pescador.streamable
def samples_from_training_data(src, window_size):
    sample_data = None
    try:
        sample_data = load_training_data(src)
    except Exception as e:
        LOGGER.error('Could not load {}: {}'.format(src, str(e)))
        raise StopIteration()

    while True:
        if len(sample_data) < window_size:
            sample = sample_data
        else:
            # Sample a random window from the audio file
            start_idx = np.random.randint(0, len(sample_data) - window_size)
            sample = sample_data[start_idx:(start_idx + window_size)]
            
        sample = np.array([command_to_bytes(x) for x in sample])

        yield { 'X':sample }

def create_batch_generator(paths, window_size):
    streamers = []
    for path in paths:
        streamers.append(samples_from_training_data(path, window_size))
    mux = pescador.StochasticMux(streamers, n_active=1, rate=1).iterate()
    
    return mux

def training_files(dirp):
    return [
      os.path.join(root, fname)
      for (root, dir_names, file_names) in os.walk(dirp, followlinks=True)
      for fname in file_names
    ]

def create_data_split(paths, window_size=MAX_WINDOW_SIZE):
    train_gen = create_batch_generator(paths, window_size)
    return train_gen



In [7]:
print("Collecting training data")
train_gen = create_data_split(training_files("../..//training_data/"))
print("Collected")

Collecting training data
Collected


In [8]:
from torch.distributions.gamma import Gamma
from torch.distributions.normal import Normal
import torch

class NIGDist():
    
    def __init__(self, m, vinv, a, b):
        self.m = m
        self.vinv = vinv
        self.a = a
        self.b = b

    def update(self, total, moment, n):
        
        newVinv = self.vinv  + n
        newM = (1.0 / newVinv) * (self.vinv * self.m + total)
        self.a += n / 2
        self.b += 0.5 * (self.m * self.m * self.vinv + moment - newM * newM * newVinv)
        self.m = newM
        self.vinv = newVinv
        
        
    def update(self, x):
        total = torch.sum(x, dim=1)
        moment = torch.sum(x * x, dim=1)
        n = x.shape[1]
        self.update(total, moment, n)

    def sample(self):
        vars = 1.0 / (Gamma(self.a, self.b)).sample()
        means = Normal(self.m, vars / self.vinv).sample()
        return means, vars

    def update_component(self, index, x):
        total = torch.sum(x)
        moment = torch.sum(x * x)
        n = x.shape[0]
        newVinv = self.vinv[index]  + n
        newM = (1.0 / newVinv) * (self.vinv[index] * self.m[index] + total)
        self.a[index] += n / 2
        self.b[index] += 0.5 * (self.m[index] * self.m[index] * self.vinv[index] + moment - newM * newM * newVinv)
        self.m[index] = newM
        self.vinv[index] = newVinv

def standardNIG(batch, device):
    return NIGDist(torch.zeros(batch, device=device), torch.ones(batch, device=device), torch.ones(batch, device=device), torch.ones(batch, device=device))

class ThompsonTuner():
    def __init__(self, stepper, num_options, device):
        self.stepper = stepper
        self.arms = num_options        
        self.belief = standardNIG(self.arms, device)
        self.histogram = torch.zeros(self.arms, device=device)
        self.n = 0
        self.ngpu = torch.zeros(1, device=device)
        self.device = device

    def step(self):
        samples, _ = self.belief.sample()
        play = torch.argmax(samples)
        obs = self.stepper(play)
        self.belief.update_component(play, obs)
        self.histogram[play] += 1
        self.n += 1
        self.ngpu += 1
        if self.arms == 2:
            pass # Special analytic case
        if self.n >= 5:
            #print(torch.max(self.histogram) / self.ngpu)
            if (torch.max(self.histogram) / self.ngpu) > 0.80:
                choice = torch.argmax(self.histogram)
                self.stepper.choose(choice)
                self.ngpu.zero_()
                self.n = 0
                self.histogram.zero_()
                self.belief.m = self.belief.m[choice].repeat(self.arms)
                self.belief.vinv = self.belief.vinv[choice].repeat(self.arms)
                self.belief.a = self.belief.a[choice].repeat(self.arms)
                self.belief.b = self.belief.b[choice].repeat(self.arms)
                self.belief.vinv /= 2
                self.belief.a /= 2
                self.belief.b /= 2

class LRTuner():
    #Note: Stepper produces correlated outputs. Take two steps per tuning step to avoid? (Overlapping minibatches?)
    class Stepper():
        def __init__(self, base):
            self.base = base
            
        def __call__(self, choice):
            return self.base.opt(self.base.LRs[choice])
            

        def choose(self, choice):
            #print("chose!")
            ratio = self.base.ratio
            if choice == 0:
                self.base.LRs /= ratio
            else:
                self.base.LRs *= ratio
            
    def __init__(self, lr, opt, device, ratio=1.259921049):
        self.opt = opt
        self.ratio = ratio
        self.LRs = torch.tensor([lr, lr * ratio], device=device)
        self.tuner = ThompsonTuner(LRTuner.Stepper(self), 2, device)

    def step(self):
        self.tuner.step()

    def zero_grad(self):
        self.opt.zero_grad()

In [9]:
def swish(x, b):
    return x * torch.sigmoid(b * x)

class Swish(nn.Module):
    def __init__(self, chan):
        super(Swish, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(chan)))

    def forward(self, x):
        return swish(x, self.weight)

In [10]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout = 0.1, max_len = MAX_WINDOW_SIZE):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [11]:
class Functional(nn.Module):
    def __init__(self, f):
        super(Functional, self).__init__()
        self.f = f

    def forward(self, x):
        return self.f(x)

In [None]:
EPOCHS = 500000
ROUND_SZ = 100

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        super(CausalConv1d, self).__init__()
        self.pad = (kernel_size - 1) * dilation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dilation = dilation

        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size , dilation=dilation, **kwargs)
        
    def forward(self, x):
        #pad here to only add to the left side
        x = F.pad(x, (self.pad, 0))
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, skip_channels, dilation=1):
        super(ResidualBlock, self).__init__()
        self.dilation = dilation
        self.conv_sig = CausalConv1d(input_channels, output_channels, kernel_size, dilation)#dim
        self.sig = nn.Sigmoid()
        self.conv_tan = CausalConv1d(input_channels, output_channels, kernel_size, dilation)#dim
        self.tanh = nn.Tanh()
        
        #separate weights for residual and skip channels
        self.conv_r = nn.Conv1d(output_channels, output_channels, 1)#dim -> k = 1
        self.conv_s = nn.Conv1d(output_channels, skip_channels, 1)
        
    def forward(self, x):
        o = self.sig(self.conv_sig(x)) * self.tanh(self.conv_tan(x))
        skip = self.conv_s(o)
        residual = self.conv_r(o)
        return residual, skip

class WaveNet(nn.Module):
    def __init__(self, skip_channels=256, num_blocks=3, num_layers=12, num_hidden=128, kernel_size=2): 
        super(WaveNet, self).__init__()

        self.embed = nn.Embedding(skip_channels, skip_channels)
        self.positional_embedding = PositionalEncoding(skip_channels)
        self.causal_conv = CausalConv1d(skip_channels, num_hidden, kernel_size)
        self.res_stack = nn.ModuleList()

        for b in range(num_blocks):
            for i in range(num_layers):
                self.res_stack.append(ResidualBlock(num_hidden, num_hidden, kernel_size, skip_channels=skip_channels, dilation=2**i))
        
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv1d(skip_channels, skip_channels, 1)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv1d(skip_channels, skip_channels, 1)
        #self.output = nn.Softmax()
        
    def forward(self, x):

        o = self.embed(x)
        o = self.positional_embedding(o)
        o = o.permute(0,2,1)

        skip_vals = []
        #initial causal conv
        o = self.causal_conv(o)
        
        #run res blocks
        for i, layer in enumerate(self.res_stack):
            o, s = layer(o)
            skip_vals.append(s)
            
        #sum skip values and pass to last portion of network
        o = reduce((lambda a,b: a+b), skip_vals)
        o = self.relu1(o)
        o = self.conv1(o)
        o = self.relu2(o)
        o = self.conv2(o)
        
        return o #self.output(o)

KERNEL_SIZE_SAMPLES=16
KERNEL_SIZE=BYTES_PER_ENTRY * KERNEL_SIZE_SAMPLES
    
def command_net():
    return WaveNet(kernel_size=KERNEL_SIZE)

def load(path):
    command_generator = command_net()
    
    if path != None:
        command_generator.load_state_dict(torch.load(path))
        
    command_generator = command_generator.to(device)

    return command_generator

def train(path, baseOpt=optim.SGD, tune_ratio=1.2599, device=torch.device(0)):
    
    lr = 0.01
    momentum=0.85
    
    command_generator = load(path)

    baseOpt = baseOpt(command_generator.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    running_loss = torch.zeros(1, device=device)
    last_loss = torch.zeros(1, device=device)

    def step(lr):
        
        for i in range(ROUND_SZ):
            ntrain =  next(train_gen)['X'].flatten().copy()
            seq = torch.Tensor(ntrain).long().to(device)
            inputs = seq[:-1].unsqueeze(0)
            labels = seq[1:].unsqueeze(0)
    
            baseOpt.zero_grad()

            for g in baseOpt.param_groups:
                g['lr'] = lr

            outputs = command_generator(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            baseOpt.step()
            running_loss.add_(loss.detach())
            
            if i == 0:
                last_loss.copy_(loss.detach())
            result = last_loss - loss.detach()
            
            seq = seq.detach().to(cpu)
            del inputs
            del labels
            del seq

        result = running_loss / ROUND_SZ
        return result

    opt = LRTuner(lr, step, device, tune_ratio) 

    for i in range(0, EPOCHS):
        opt.step()

        print("Loss:", running_loss.item() / ROUND_SZ)
        running_loss.zero_()
        print("LR:",opt.LRs[0].item())
        
        print("Saving checkpoint")
        torch.save(command_generator.state_dict(), "./" + str(int(datetime.now().timestamp())) + ".checkpoint.model")
        torch.save(command_generator.state_dict(), "./last.checkpoint.model")
        print("Saved checkpoint")
        
        gc.collect()
    
    return command_generator.eval()

train("../../training_data/")

Loss: 3.2088217163085937
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 2.304729461669922
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 2.0896235656738282
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 1.960494384765625
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 1.9032095336914063
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 1.8320292663574218
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 1.7759666442871094
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 1.7249778747558593
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 1.7201437377929687
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 1.7098617553710938
LR: 0.009999999776482582
Saving checkpoint
Saved checkpoint
Loss: 1.6526358032226562
LR: 0.01259899977594614
Saving checkpoint
Saved checkpoint
Loss: 1.5716722106933594
LR: 0.01259899977594614
Saving checkpoint
S

In [None]:
#%%capture cap --no-stderr

print("Collecting training data")
train_gen = create_data_split(training_files("../../training_data/"))
print("Collected")

command_generator = load("./last.checkpoint.model").eval()

seed = next(train_gen)['X'].flatten().copy()

def max_of(v, begin, end):
    return begin + np.argmax(v[begin:end])

with open('output.txt', 'w') as f:

    for i in range(0, len(seed), BYTES_PER_ENTRY):
        print("Seed value :", i)
        cmd = command_of_bytes(seed[i:i+BYTES_PER_ENTRY])
        print_feature(cmd)
        print_feature(cmd, file=f)
    
    for i in range(BYTES_PER_ENTRY * 10000):
        seq = torch.Tensor(seed).long().to(device).unsqueeze(0)
        pred = command_generator(seq).detach().to(cpu).permute(0,2,1).squeeze(0).numpy()
        new_pred_bytes = np.array([np.argmax(x) for x in pred[-1:]]).astype(np.uint8)
        seed = np.concatenate([seed[1:], new_pred_bytes])

        print("Step: ", i)
    
        if (i + 1) % BYTES_PER_ENTRY == 0:
            try:
                print("New pred bytes:", seed[-7:])
                new_pred = command_of_bytes(seed[-7:])
                print_feature(new_pred)
                print_feature(new_pred, file=f)
            except BaseException as err:
                print("pred was not valid because:", err)

    del pred