In [939]:
import torch.nn as nn
import torch
from torch.autograd import Variable

In [940]:
class Controller(nn.Module):
    """
    Contains the two learnable parts of the model in four independent, fully connected layers.
    First the initial values for the registers and instruction registers and second the 
    parameters that computes the required distributions. 
    """
    def __init(self, program_matrix1, program_matrix2, program_matrix3, program_matrix4, initial_memory, 
               initial_registers, instruction_register, stop_threshold):
        """
        Initializes registers, memory and register matrix and their dimensions 
        Initializes four program matrices one of dimension NxM and three of dimension RxM
        """
        super(Controller, self).__init__()
        R, M = initial_registers.size()
        self.M = M
        self.R = R
        
        self.program_matrix1 = nn.Parameter(program_matrix1)
        self.program_matrix2 = nn.Parameter(program_matrix2)
        self.program_matrix3 = nn.Parameter(program_matrix3)
        self.program_matrix4 = nn.Parameter(program_matrix4)    
        
        # Memory matrix (M x M)
        self.memory = initial_memory
        
        # Register Matrix (R x M)
        self.registers = initial_registers
        
        # Instruction Register (M)
        self.IR = instruction_register
        
        # Machine initialization
        self.Machine = Machine(M, R, stop_threshold)
                
    def forward(self):
        einput = torch.bmm(self.program_matrix1, self.IR)
        ainstruction = torch.bmm(self.program_matrix2, self.IR)
        binstruction = torch.bmm(self.program_matrix3, self.IR)
        outputin = torch.bmm(self.program_matrix4, self.IR)
        
        # Updating memory, registers, and IR after machine operation
        self.memory, self.registers, self.IR, self.stop = self.Machine
        (einput, ainstruction, binstruction, outputin, memory, registers, IR)
                
        
    def lossfunctions(self, cmatrix, tjmatrix):
        """ compute four different loss functions and return a weighted average of the four measuring correctness, 
        halting, efficiency, and confidence"""
        sum ((cmatrix)*(self.program_matrix4 - self.memory)^2)
        ### need to add loss functions::: mehdi需要帮我们
        救命
        
        
# HOW TO BLUR? ==> Add a constant, then softmax
#TODO: Add in some fuzz factor????
      
        
        

In [941]:
class Operation(nn.Module):
    """
    Parent class for our binary operations
    """
    def __init__(self, M):
        """
        Initialize the memory length (needed so we can mod our answer in case it exceeds the range 0-M-1)
        
        :param M: Memory length
        """
        super(Operation, self).__init__()
        self.M = M #TODO: Check this gets updated!
        self.outputs = torch.zeros(M, M, M)
        for i in range(M):
            for j in range(M):
                val = self.compute(i, j)
                self.outputs[val][i][j] = 1
    
    def compute(self, x, y):
        """ 
        Perform the binary operation.  The arguments may or may not be used.
        
        :param x: First argument
        :param y: Second argument
        """
        raise NotImplementedError
    
    def forward(self):
        return self.outputs

In [942]:
class Add(Operation):

    def __init__(self, M):
        super(Add, self).__init__(M)
    
    def compute(self, x, y):
        return (x + y) % self.M


In [943]:
class Stop(Operation):
    
    def __init__(self, M):
        super(Stop, self).__init__(M)

    def compute(self, _1, _2):
        return 0

In [944]:
class Jump(Operation):
    
    def __init__(self, M):
        super(Jump, self).__init__(M)

    def compute(self, _1, _2):
        return 0 # Actual jump happens in the Machine class

In [945]:
class Decrement(Operation):
    
    def __init__(self, M):
        super(Decrement, self).__init__(M)

    def compute(self, x, _):
        return (x - 1) % self.M

In [946]:
class Increment(Operation):
    
    def __init__(self, M):
        super(Increment, self).__init__(M)

    def compute(self, x, _):
        return (x + 1) % self.M

In [947]:
class Max(Operation):
    
    def __init__(self, M):
        super(Max, self).__init__(M)

    def compute(self, x, y):
        return max(x,y)

In [948]:
class Min(Operation):
    
    def __init__(self, M):
        super(Min, self).__init__(M)

    def compute(self, x, y):
        return min(x,y)

In [949]:
class Read(Operation):
    
    def __init__(self, M):
        super(Read, self).__init__(M)
        self.outputs = torch.zeros(M, M, M) # Clear output matrix out since we're gonna do the reading elsewhere

    def compute(self, x, _):
        return 0 # Actual reading happens in the Machine class

