In [None]:
from __future__ import print_function

import os
import sys
sys.path.append(os.path.join(os.environ['ITHEMAL_HOME'], 'learning', 'pytorch'))

In [None]:
import common_libs.utilities as ut
import data.data_cost as dt
import functools
from pprint import pprint
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, NamedTuple

In [None]:
SLOT_WIDTH = 16
NUM_SLOTS = 8

class Slot(object):
    def __init__(self):
        # type: () -> None
        self.state = torch.randn(SLOT_WIDTH, requires_grad=True)
        self.remaining_time = torch.zeros(1, requires_grad=True)
        
    def mutate(self, new_state, additional_time):
        self.state = self.state * (self.remaining_time > 0).float() + new_state
        old_time = self.remaining_time
        self.remaining_time = self.remaining_time + additional_time
        return old_time[0]
    
    def step(self, time):
        self.remaining_time = torch.clamp(self.remaining_time - time, min=0)
    
    def read(self):
        return torch.cat([self.remaining_time, self.state * (self.remaining_time > 0).float()])

In [None]:
SimulatorInput = NamedTuple('SimulatorInput', [
    ('slot_vector', torch.tensor),
    ('instruction_vector', torch.tensor),
])

SimulatorResult = NamedTuple('SimulatorResult', [
    ('wait_time', torch.tensor),
    ('write_head', torch.tensor),
    ('write_state', torch.tensor),
    ('write_time', torch.tensor),
])
ModelRunResult = NamedTuple('ModelRunResult', [
    ('prediction', torch.tensor),
    ('loss', torch.tensor),
    ('slots', List[Slot]),
])

In [None]:
def update_slots_from_result(slots, result):
    loss = torch.tensor(0., requires_grad=True)
    for i in range(NUM_SLOTS):
        frac = result.write_head[i]
        m_loss = slots[i].mutate(frac * result.write_state, frac * result.write_time)
        loss = loss + frac * (1 + m_loss + result.write_state.norm()) # l1 loss of write head
    return loss

In [None]:
def cat_embedder(emb_dim, max_n_srcs, max_n_dsts):
    sym_dict, _ = ut.get_sym_dict()
    embedder = torch.nn.Embedding(len(sym_dict), emb_dim)
    clamp = lambda x: x if x < len(sym_dict) else len(sym_dict) - 1
    
    def get_emb_list(arr, length):
        assert len(arr) <= length
        real = [embedder(torch.tensor(clamp(val))) for val in arr]
        zeros = [torch.zeros(emb_dim) for _ in range(length - len(arr))]
        return real + zeros
    
    def embed(instr):
        opc = embedder(torch.tensor(instr.opcode))
        srcs = get_emb_list(instr.srcs, max_n_srcs)
        dsts = get_emb_list(instr.dsts, max_n_dsts)
        return torch.cat([opc] + srcs + dsts)
        
    return embed

In [None]:
class NeuralProcessorSimulator(nn.Module):
    def __init__(self):
        super(NeuralProcessorSimulator, self).__init__()
        self.embedder = cat_embedder(128, 3, 3)
        self.instr_vec_emb = nn.Linear(128*7, 128)
        self.slot_vec_emb = nn.Linear((1+SLOT_WIDTH)*NUM_SLOTS, 128)
        self.wait_time_out = nn.Linear(256, 1)
        self.write_head_out = nn.Linear(256, NUM_SLOTS)
        self.write_state_out = nn.Linear(256, SLOT_WIDTH)
        self.write_time_out = nn.Linear(256, 1)
        
    def forward(self, instr_vec, slot_vec):
        instr_vec = F.relu(self.instr_vec_emb(instr_vec))
        slot_vec = F.relu(self.slot_vec_emb(slot_vec))
        concat = torch.cat([instr_vec, slot_vec])
        
        wait_time = self.wait_time_out(concat).abs()
        write_head = F.softmax(self.write_head_out(concat), dim=0)
        write_state = self.write_state_out(concat)
        write_time = self.write_time_out(concat).abs()
        
        return SimulatorResult(
            wait_time=wait_time,
            write_head=write_head,
            write_state=write_state,
            write_time=write_time,
        )
    
def run_on_data(model, datum, debug=False):
    block = datum.block
    slots = [Slot() for _ in range(NUM_SLOTS)]
    overfill_loss = torch.tensor(0., requires_grad=True)
    wait_time = torch.tensor(0., requires_grad=True)

    for i, instr in enumerate(block.instrs):
        slot_vec = torch.cat([slot.read() for slot in slots])
        instr_vec = model.embedder(instr)
        result = model(instr_vec, slot_vec)
        if debug:
            print('Instr {}'.format(instr))
            pprint(dict(vars(SimulatorResult(*[x.data for x in result]))))
            print()
        overfill_loss = overfill_loss + update_slots_from_result(slots, result)

        if i == len(block.instrs) - 1:
            break
        wait_time = wait_time + result.wait_time[0]
        for slot in slots:
            slot.step(result.wait_time)

    remaining_time = torch.max(torch.cat([slot.remaining_time for slot in slots]))
    total_time = wait_time + remaining_time
    wrongness_loss = F.mse_loss(total_time, torch.tensor(datum.y, requires_grad=True))
    loss = overfill_loss + wrongness_loss
    return ModelRunResult(total_time, loss, slots)

In [None]:
neural_processor_simulator = NeuralProcessorSimulator()
optimizer = torch.optim.Adam(neural_processor_simulator.parameters(), lr = 1e-4)

In [None]:
def step_sgd(debug=False):
    optimizer.zero_grad()
    i = random.randrange(len(data.data))
    datum = data.data[i]
    result = run_on_data(neural_processor_simulator, datum, debug=debug)
    result.loss.backward()
    
    sAPE = result.prediction / datum.y 
    print(' '*80, end='\r')
    print('sAPE: {:4.2f}, pred: {:6.2f}, actual: {:6.2f}, loss: {:8.2f}'.format(sAPE, result.prediction, datum.y, result.loss), end='\r')
    
    optimizer.step()

In [None]:
step_sgd(True)

In [None]:
while True:
    step_sgd()

In [None]:
data = dt.load_dataset('../inputs/embeddings/code_delim.emb', '../inputs/data/time_skylake_test.data')