In [1]:
import numpy as np
import pescador
import logging
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast
from torch.distributions.gamma import Gamma
from torch.distributions.normal import Normal
from local_attention import LocalAttention

from IPython.display import display, clear_output

import math

import gc
import sys

from datetime import datetime

from functools import reduce
from torch.distributions.categorical import Categorical
from torch.distributions.one_hot_categorical import OneHotCategorical

In [2]:
LOGGER = logging.getLogger('gbsd')
LOGGER.setLevel(logging.DEBUG)

In [3]:
torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)

In [4]:
matplotlib.use('Agg')

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
cpu = torch.device('cpu')

In [6]:
CMD_VOLENVPER = 0
CMD_DUTYLL = 1
CMD_MSB = 2
CMD_LSB = 3
CMD_COUNT = 4

TIME_OFFSET = 0
CH_OFFSET = 1
CMD_OFFSET = 2
PARAM1_OFFSET = 3
PARAM2_OFFSET = 4
PARAM3_OFFSET = 5
SIZE_OF_INPUT_FIELDS = 6

MAX_WINDOW_SIZE = 2 * 1024

M_CYCLES_PER_SECOND = 4194304.
NORMALIZE_TIME_BY = M_CYCLES_PER_SECOND * 10.

def fresh_input(command, channel, time):
    newd = np.zeros(shape=SIZE_OF_INPUT_FIELDS, dtype=int)
    newd[TIME_OFFSET] = time
    newd[CH_OFFSET] = channel
    newd[CMD_OFFSET] = command
    return newd

def parse_bool(v):
    if v == "true":
        return 1
    elif v == "false":
        return 0
    else:
        return int(v)

def command_of_parts(command, channel, parts, time):
    inp = fresh_input(command, channel, time)
    
    if command == CMD_DUTYLL:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = int(parts[4])
    elif command == CMD_VOLENVPER:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = parse_bool(parts[4])
        inp[PARAM3_OFFSET] = int(parts[4])
    elif command == CMD_LSB:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = 0
        inp[PARAM3_OFFSET] = 0
    elif command == CMD_MSB:
        inp[PARAM1_OFFSET] = int(parts[3])
        inp[PARAM2_OFFSET] = parse_bool(parts[4])
        inp[PARAM3_OFFSET] = parse_bool(parts[5])
    else:
        raise "this should not happen"
    return inp

def int32_as_bytes(ival):
    return np.frombuffer(ival.item().to_bytes(4, byteorder = 'big'), dtype=np.uint8)

def int32_of_bytes(np):
    return int.from_bytes(np, byteorder = 'big')

def int8_as_bytes(ival):
    return np.frombuffer(ival.item().to_bytes(1, byteorder='big'), dtype=np.uint8)

def int8_of_bytes(np):
    return int.from_bytes(np, byteorder = 'big')

def merge_params(data):
    command = data[CMD_OFFSET]
    if command == CMD_DUTYLL:
        return (data[PARAM1_OFFSET] << 6) | data[PARAM2_OFFSET]
    elif command == CMD_VOLENVPER:
        return (data[PARAM1_OFFSET] << 4) | (data[PARAM2_OFFSET] << 3) | data[PARAM3_OFFSET]
    elif command == CMD_LSB:
        return data[PARAM1_OFFSET]
    elif command == CMD_MSB:
        return data[PARAM1_OFFSET]  | (data[PARAM2_OFFSET] << 6) | (data[PARAM3_OFFSET] << 7)
    else:
        raise "this should not happen"
        
def unmerge_params(command, data,v ):
    if command == CMD_DUTYLL:
        data[PARAM1_OFFSET] = v >> 6;
        data[PARAM2_OFFSET] = v & 0b0011_1111
    elif command == CMD_VOLENVPER:
        data[PARAM1_OFFSET] = v >> 4
        data[PARAM2_OFFSET] = (v & 0b0000_1000) >> 3
        data[PARAM3_OFFSET] = (v & 0b0000_0111)
    elif command == CMD_LSB:
        data[PARAM1_OFFSET] = v
    elif command == CMD_MSB:
        data[PARAM1_OFFSET] = v & 0b0011_1111
        data[PARAM2_OFFSET] = (v & 0b0100_0000) >> 6
        data[PARAM3_OFFSET] = (v & 0b1000_0000) >> 7
    else:
        raise Exception("this should not happen")
        
