In [1]:
import numpy as np
import pescador
import logging
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import math

import gc

from datetime import datetime

In [2]:
LOGGER = logging.getLogger('gbsd')
LOGGER.setLevel(logging.DEBUG)

In [3]:
torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)

In [4]:
matplotlib.use('Agg')

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
cpu = torch.device('cpu')

In [6]:
CMD_VOLENVPER = 0
CMD_DUTYLL = 1
CMD_MSB = 2
CMD_LSB = 3
CMD_COUNT = 4

def onehot_cmd(data):
    cmd = data[CMD_OFFSET]
    nd = [ 0, 0, 0, 0 ]
    nd[int(cmd)] = 1
    return nd


CH_1 = 0
CH_2 = 1
CH_COUNT = 2

TIME_OFFSET = 0
CH_OFFSET = 1
CMD_OFFSET = 2
PARAM1_OFFSET = 3
PARAM2_OFFSET = 4
PARAM3_OFFSET = 5
SIZE_OF_INPUT_FIELDS = 6

MAX_WINDOW_SIZE = 5000

M_CYCLES_PER_SECOND = 4194304.
NORMALIZE_TIME_BY = M_CYCLES_PER_SECOND * 10.

def norm(val, max_val):
    if val > max_val:
        return 1.
    else:
        return ((val / max_val) * 2.) - 1.

def unnorm(val, max_val):
    return ((val + 1.) / 2.) * max_val

def fresh_input(command, channel, time):
    newd = np.zeros(shape=SIZE_OF_INPUT_FIELDS, dtype=int)
    newd[TIME_OFFSET] = time
    newd[CH_OFFSET] = channel
    newd[CMD_OFFSET] = command
    return newd

def parse_bool(v):
    if v == "true":
        return 1
    elif v == "false":
        return 0
    else:
        return int(v)

def command_of_parts(command, channel, parts, time):
    inp = fresh_input(command, channel, time)
    
    if command == CMD_DUTYLL:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = int(parts[4])
    elif command == CMD_VOLENVPER:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = parse_bool(parts[4])
        inp[PARAM3_OFFSET] = int(parts[4])
    elif command == CMD_LSB:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = 0
        inp[PARAM3_OFFSET] = 0
    elif command == CMD_MSB:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = parse_bool(parts[4])
        inp[PARAM3_OFFSET] = parse_bool(parts[5])
    else:
        raise "this should not happen"
    return inp

def int32_as_bytes(ival):
    return np.frombuffer(ival.item().to_bytes(4, byteorder = 'big'), dtype=np.uint8)

def int8_as_bytes(ival):
    return np.frombuffer(ival.item().to_bytes(1, byteorder='big'), dtype=np.uint8)

def merge_params(data):
    command = data[CMD_OFFSET]
    if command == CMD_DUTYLL:
        return (data[PARAM1_OFFSET] << 6) | data[PARAM2_OFFSET]
    elif command == CMD_VOLENVPER:
        return (data[PARAM1_OFFSET] << 4) | (data[PARAM2_OFFSET] << 3) | data[PARAM3_OFFSET]
    elif command == CMD_LSB:
        return data[PARAM1_OFFSET]
    elif command == CMD_MSB:
        return data[PARAM1_OFFSET]  | (data[PARAM2_OFFSET] << 6) | (data[PARAM3_OFFSET] << 7)
    else:
        raise "this should not happen"

def command_to_bytes(command):
    new_arr = np.concatenate([
                    int32_as_bytes(command[TIME_OFFSET]),
                    int8_as_bytes(command[CH_OFFSET]),
                    int8_as_bytes(command[CMD_OFFSET]),
                    int8_as_bytes(merge_params(command)),]).flatten()
    return new_arr