In [950]:
class Subtract(Operation):
    
    def __init__(self, M):
        super(Subtract, self).__init__(M)

    def compute(self, x, y):
        return (x - y) % self.M

In [951]:
class Write(Operation):
    
    def __init__(self, M):
        super(Write, self).__init__(M)

    def compute(self, x, y):
        return 0 # Actual write happens in the Machine class

In [952]:
class Zero(Operation):
    
    def __init__(self, M):
        super(Zero, self).__init__(M)

    def compute(self, _1, _2):
        return 0

In [953]:
class Machine(nn.Module):
    """
    The Machine executes assembly instructions passed to it by the Controller.
    It updates the given memory, registers, and instruction pointer.
    The Machine doesn't have any learnable parameters.
    """
    def __init__(self, B, M, R, stop_threshold):
        """
        Initializes dimensions, operations, and counters
        
        :param M: Memory length.  Integer values also take on values 0-M-1.  M is also the program length.
        :param R: Number of registers
        :stop_threshold: Accumulated probability of stopping after which the program terminates.
        """
        super(Machine, self).__init__()
        
        # Store parameters as class variables
        self.R = R # Number of registers
        self.M = M # Memory length (also largest number)
        self.B = B # Batch size
        self.stop_threshold = stop_threshold
        
        # Start off with 0 probability of stopping
        self.stop_probability = torch.zeros(B)
        
        # A list of all our possible ops
        self.ops = [
            Jump(M), 
            Stop(M), 
            Write(M), 
            Read(M), 
            Add(M), 
            Subtract(M), 
            Increment(M),
            Decrement(M),
            Min(M),
            Max(M),
            Zero(M)
        ]
        
        # Number of instructions
        self.N = len(self.ops)
        
        self.outputs = Variable(torch.zeros(self.N, M, M, M))
        
        for i in range(self.N):
            op = self.ops[i]
            self.outputs[i] = op()
            
            
        self.outputs = torch.unsqueeze(self.outputs, 0)
        self.outputs = self.outputs.expand(B, -1, -1, -1, -1)
        
        # Keep track of the index of certain ops which are dealt with specially
        self.jump_index = 0
        self.stop_index = 1
        self.write_index = 2
        self.read_index = 3
        

        
    def forward(self, e, a, b, o, memory, registers, IR):
        
        """
        Run the Machine for one timestep (corresponding to the execution of one line of Assembly).
        The first four parameter names correspond to the vector names used in the original ANC paper
        
        :param e: Probability distribution over the instruction being executed (length M)
        :param a: Probability distribution over the first argument register (length R)
        :param b: Probability distribution over the second argument register (length R)
        :param o: Probability distribution over the first argument register (length R)
        :param memory: Memory matrix (size M x M)
        :param registers: Register matrix (size R x M)
        :param IR: Instruction Register (length M)
        
        :return: The memory, registers, and instruction register after the timestep
        """
        
        # Give all vectors an extra dimension
        
        # Dimensions B x R -> B x 1 x R
        a = torch.unsqueeze(a, 1)
        b = torch.unsqueeze(b, 1)
        
        
        # Calculate distributions over the two argument values by multiplying each 
        # register by the probability that register is being used.
        arg1 = torch.bmm(a, registers)
        arg2 = torch.bmm(b, registers)
        
        
        # Arg1 dimensions: B x 1 x M --> B x 1 x 1 x 1 x 3
        arg1_long = torch.unsqueeze(arg1, 1)
        arg1_long = torch.unsqueeze(arg1_long, 1)
        
        
        # A bunch of matrix-y stuff 
        #arg1: BxMx1; Outputs = NxMxMxM

        x = torch.matmul(arg1_long, self.outputs)
        
        # x dimensions: B x N x M x 1 x M -> B x N x M x M
        x = torch.squeeze(x, 3)
        
        # Arg2 dimensions: B x 1 x M --> B x 1 x M x 1
        arg2_long = torch.unsqueeze(arg2, 3)
        
        
        y = torch.matmul(x, arg2_long)
        
        # y dimensions: B x N x M x 1 -> B x N x M
        y = torch.squeeze(y, 3)
        
        # Dimensions B x N -> B x 1 x N
        e = torch.unsqueeze(e, 1)
        read_vec =  e[:, :, self.read_index]
        # Dimensions B x 1 -> B x 1 x 1
        read_vec = read_vec.unsqueeze(1)
        
        out_vec = torch.matmul(e, y) # Length M vector over the output of the operation
        # Deal with memory reads separately
        out_vec = out_vec + read_vec * torch.matmul(arg1, memory)        
        torch.Size([4, 1, 3])
        
        # Update our memory, registers, instruction register, and stopping probability
        memory = self.writeMemory(e, memory, arg1, arg2)
        registers = self.writeRegisters(out_vec, o, registers)
        IR = self.updateIR(e, IR, arg1, arg2)
        should_stop = self.updateStop(e)
        
        return(memory, registers, IR, should_stop)
        
        
    def writeRegisters(self, out, o, registers):
        """
        Write the result of our operation to our registers.
        
        :param out: probability distribution over the output value
        :param o: probability distribution over the output register
        :param registers: register matrix
        
        :return: the updated registers
        """
        # Multiply probability of writing to each output register by the value 
        
        # o dimensions: B x R -> B x R x 1
        o = torch.unsqueeze(o, 2)
        
        new_register_vals = torch.matmul(o, out)
        # Multiply each original register cell by the probabilty of not writing to that register
        old_register_vals = (1-o).expand(self.B, self.R, self.M) * registers
        
        registers =  new_register_vals + old_register_vals
        return registers
    
    def updateIR(self, e, IR, arg1, arg2):
        """
        Update the instruction register
        
        :param e: distribution over the current instruction (length N)
        :param IR: instruction register (length M)
        :param arg1: distribution over the first argument value (length M)
        :param arg2: distribution over the second argument value (length M)
        
        :return: the updated instruction register
        """
        # IR - length M vector
        jump_probability = e[:, :, self.jump_index]
        
        is_zero = arg1[:, :, 0]
        # Slicing lost a dimension.  Let's add it back
        jump_probability = torch.unsqueeze(jump_probability, 1)
        is_zero = torch.unsqueeze(is_zero, 1)
        
        
        # If we're not jumping, just shift IR by one slot
        wraparound = IR[:, -1]
        normal_instructions = IR[:, :-1]
        
        # for whatever reason, when you chop off one row/column, that dimension disappears
        # add it back
        wraparound = wraparound.unsqueeze(1)
        IR_no_jump = torch.cat([wraparound, normal_instructions], 1)
        # If we are on a jump instruction, check whether the argument's 0.
        # If it is, jump to the location specified by arg2.  Otherwise, increment like normal.
        IR_jump = arg2 * is_zero + (1 - is_zero) * IR_no_jump
        
        IR = IR_no_jump * (1 - jump_probability) + IR_jump * jump_probability
        return IR
    
    def writeMemory(self, e, mem_orig, arg1, arg2):
        """
        Update the memory
        
        :param e: distribution over the current instruction (B x 1 x M)
        :param mem_orig: current memory matrix (B x M x M)
        :param arg1: distribution over the first argument value (B x 1 x M)
        :param arg2: distribution over the second argument value (B x 1 x M)
        
        :return: the updated memory matrix
        """
        write_probability = e[:,:, self.write_index]
        # Write_prob dimensions: B x 1 x 1
        write_probability = torch.unsqueeze(write_probability, 1)
        
        # Arg1 dimensions: B x 1 x M -> B x M x 1
        arg1 = torch.transpose(arg1, 1, 2)
        
        # If we are on a write instruction, write the value arg2 in register arg1. Otherwise, leave memory as is.
        mem_write = torch.bmm(arg1, arg2) 
        temp3 = 1 - write_probability
        memory = mem_orig * (1 - write_probability) + mem_write * write_probability
        return memory
        
    def updateStop(self, e):
        """
        Update the probability that we will stop at this timestep based on the probability that we are running the STOP op.
        
        :param e: distribution over the current instruction (length M)
        
        :return: boolean representing whether the controller should stop.
        """
        self.stop_probability += e[:, :, self.stop_index].data[0]
        return (self.stop_probability > self.stop_threshold)