BYTES_PER_ENTRY=7

def command_to_bytes(command):
    new_arr = np.concatenate([
                    int32_as_bytes(command[TIME_OFFSET]),
                    int8_as_bytes(command[CH_OFFSET]),
                    int8_as_bytes(command[CMD_OFFSET]),
                    int8_as_bytes(merge_params(command)),]).flatten()
    return new_arr

def command_of_bytes(byte_arr):
    d = fresh_input(0, 0, 0)
    d[TIME_OFFSET] = int32_of_bytes(byte_arr[0:4])
    d[CH_OFFSET] = int8_of_bytes(byte_arr[4:5])
    if d[CH_OFFSET] != 1 and d[CH_OFFSET] != 2:
        raise Exception("bad channel prediction")
    d[CMD_OFFSET] = int8_of_bytes(byte_arr[5:6])
    unmerge_params(d[CMD_OFFSET], d, byte_arr[6])
    return d

def print_feature(data, file=sys.stdout):
    command = data[CMD_OFFSET]
    if command == CMD_DUTYLL:
        print(f"CH {data[CH_OFFSET]} DUTYLL {data[PARAM1_OFFSET]} {data[PARAM2_OFFSET]} AT {data[TIME_OFFSET]}", file=file, flush=True)
    elif command == CMD_VOLENVPER:
        print(f"CH {data[CH_OFFSET]} VOLENVPER {data[PARAM1_OFFSET]} {data[PARAM2_OFFSET]} {data[PARAM3_OFFSET]} AT {data[TIME_OFFSET]}", file=file, flush=True)
    elif command == CMD_LSB:
        print(f"CH {data[CH_OFFSET]} FREQLSB {data[PARAM1_OFFSET]} AT {data[TIME_OFFSET]}", file=file, flush=True)
    elif command == CMD_MSB:
        print(f"CH {data[CH_OFFSET]} FREQMSB {data[PARAM1_OFFSET]} {data[PARAM2_OFFSET]} {data[PARAM3_OFFSET]} AT {data[TIME_OFFSET]}", file=file, flush=True)
    else:
        print(f"Bad prediction", file=file, flush=True)

def load_training_data(src):
    data = []
    file = open(src, 'r')
    for line in file:
        parts = line.split()
        if len(parts) > 0 and parts[0] == "CH":
            #print(parts)
            channel = int(parts[1])
            command = parts[2]
            time = int(parts[-1])
            if command == "DUTYLL":
                new_item = command_of_parts(CMD_DUTYLL, channel, parts, time)
            elif command == "VOLENVPER":
                new_item = command_of_parts(CMD_VOLENVPER, channel, parts, time)
            elif command == "FREQLSB":
                new_item = command_of_parts(CMD_LSB, channel, parts, time)
            elif command == "FREQMSB":
                new_item = command_of_parts(CMD_MSB, channel, parts, time)
            else:
                print("Unknown", command)
             # Otherwise unknown   
            data.append(new_item)
    return data

@pescador.streamable
def samples_from_training_data(src, window_size, start_at_sample):
    
    # Scale the window size by the bytes per entry
    window_size = window_size * BYTES_PER_ENTRY
    
    sample_data = None
    try:
        sample_data = load_training_data(src)
    except Exception as e:
        LOGGER.error('Could not load {}: {}'.format(src, str(e)))
        raise StopIteration()
        
    sample_data = np.array([command_to_bytes(x) for x in sample_data]).flatten()

    while True:
        if len(sample_data) < window_size:
            sample = sample_data
        else:
            # Sample a random window from the audio file
            start_idx = np.random.randint(0, len(sample_data) - window_size)
            
            # If we should start on a sample boundary then round to the nearest multiple of sample boundary from the start
            if start_at_sample:
                start_idx = BYTES_PER_ENTRY * round(start_idx / BYTES_PER_ENTRY)

            sample = sample_data[start_idx:(start_idx + window_size)]

        yield sample