def unnorm_feature(data):
    
    data = data.copy()
    
    # Unnormalize a channel given a specific max value
    def l_unnorm(channel, maxv):
        data[channel] = unnorm(data[channel], maxv)
    
    # Round and box a channel to a min and max value
    def l_round(channel, minv, maxv):
        data[channel] = round(min(max(data[channel], minv), maxv))
    
    l_unnorm(TIME_OFFSET, NORMALIZE_TIME_BY)
    
    if data[TIME_OFFSET] < 4:
        raise Exception('bad time prediction')
    
    l_unnorm(CH_OFFSET, CH_COUNT)
    data[CH_OFFSET] = round(data[CH_OFFSET])
    
    l_unnorm(CMD_OFFSET, CMD_COUNT)
    data[CMD_OFFSET] = round(data[CMD_OFFSET])
    
    command = data[CMD_OFFSET]
    
    if command == CMD_DUTYLL:
        l_unnorm(PARAM1_OFFSET, 2)
        l_round(PARAM1_OFFSET, 0, 2)
        
        l_unnorm(PARAM2_OFFSET, 64)
        l_round(PARAM2_OFFSET, 0, 64)
        
        data[PARAM3_OFFSET] = 0.
    elif command == CMD_VOLENVPER:
        
        l_unnorm(PARAM1_OFFSET, 16)
        l_round(PARAM1_OFFSET, 0, 16)
        
        l_round(PARAM2_OFFSET, 0, 1)
        
        l_unnorm(PARAM3_OFFSET, 7)
        l_round(PARAM3_OFFSET, 0, 7)
    elif command == CMD_LSB:
        
        l_unnorm(PARAM1_OFFSET, 255.)
        l_round(PARAM1_OFFSET, 0., 255.)
        
        data[PARAM2_OFFSET] = 0.
        
        data[PARAM3_OFFSET] = 0.
    elif command == CMD_MSB:
        l_unnorm(PARAM1_OFFSET, 7.)
        l_round(PARAM1_OFFSET, 0., 7.)
        
        l_round(PARAM2_OFFSET, 0., 1.)
        l_round(PARAM3_OFFSET, 0., 1.)
    else:
        raise Exception("pred was bad")
    
    return data

def print_feature(data):
    
    data = data.copy().astype(int)

    data[CH_OFFSET] = data[CH_OFFSET] + 1
    command = data[CMD_OFFSET]
    
    if command == CMD_DUTYLL:
        print(f"CH {data[CH_OFFSET]} DUTYLL {data[PARAM1_OFFSET]} {data[PARAM2_OFFSET]} AT {data[TIME_OFFSET]}")
    elif command == CMD_VOLENVPER:
        print(f"CH {data[CH_OFFSET]} VOLENVPER {data[PARAM1_OFFSET]} {data[PARAM2_OFFSET]} {data[PARAM3_OFFSET]} AT {data[TIME_OFFSET]}")
    elif command == CMD_LSB:
        print(f"CH {data[CH_OFFSET]} FREQLSB {data[PARAM1_OFFSET]} AT {data[TIME_OFFSET]}")
    elif command == CMD_MSB:
        print(f"CH {data[CH_OFFSET]} FREQMSB {data[PARAM1_OFFSET]} {data[PARAM2_OFFSET]} {data[PARAM3_OFFSET]} AT {data[TIME_OFFSET]}")
    else:
        print("Bad prediction")

def load_training_data(src):
    data = []
    file = open(src, 'r')
    for line in file:
        parts = line.split()
        if len(parts) > 0 and parts[0] == "CH":
            #print(parts)
            channel = int(parts[1])
            command = parts[2]
            time = int(parts[-1])
            if command == "DUTYLL":
                new_item = command_of_parts(CMD_DUTYLL, channel, parts, time)
            elif command == "VOLENVPER":
                new_item = command_of_parts(CMD_VOLENVPER, channel, parts, time)
            elif command == "FREQLSB":
                new_item = command_of_parts(CMD_LSB, channel, parts, time)
            elif command == "FREQMSB":
                new_item = command_of_parts(CMD_MSB, channel, parts, time)
            else:
                print("Unknown", command)
             # Otherwise unknown   
            data.append(new_item)
    return data

@pescador.streamable
def samples_from_training_data(src, window_size=MAX_WINDOW_SIZE):
    sample_data = None
    try:
        sample_data = load_training_data(src)
    except Exception as e:
        LOGGER.error('Could not load {}: {}'.format(src, str(e)))
        raise StopIteration()

    while True:
        if len(sample_data) < window_size:
            sample = sample_data
        else:
            # Sample a random window from the audio file
            start_idx = np.random.randint(0, len(sample_data) - window_size)
            sample = sample_data[start_idx:(start_idx + window_size)]
            
        sample = np.array([command_to_bytes(x) for x in sample])

        yield { 'X':sample }

def create_batch_generator(paths, batch_size):
    streamers = []
    for path in paths:
        streamers.append(samples_from_training_data(path))
    mux = pescador.StochasticMux(streamers, n_active=1, rate=1).iterate()
    
    return mux

def training_files(dirp):
    return [
      os.path.join(root, fname)
      for (root, dir_names, file_names) in os.walk(dirp, followlinks=True)
      for fname in file_names
    ]

def create_data_split(paths, batch_size):
    train_gen = create_batch_generator(paths, batch_size)
    return train_gen



In [7]:
print("Collecting training data")
train_gen = create_data_split(training_files("../..//training_data/"), 1)
print("Collected")

Collecting training data
Collected


In [8]:
from torch.distributions.gamma import Gamma
from torch.distributions.normal import Normal
import torch