In [954]:
# Mini-test
# M=3, R=2, N=11, B=4
machine = Machine(4, 3, 2, .5)
e = Variable(torch.FloatTensor([[.1,.1,.1,.1,.1,.1,.1,.1,.1,.05,.05],
                                [0,0,.6,.05,.05,.05,.05,.05,.05,.05,.05],
                                [0,0,.6,.05,.05,.05,.05,.05,.05,.05,.05],
                                [0,0,.05,.05,.05,.05,.6,.05,.05,.05,.05]])) 
a = Variable(torch.FloatTensor([[.1, .9],
                                [.3, .7],
                                [.3, .7],
                                [.7, .3]]))
b = Variable(torch.FloatTensor([[.8, .2],
                                [.5, .5],
                                [.5, .5],
                                [.9, .1]]))
o = Variable(torch.FloatTensor([[.6, .4],
                                [.6, .4],
                                [.6, .4],
                                [.6, .4]]))
memory = Variable(torch.FloatTensor([[[1,0,0], [1,0,0], [0,0,1]],
                                     [[0,1,0], [1,0,0], [0,0,1]],
                                     [[0,1,0], [1,0,0], [0,0,1]],
                                     [[.5,.5,0], [1,0,0], [0,0,1]]]))
registers = Variable(torch.FloatTensor([[[.4, .5, .1], [.2, .6, .2]],
                                        [[.4, .1, .5], [.2, .6, .2]],
                                        [[.4, .1, .5], [.2, .6, .2]],
                                        [[.4, .5, .1], [.6, .2, .2]]]))