def create_batch_generator(paths, window_size, start_at_sample):
    streamers = []
    
    for path in paths:
        streamers.append(samples_from_training_data(path, window_size, start_at_sample))
    
    mux = pescador.StochasticMux(streamers, n_active=1, rate=1).iterate()
    
    return mux

def training_files(dirp):
    return [
      os.path.join(root, fname)
      for (root, dir_names, file_names) in os.walk(dirp, followlinks=True)
      for fname in file_names
    ]

def create_data_split(paths, window_size=MAX_WINDOW_SIZE, start_at_sample=True):
    train_gen = create_batch_generator(paths, window_size, start_at_sample)
    return train_gen

class SampleDataset(torch.utils.data.IterableDataset):
    
    def __init__(self, path, window_size):
        super(SampleDataset).__init__()
        
        files = training_files(path)
        
        print("Training files: ", files)
        
        # Add one to window_size so that we have window size labels and inputs
        self.loader = create_data_split(files, window_size=MAX_WINDOW_SIZE + 1, start_at_sample=False)
    
    def __iter__(self):
        while True:
             yield next(self.loader)

In [7]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout = 0.1, max_len = 2048):
        super().__init__()
        
        assert(MAX_WINDOW_SIZE <= max_len)
        
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [8]:
KERNEL_SIZE_SAMPLES=16
KERNEL_SIZE=BYTES_PER_ENTRY * KERNEL_SIZE_SAMPLES
NUM_LAYERS=4
RECEPTIVE_FIELD_BYTES=KERNEL_SIZE*(2**4)

class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        super(CausalConv1d, self).__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size , dilation=dilation, **kwargs)
        
    def forward(self, x):
        #pad here to only add to the left side
        x = F.pad(x, (self.pad, 0))
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, skip_channels, dilation=1):
        super(ResidualBlock, self).__init__()
        
        self.conv_sig = CausalConv1d(input_channels, output_channels, kernel_size, dilation)
        self.sig = nn.Sigmoid()
        
        self.conv_tan = CausalConv1d(input_channels, output_channels, kernel_size, dilation)
        self.tanh = nn.Tanh()
        
        #separate weights for residual and skip channels
        self.conv_r = nn.Conv1d(output_channels, output_channels, 1)
        self.conv_s = nn.Conv1d(output_channels, skip_channels, 1)
        
    def forward(self, x):
        o = self.sig(self.conv_sig(x)) * self.tanh(self.conv_tan(x))
        skip = self.conv_s(o)
        residual = self.conv_r(o)
        return residual, skip
    
class AttentionResBlock(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, skip_channels, dilation=1):
        super(AttentionResBlock, self).__init__()
        
        self.attn_sig = LocalAttention(window_size=512, dim=input_channels, causal=True)
        self.sig = nn.Sigmoid()
        
        self.attn_tan = LocalAttention(window_size=512, dim=input_channels, causal=True)
        self.tanh = nn.Tanh()
        
        #separate weights for residual and skip channels
        self.conv_r = nn.Conv1d(output_channels, output_channels, 1)
        self.conv_s = nn.Conv1d(output_channels, skip_channels, 1)
        
    def forward(self, x):
        o = self.sig(self.attn_sig(x)) * self.tanh(self.attn_tan(x))
        skip = self.conv_s(o)
        residual = self.conv_r(o)
        return residual, skip
    
class AttentionNet(nn.Module):
    def __init__(self, skip_channels=256, num_blocks=4, num_layers=5, num_hidden=256, kernel_size=KERNEL_SIZE): 
        super(AttentionNet, self).__init__()

        self.embed = nn.Embedding(skip_channels, skip_channels)
        self.positional_embedding = PositionalEncoding(skip_channels)
        self.causal_conv = CausalConv1d(skip_channels, num_hidden, kernel_size)
        self.res_stack = nn.ModuleList()

        for b in range(num_blocks):
            for i in range(num_layers):
                self.res_stack.append(AttentionResBlock(num_hidden, num_hidden, kernel_size, skip_channels=skip_channels, dilation=2**i))
        
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv1d(skip_channels, skip_channels, 1)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv1d(skip_channels, skip_channels, 1)
        
    def forward(self, x):

        o = self.embed(x)
        o = self.positional_embedding(o)
        o = o.permute(0,2,1)
        
        o = self.causal_conv(o)
        
        skip_vals = []
        
        #run res blocks
        for i, layer in enumerate(self.res_stack):
            o, s = layer(o)
            skip_vals.append(s)
            
        #sum skip values and pass to last portion of network
        o = reduce((lambda a,b: a+b), skip_vals)
        
        o = self.relu1(o)
        o = self.conv1(o)
        o = self.relu2(o)
        o = self.conv2(o)
        
        return o