class NIGDist():
    
    def __init__(self, m, vinv, a, b):
        self.m = m
        self.vinv = vinv
        self.a = a
        self.b = b

    def update(self, total, moment, n):
        
        newVinv = self.vinv  + n
        newM = (1.0 / newVinv) * (self.vinv * self.m + total)
        self.a += n / 2
        self.b += 0.5 * (self.m * self.m * self.vinv + moment - newM * newM * newVinv)
        self.m = newM
        self.vinv = newVinv
        
        
    def update(self, x):
        total = torch.sum(x, dim=1)
        moment = torch.sum(x * x, dim=1)
        n = x.shape[1]
        self.update(total, moment, n)

    def sample(self):
        vars = 1.0 / (Gamma(self.a, self.b)).sample()
        means = Normal(self.m, vars / self.vinv).sample()
        return means, vars

    def update_component(self, index, x):
        total = torch.sum(x)
        moment = torch.sum(x * x)
        n = x.shape[0]
        newVinv = self.vinv[index]  + n
        newM = (1.0 / newVinv) * (self.vinv[index] * self.m[index] + total)
        self.a[index] += n / 2
        self.b[index] += 0.5 * (self.m[index] * self.m[index] * self.vinv[index] + moment - newM * newM * newVinv)
        self.m[index] = newM
        self.vinv[index] = newVinv

def standardNIG(batch, device):
    return NIGDist(torch.zeros(batch, device=device), torch.ones(batch, device=device), torch.ones(batch, device=device), torch.ones(batch, device=device))

class ThompsonTuner():
    def __init__(self, stepper, num_options, device):
        self.stepper = stepper
        self.arms = num_options        
        self.belief = standardNIG(self.arms, device)
        self.histogram = torch.zeros(self.arms, device=device)
        self.n = 0
        self.ngpu = torch.zeros(1, device=device)
        self.device = device

    def step(self):
        samples, _ = self.belief.sample()
        play = torch.argmax(samples)
        obs = self.stepper(play)
        self.belief.update_component(play, obs)
        self.histogram[play] += 1
        self.n += 1
        self.ngpu += 1
        if self.arms == 2:
            pass # Special analytic case
        if self.n >= 100:
            #print(torch.max(self.histogram) / self.ngpu)
            if (torch.max(self.histogram) / self.ngpu) > 0.80:
                choice = torch.argmax(self.histogram)
                self.stepper.choose(choice)
                self.ngpu.zero_()
                self.n = 0
                self.histogram.zero_()
                self.belief.m = self.belief.m[choice].repeat(self.arms)
                self.belief.vinv = self.belief.vinv[choice].repeat(self.arms)
                self.belief.a = self.belief.a[choice].repeat(self.arms)
                self.belief.b = self.belief.b[choice].repeat(self.arms)
                self.belief.vinv /= 2
                self.belief.a /= 2
                self.belief.b /= 2

class LRTuner():
    #Note: Stepper produces correlated outputs. Take two steps per tuning step to avoid? (Overlapping minibatches?)
    class Stepper():
        def __init__(self, base):
            self.base = base
            
        def __call__(self, choice):
            return self.base.opt(self.base.LRs[choice])
            

        def choose(self, choice):
            #print("chose!")
            ratio = self.base.ratio
            if choice == 0:
                self.base.LRs /= ratio
            else:
                self.base.LRs *= ratio
            
    def __init__(self, lr, opt, device, ratio=1.259921049):
        self.opt = opt
        self.ratio = ratio
        self.LRs = torch.tensor([lr, lr * ratio], device=device)
        self.tuner = ThompsonTuner(LRTuner.Stepper(self), 2, device)

    def step(self):
        self.tuner.step()

    def zero_grad(self):
        self.opt.zero_grad()

In [9]:
def swish(x, b):
    return x * torch.sigmoid(b * x)

class Swish(nn.Module):
    def __init__(self, chan):
        super(Swish, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(chan)))

    def forward(self, x):
        return swish(x, self.weight)

In [10]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout = 0.1, max_len = MAX_WINDOW_SIZE):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [11]:
class Functional(nn.Module):
    def __init__(self, f):
        super(Functional, self).__init__()
        self.f = f

    def forward(self, x):
        return self.f(x)

In [13]:
class CausalConv1d(nn.Conv1d):
    # From https://github.com/Straw1239/lang-model-experiments/blob/master/src/main/ConvLM.py
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
        super(CausalConv1d, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias)
        self.left_padding = dilation * (kernel_size - 1)

    def forward(self, input):
        x = F.pad(input.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2)
        return super(CausalConv1d, self).forward(x)
    