IR = Variable(torch.FloatTensor([[.1, .9, .1],
                                 [.1, .9, .1],
                                 [.1, .9, .1],
                                 [.1, .4, .5]]))
mem, regs, ir, stop = machine(e, a, b, o, memory, registers, IR)
print("MEM", mem)
print("REGS", regs)
print("IR", ir)
print("STOP", stop)

# # M=4, R=2, N=11
# machine = Machine(3, 2, .5)
# e = Variable(torch.FloatTensor([.1,.1,.1,.1,.1,.1,.1,.1,.1,.05,.05])) 
# a = Variable(torch.FloatTensor([.1, .9]))
# b = Variable(torch.FloatTensor([.8, .2]))
# o = Variable(torch.FloatTensor([.6, .4]))
# memory = Variable(torch.FloatTensor([[1,0,0], [1,0,0], [0,0,1]]))
# registers = Variable(torch.FloatTensor([[.4, .5, .1], [.2, .6, .2]]))
# IR = Variable(torch.FloatTensor([.1, .9, .1]))
# mem, regs, ir = machine(e, a, b, o, memory, registers, IR)


torch.Size([4, 1, 3])

AAAAAA

 0.1000
 0.1000
 0.1000
 0.1000
[torch.FloatTensor of size 4]

0.5

 0
 0
 0
 0
[torch.ByteTensor of size 4]

MEM Variable containing:
(0 ,.,.) = 
  0.9079  0.0114  0.0026
  0.9212  0.0307  0.0071
  0.0068  0.0099  0.9023

(1 ,.,.) = 
  0.0468  0.4546  0.0546
  0.4810  0.0945  0.0945
  0.0522  0.0609  0.4609

(2 ,.,.) = 
  0.0468  0.4546  0.0546
  0.4810  0.0945  0.0945
  0.0522  0.0609  0.4609

(3 ,.,.) = 
  0.4847  0.4858  0.0025
  0.9586  0.0096  0.0023
  0.0027  0.0031  0.9507
[torch.FloatTensor of size 4x3x3]

REGS Variable containing:
(0 ,.,.) = 
  0.5373  0.3134  0.1493
  0.3715  0.4356  0.1929

(1 ,.,.) = 
  0.6227  0.1080  0.2693
  0.4285  0.4053  0.1662

(2 ,.,.) = 
  0.6227  0.1080  0.2693
  0.4285  0.4053  0.1662

(3 ,.,.) = 
  0.3457  0.4234  0.2309
  0.4838  0.2689  0.2473
[torch.FloatTensor of size 4x2x3]

IR Variable containing:
(0 ,.,.) = 
  0.1057  0.1092  0.8828
  0.1057  0.1092  0.8828
  0.1057  0.1092  0.8828
  0.4969  0.1092  0.3938


In [955]:
# TASKS 
# - Compilation
# - Train function
# - Blurring
# - Running the tests they ran, verifying that we get similar results

In [956]:
# model = Controller()
# lower_seq_length = 3
# upper_seq_length = 10
# num_batches = 10000

# dataset = CopyTaskDataset(num_batches, batch_size, lower_seq_length, upper_seq_length, seq_size)
# data_loader = data.DataLoader(dataset, batch_size=batch_size)
# def train_model(model, dset_loader, training_criterion, optimizer)