# When using dilations the effective lookback is KERNEL_SIZE^num_layers
class CommandNet(nn.Module):
    def __init__(self, skip_channels=256, num_blocks=2, num_layers=NUM_LAYERS, num_hidden=256, kernel_size=KERNEL_SIZE): 
        super(CommandNet, self).__init__()

        self.embed = nn.Embedding(skip_channels, skip_channels)
        self.positional_embedding = PositionalEncoding(skip_channels)
        self.causal_conv = CausalConv1d(skip_channels, num_hidden, kernel_size)
        self.res_stack = nn.ModuleList()

        for b in range(num_blocks):
            for i in range(num_layers):
                self.res_stack.append(ResidualBlock(num_hidden, num_hidden, kernel_size, skip_channels=skip_channels, dilation=2**i))
        
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv1d(skip_channels, skip_channels, 1)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv1d(skip_channels, skip_channels, 1)
        
    # TODO: Move the receptive field here
        
    def forward(self, x):

        o = self.embed(x)
        o = self.positional_embedding(o)
        o = o.permute(0,2,1)
        
        o = self.causal_conv(o)
        
        skip_vals = []
        
        #run res blocks
        for i, layer in enumerate(self.res_stack):
            o, s = layer(o)
            skip_vals.append(s)
            
        #sum skip values and pass to last portion of network
        o = reduce((lambda a,b: a+b), skip_vals)
        
        o = self.relu1(o)
        o = self.conv1(o)
        o = self.relu2(o)
        o = self.conv2(o)
        
        return o

def load(path):
        
    lr = 0.01
    momentum=0.8
    
    command_generator = CommandNet()
    
    optimizer = optim.SGD(
        command_generator.parameters(),
        lr=lr,
        momentum=momentum
    )
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.97, min_lr=0.0001)
    command_generator = command_generator.to(device)
    
    # This needs to be after to because the optimizer decides what device to send the tensors to based on the
    # device of the model.
    if path != None:
        command_generator.load_state_dict(torch.load(path + ".model"))
        optimizer.load_state_dict(torch.load(path + ".optimizer"))
        #scheduler = torch.load(path + ".scheduler")

    return command_generator, optimizer, scheduler

In [None]:
EPOCHS = 500000
ROUND_SZ = 100

print("Collecting training data")
loader = torch.utils.data.DataLoader(SampleDataset("../../training_data/",window_size=MAX_WINDOW_SIZE))
print("Collected")

def train(path):
    
    command_generator, optimizer, scheduler = load(path)

    criterion = nn.CrossEntropyLoss()
    running_loss = torch.zeros(1, device=device)

    def step():
        
        print("Starting batch")
        running_loss.zero_()
        
        for i in range(ROUND_SZ):
            
            if i % (ROUND_SZ / 10) == 0:
                print("Batch completion:", (float(i) / float(ROUND_SZ)) * 100., "%")
            
            seq = next(iter(loader)).long().to(device)
            inputs = seq[:,:-1]
            labels = seq[:,1:]
            
            #print(inputs, inputs.shape)
    
            optimizer.zero_grad()

            with autocast():
                outputs = command_generator(inputs)
                
                # Backprop only on the datapoints that had at least half a k kernel
                #backprop_l = int(min(KERNEL_SIZE / 2, len(seq) / 2))
                #backprop_inputs = outputs[:,:,backprop_l:]
                #backprop_outputs = labels[:,backprop_l:]
                
                loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()
            #print(loss.detach().item())
            running_loss.add_(loss.detach())
            #print(running_loss)
            
            seq = seq.detach().to(cpu)
            del inputs
            del labels
            del seq

        result = running_loss / ROUND_SZ
        return result
    
    def save(name):
        torch.save(command_generator.state_dict(), "./" + name + ".model")
        torch.save(optimizer.state_dict(), "./" + name + ".optimizer")
        
        # Saving the scheduler seems to break stuff
        #torch.save(scheduler, "./" + name + ".scheduler")

    for i in range(0, EPOCHS):
        loss = step()
        scheduler.step(loss)

        print("Loss:", loss.item())
        print("LR:", optimizer.param_groups[0]['lr'])
        
        print("Saving checkpoint")
        
        # Timestamp every 10th epoch to test fits later
        if i % 10 == 0:
            save(str(int(datetime.now().timestamp())))

        save("./last.checkpoint")
        print("Saved checkpoint")
    
    return command_generator.eval()