class ConvLM(nn.Module):
    def __init__(self):
        super(ConvLM,self).__init__()
        self.embd = nn.Embedding(256, 256)
        self.conv1 = CausalConv1d(256, 256, 32, groups=64)
        self.conv2 = CausalConv1d(256, 256, 1)

    def forward(self, x):
        x = self.embd(x).permute(0,2,1)
        x = F.selu(self.conv1(x))
        x = self.conv2(x)
        return x

class ResBlock(nn.Module):
    def __init__(self, base):
        super(ResBlock, self).__init__()
        self.base = base

    def forward(self, x):
        return x + self.base(x)

class ApplyConv(nn.Module):
    def __init__(self, func):
        super(ApplyConv, self).__init__()
        self.func = func

    def forward(self, x):
        shape = x.shape
        return self.func(x.view(-1, shape[1])).view(*shape)

def res_conv_lm(layers, channels, p_builder, kernel):
    spatials = [CausalConv1d(channels, channels, kernel, groups=channels) for _ in range(layers)]
    pointwise = [ResBlock(p_builder(i)) for i in range(layers)]
    # Interleaving layers of CausalConv1d(...), ResBlock(...), CausalConv1d(...), ResBlock(...)
    conv_layers = [layer for pair in zip(spatials, pointwise) for layer in pair]
    layers = [nn.Embedding(256, channels), PositionalEncoding(channels)] + [Functional(lambda x: x.permute(0, 2, 1))] + conv_layers + [nn.Conv1d(channels, 256, 1)]
    return nn.Sequential(*layers)

def standard_conv_lm(layers, channels, ksize=8, hfac=4, activation=Swish(1)):
    def builder(i):
        return nn.Sequential(nn.Conv1d(channels, channels*hfac, 1),
                                  activation,
                                  nn.Conv1d(channels*hfac, channels, 1))
    return res_conv_lm(layers, channels, builder, ksize)

def command_net():
    return standard_conv_lm(6, 512)

EPOCHS = 500000
ROUND_SZ = 100

def load(path):
    command_generator = command_net()
    
    if path != None:
        command_generator.load_state_dict(torch.load(path))
        
    command_generator = command_generator.to(device)

    return command_generator

def memReport():
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            print(type(obj), obj.size())

def train(path, baseOpt=optim.SGD, lr=0.001, momentum=0.9, tune_ratio=1.2599, device=torch.device(0)):
    
    command_generator = load(path)

    baseOpt = baseOpt(command_generator.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    running_loss = torch.zeros(1, device=device)
    last_loss = torch.zeros(1, device=device)

    def step(lr):
        
        for i in range(ROUND_SZ):
            ntrain =  next(train_gen)['X'].flatten().copy()
            seq = torch.Tensor(ntrain).long().to(device)
            inputs = seq[:-1].unsqueeze(0)
            labels = seq[1:].unsqueeze(0)
            
            #print("Labels:", labels, labels.shape)
    
            baseOpt.zero_grad()

            for g in baseOpt.param_groups:
                g['lr'] = lr

            outputs = command_generator(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            baseOpt.step()
            running_loss.add_(loss.detach())
            
            if i == 0:
                last_loss.copy_(loss.detach())
            result = last_loss - loss.detach()
            
            seq = seq.detach().to(cpu)
            del inputs
            del labels
            del seq
            
        return result

    opt = LRTuner(lr, step, device, tune_ratio) 

    for i in range(0, EPOCHS):
        opt.step()

        print("Loss:", running_loss.item() / ROUND_SZ)
        running_loss.zero_()
        print("LR:",opt.LRs[0].item())
        
        print("Saving checkpoint")
        torch.save(command_generator.state_dict(), "./" + str(int(datetime.now().timestamp())) + ".checkpoint.model")
        torch.save(command_generator.state_dict(), "./last.checkpoint.model")
        print("Saved checkpoint")
        
        gc.collect()
        torch.cuda.empty_cache()
    
    return command_generator.eval()

train("./last.checkpoint.model")

IndentationError: unexpected indent (3410489966.py, line 56)

In [None]:
#%%capture cap --no-stderr
command_generator = load("./last.checkpoint.model").eval()

seed = torch.Tensor(next(train_gen)['X'].flatten().copy()).long().to(device).unsqueeze(0)

for i in range(1024):
    print("Pre pred", seed)
    pred = command_generator(seed).detach().to(cpu).permute(0,2,1).squeeze(0).numpy()
    pred = np.array([np.argmax(x) for x in pred])
    print(pred, pred.shape)
    seed = torch.from_numpy(pred).long().to(device).unsqueeze(0)
    del pred

In [None]:
with open('output.txt', 'w') as f:
    f.write(cap.stdout)