#train(None)
train("./last.checkpoint")

Collecting training data
Training files:  ['../../training_data/harvest_title', '../../training_data/tetris_world_music_a', '../../training_data/tetris_2_game_menu_b', '../../training_data/barbie', '../../training_data/casper', '../../training_data/pk_oak', '../../training_data/zelda_title', '../../training_data/star_wars_return_of_the_jedi_title', '../../training_data/tetris_attack', '../../training_data/tetris_attack_ingame', '../../training_data/tetris_world', '../../training_data/batman_return_of_the_joker_title', '../../training_data/star_wars_return_of_the_jedi_ingame', '../../training_data/aliens_into', '../../training_data/kirbys_dreamland_title', '../../training_data/knights_quest_title', '../../training_data/toy_story_menu', '../../training_data/star_wars_title', '../../training_data/pk_title', '../../training_data/batman_forever_title', '../../training_data/bubble_ghost_ingame', '../../training_data/tetris_attack_menu', '../../training_data/lion_king_menu', '../../training_d

In [None]:
print("Collecting training data")
train_gen = create_data_split(training_files("../../out_of_sample/"))
print("Collected")

command_generator, _, _ = load("./last.checkpoint")
command_generator = command_generator.eval()

# Cut the seed to the receptive window of our model so that it executes faster
seed = next(train_gen)[:RECEPTIVE_FIELD_BYTES]

def max_of(v, begin, end):
    return begin + np.argmax(v[begin:end])

with open('seed.txt', 'w') as f:
    for i in range(0, len(seed), BYTES_PER_ENTRY):
        print("Seed value :", i, seed.shape)
        cmd = command_of_bytes(seed[i:i+BYTES_PER_ENTRY])
        print_feature(cmd, file=f)
        
class MovingWindow():
    
    def __init__(self, seed):
        # Pre-allocate 16x the seed
        self.seq = torch.cat((torch.Tensor(seed).long(), torch.zeros(len(seed) * 16).long())).to(device)
        self.start = 0
        self.len = len(seed)
        
    def end(self):
        return self.start + self.len
        
    def append(self, item):
        
        # when we run out of free slots we move the array by using torch.roll
        # so that the data we care about is from 0:len again.
        if self.end() == len(self.seq):
            # Roll of a 1d Tensor => arr[i] = arr[(i + shift) % len(arr)], so the most recent element 
            torch.roll(self.seq, self.len)
            self.start = 0
        else:
            self.seq[self.end()] = item
            self.start += 1

    def window(self):
        # Slice the current window
        return self.seq[self.start:self.end()]
    
window = MovingWindow(seed)

with open('output.txt', 'w') as f:
    
    for i in range(BYTES_PER_ENTRY * 10000):
        seq = window.window().unsqueeze(0)
        pred = command_generator(seq).detach().to(cpu).permute(0,2,1).squeeze(0)[-1]
        pred = Categorical(logits=pred).sample()
        window.append(pred)
    
        if (i + 1) % BYTES_PER_ENTRY == 0:
            try:
                last_sample = window.window()[-BYTES_PER_ENTRY:].detach().cpu().numpy().astype(np.uint8)
                last_sample = command_of_bytes(last_sample)
                print_feature(last_sample, file=f)
            except BaseException as err:
                print("pred was not valid because:", err)

    del pred