In [1]:
import torch.nn as nn
import torch
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import random
import matplotlib.pyplot as plt

In [2]:
cd ..

/Users/oliviawatkins/Documents/Schoolwork/NN/neural_nets_research


In [3]:
from neural_nets_library import training

In [4]:
iteration = 0

In [5]:
def checkProb(vec, dim, name):
    # Get rid of batch dim
    sums = torch.sum(vec, dim)
    prob = torch.max(torch.abs(sums - 1)).data[0]
    if not prob < .002:
        print("BAD PROB", prob, name, vec)
    return prob < .002
    

In [44]:
class Controller(nn.Module):
    """
    Contains the two learnable parts of the model in four independent, fully connected layers.
    First the initial values for the registers and instruction registers and second the 
    parameters that computes the required distributions. 
    """

    def __init__(self, 
                 first_arg = None, 
                 second_arg = None, 
                 output = None, 
                 instruction = None, 
                 initial_registers = None, 
                 stop_threshold = 1, 
                 multiplier = 5,
                 correctness_weight = .2, 
                 halting_weight = .2, 
                 confidence_weight = .2, 
                 efficiency_weight = .4,
                 t_max = 75):
        """
        Initialize a bunch of constants and pass in matrices defining a program.
        
        :param first_arg: Matrix with the 1st register argument for each timestep stored in the columns (RxM)
        :param second_arg: Matrix with the 2nd register argument for each timestep stored in the columns (RxM)
        :param output: Matrix with the output register for each timestep stored in the columns (RxM)
        :param instruction: Matrix with the instruction for each timestep stored in the columns (NxM)
        :param initial_registers: Matrix where each row is a distribution over the value in one register (RxM)
        :param stop_threshold: The stop probability threshold at which the controller should stop running
        :param multiplier: The factor our one-hot vectors are be multiplied by before they're softmaxed to add blur
        :param correctness_weight: Weight given to the correctness component of the loss function
        :param halting_weight: Weight given to the halting component of the loss function
        :param confidence_weight: Weight given to the confidence component of the loss function
        :param efficiency_weight: Weight given to the efficiency component of the loss function
        
        """
        super(Controller, self).__init__()
        
        # Initialize dimension constants
        R, M = initial_registers.size()
        self.M = M
        self.R = R
        self.times = []
        
        # Initialize loss function weights
        # In the ANC paper, these scalars are called, alpha, beta, gamma, and delta
        self.correctness_weight = correctness_weight
        self.halting_weight = halting_weight
        self.confidence_weight = confidence_weight
        self.efficiency_weight = efficiency_weight
        
        # And yet more initialized constants... yeah, there are a bunch, I know.
        self.t_max = t_max
        self.stop_threshold = stop_threshold
        self.multiplier = multiplier

        # Initialize parameters.  These are the things that are going to be optimized. 
        self.first_arg = nn.Parameter(multiplier * first_arg)
        self.second_arg = nn.Parameter(multiplier * second_arg)
        self.output = nn.Parameter(multiplier * output)
        self.instruction = nn.Parameter(multiplier * instruction) 
        self.registers = nn.Parameter(multiplier * initial_registers)
        IR = torch.zeros(M)
        IR[0] = 1
        self.IR = nn.Parameter(multiplier * IR)
                
        # Machine initialization
        self.machine = Machine(M, R)
        self.softmax = nn.Softmax(0)
    
    
    def forward(self, input, forward_train):
        if forward_train:
            return self.forward_train(input)
        else:
            return self.forward_predict(self, input)
        
    def forward_train(self, input, output):
        """
        Runs the controller on a certain input memory matrix. It returns the loss.
        
        :param initial_memory: The state of memory at the beginning of the program.
        :param output_meory: The desired state of memory at the end of the program.
        :param output_mask: The parts of the output memory that are relevant.
        
        :return: Returns the training loss.
        """
        initial_memory = input
        output_memory = output[0]
        output_mask = output[1]
        # Program's initial memory #TODO: Variable?
        
        self.memory = Variable(initial_memory)
        self.output_memory = Variable(output_memory)
        self.output_mask = Variable(output_mask)
        self.stop_probability = Variable(torch.zeros(1))
        
        # Copy registers so we aren't using the values from the previous iteration. Also
        # make both registers and IR into a probability distribution.
        registers = nn.Softmax(1)(self.registers)
        IR = self.softmax(self.IR)
        
        # loss initialization
        self.confidence = 0
        self.efficiency = 0
        self.halting = 0
        self.correctness = 0
        
        t = 0 
        # Run the program, one timestep at a time, until the program terminates or whe time out
        while t < self.t_max and float(self.stop_probability) < self.stop_threshold: 
            
            a = self.softmax(torch.matmul(self.first_arg, IR))
            b = self.softmax(torch.matmul(self.second_arg, IR))
            o = self.softmax(torch.matmul(self.output, IR))
            e = self.softmax(torch.matmul(self.instruction, IR))
                        
            # Update memory, registers, and IR after machine operation
            self.old_stop_probability = self.stop_probability
            self.memory, registers, IR, new_stop_prob = self.machine(e, a, b, o, self.memory, registers, IR)
            
            self.stop_probability += new_stop_prob
            self.timestep_loss(t)
            t += 1
        
        self.final_loss(t)
        self.times.append(t)
#         return self.memory, self.total_loss()
        return self.total_loss()
    
    
    def test_train(self, initial_memory):
        """
        Runs the controller on a certain input memory matrix. It returns the loss.
        
        :param initial_memory: The state of memory at the beginning of the program.
        :param output_meory: The desired state of memory at the end of the program.
        :param output_mask: The parts of the output memory that are relevant.
        
        :return: Returns the training loss.
        """
        # Program's initial memory #TODO: Variable?
        
        self.memory = Variable(initial_memory)
        self.stop_probability = Variable(torch.zeros(1))
        
        # Copy registers so we aren't using the values from the previous iteration. Also
        # make both registers and IR into a probability distribution.
        registers = nn.Softmax(1)(self.registers)
        IR = self.softmax(self.IR)
        
        # loss initialization
        self.confidence = 0
        self.efficiency = 0
        self.halting = 0
        self.correctness = 0
        
        t = 0 
        # Run the program, one timestep at a time, until the program terminates or whe time out
        while t < self.t_max and float(self.stop_probability) < self.stop_threshold: 
            
            a = self.softmax(torch.matmul(self.first_arg, IR))
            b = self.softmax(torch.matmul(self.second_arg, IR))
            o = self.softmax(torch.matmul(self.output, IR))
            e = self.softmax(torch.matmul(self.instruction, IR))
            
            # Update memory, registers, and IR after machine operation
            self.old_stop_probability = self.stop_probability
            self.memory, registers, IR, new_stop_prob = self.machine(e, a, b, o, self.memory, registers, IR)
            self.stop_probability += new_stop_prob

            t += 1
            print("INTERMED", registers)
        
        self.times.append(t)
        print("REGS", registers)
        print("T", t)
        return self.memory
        
    
    def forward_prediction(self, input):
        """
        Runs the controller on a certain input memory matrix. It returns the output memory matrix.
        
        :param initial_memory: The state of memory at the beginning of the program.
        
        :return: Returns the output memory matrix.
        """
        memory = input[0]
        # Program's initial memory
        self.memory = memory
        self.stop_probability = 0
        
        # Copy registers so we aren't using the values from the previous iteration. Also
        # make both registers and IR into a probability distribution.
        registers = nn.Softmax(1)(self.registers)
        IR = self.softmax(self.IR)
        
        t = 0 
        
        # Run the program, one timestep at a time, until the program terminates or whe time out
        while t < self.t_max and self.stop_probability < self.stop_threshold: 
            
            
            
            a = self.softmax(torch.matmul(self.first_arg, IR))
            b = self.softmax(torch.matmul(self.second_arg, IR))
            o = self.softmax(torch.matmul(self.output, IR))
            e = self.softmax(torch.matmul(self.instruction, IR))
                        
            # Update memory, registers, and IR after machine operation
            self.old_stop_probability = self.stop_probability
            self.memory, registers, IR, new_stop_prob = self.machine(e, a, b, o, self.memory, registers, IR) 
            
            self.stop_probability += new_stop_prob
            t += 1
        
        return self.memory, None
    
    def timestep_loss(self, t):
        # Confidence Loss 
        mem_diff = self.output_memory - self.memory
        correctness = torch.sum(self.output_mask * mem_diff * mem_diff)
        self.confidence += (self.stop_probability - self.old_stop_probability) * correctness
        
        # Efficiency Loss
        if float(self.stop_probability) < self.stop_threshold: # don't add efficiency loss if it stops
            self.efficiency += (1 - self.stop_probability)
            
        
            
    
    def final_loss(self, t):
        # Correctness loss
        mem_diff = self.output_memory - self.memory
        self.correctness = torch.sum(self.output_mask * mem_diff * mem_diff)

        # Halting loss
        if t == self.t_max:
            self.halting = (1 - self.stop_probability)

    def total_loss(self):
        """ compute four diferent loss functions and return a weighted average of the four measuring correctness, 
        halting, efficiency, and confidence"""
        print("confidence", float(self.confidence * self.confidence_weight))
        print("efficiency", float(self.efficiency * self.efficiency_weight))
        print("halting", float(self.halting * self.halting_weight))
        print("correctness", float(self.correctness * self.correctness_weight))
        
        return  (self.correctness*self.correctness_weight) + (self.confidence_weight*self.confidence) + (self.halting_weight*self.halting) + (self.efficiency_weight*self.efficiency)     

In [7]:
class Operation(nn.Module):
    """
    Parent class for our binary operations
    """
    def __init__(self, M):
        """
        Initialize the memory length (needed so we can mod our answer in case it exceeds the range 0-M-1)
        Also calculate the output matrix for the operation
        
        :param M: Memory length
        """
        super(Operation, self).__init__()
        self.M = M
        
        # Create a MxMxM matrix where the (i,j,k) cell is 1 iff operation(i,j) = k.
        self.outputs = torch.IntTensor(M, M, M).zero_()
        for i in range(M):
            for j in range(M):
                val = self.compute(i, j)
                self.outputs[i][j][val] = 1
                
        self.outputs = Variable(self.outputs)
    
    def compute(self, x, y):
        """ 
        Perform the binary operation.  The arguments may or may not be used.
        
        :param x: First argument
        :param y: Second argument
        """
        raise NotImplementedError
    
    def forward(self):
        """
        :return: The output matrix
        """
        return self.outputs

In [8]:
class Add(Operation):

    def __init__(self, M):
        super(Add, self).__init__(M)
    
    def compute(self, x, y):
        return (x + y) % self.M


In [9]:
class Stop(Operation):
    
    def __init__(self, M):
        super(Stop, self).__init__(M)

    def compute(self, _1, _2):
        return 0

In [10]:
class Jump(Operation):
    
    def __init__(self, M):
        super(Jump, self).__init__(M)

    def compute(self, _1, _2):
        return 0 # Actual jump happens in the Machine class

In [11]:
class Decrement(Operation):
    
    def __init__(self, M):
        super(Decrement, self).__init__(M)

    def compute(self, x, _):
        return (x - 1) % self.M

In [12]:
class Increment(Operation):
    
    def __init__(self, M):
        super(Increment, self).__init__(M)

    def compute(self, x, _):
        return (x + 1) % self.M

In [13]:
class Max(Operation):
    
    def __init__(self, M):
        super(Max, self).__init__(M)

    def compute(self, x, y):
        return max(x,y)

In [14]:
class Min(Operation):
    
    def __init__(self, M):
        super(Min, self).__init__(M)

    def compute(self, x, y):
        return min(x,y)

In [15]:
class Read(Operation):
    
    def __init__(self, M):
        super(Read, self).__init__(M)
        # Leave output matrix blank since we're gonna do the reading elsewhere
        self.outputs = torch.zeros(M, M, M)

    def compute(self, x, _):
        return 0 # Actual reading happens in the Machine class

In [16]:
class Subtract(Operation):
    
    def __init__(self, M):
        super(Subtract, self).__init__(M)

    def compute(self, x, y):
        return (x - y) % self.M

In [17]:
class Write(Operation):
    
    def __init__(self, M):
        super(Write, self).__init__(M)

    def compute(self, x, y):
        return 0 # Actual write happens in the Machine class

In [18]:
class Zero(Operation):
    
    def __init__(self, M):
        super(Zero, self).__init__(M)

    def compute(self, _1, _2):
        return 0

In [19]:
class Machine(nn.Module):
    """
    The Machine executes assembly instructions passed to it by the Controller.
    It updates the given memory, registers, and instruction pointer.
    The Machine doesn't have any learnable parameters.
    """
    def __init__(self, M, R):
        """
        Initializes dimensions, operations, and counters
        
        :param M: Memory length.  Integer values also take on values 0-M-1.  M is also the program length.
        :param R: Number of registers
        """
        super(Machine, self).__init__()
        
        # Store parameters as class variables
        self.R = R # Number of registers
        self.M = M # Memory length (also largest number)
        
        # Start off with 0 probability of stopping
        self.stop_probability = 0 
        
        # List of ops (must be in same order as the original ANC paper so compilation works right)
        self.ops = [ 
            Stop(M),
            Zero(M),
            Increment(M),
            Add(M),
            Subtract(M),
            Decrement(M),
            Min(M),
            Max(M),
            Read(M),
            Write(M),
            Jump(M)
        ]
        
        # Number of instructions
        self.N = len(self.ops)
        
        # Create a 4D matrix composed of the output matrices of each of the ops
        self.outputs = Variable(torch.zeros(self.N, self.M, self.M, self.M))
        
        for i in range(self.N):
            op = self.ops[i]
            self.outputs[i] = op()
                
        # Keep track of ops which will be handled specially
        self.jump_index = 10
        self.stop_index = 0
        self.write_index = 9
        self.read_index = 8 
        
    def forward(self, e, a, b, o, memory, registers, IR):
        
        """
        Run the Machine for one timestep (corresponding to the execution of one line of Assembly).
        The first four parameter names correspond to the vector names used in the original ANC paper
        
        :param e: Probability distribution over the instruction being executed (M)
        :param a: Probability distribution over the first argument register (length R)
        :param b: Probability distribution over the second argument register (length R)
        :param o: Probability distribution over the first argument register (length R)
        :param memory: Memory matrix (size MxM)
        :param registers: Register matrix (size RxM)
        :param IR: Instruction Register (length M)
        
        :return: The memory, registers, and instruction register after the timestep
        """
        
        # Calculate distributions over the two argument values by multiplying each 
        # register by the probability that register is being used.
        arg1 = torch.matmul(a, registers)
        arg2 = torch.matmul(b, registers)
        
        # Multiply the output matrix by the arg1, arg2, and e vectors. Also take care
        # of doing the read.
        
        arg1_long = arg1.view(1, -1, 1, 1)
        arg2_long = arg2.view(1, 1, -1, 1)
        instr = e.view(-1, 1, 1, 1)
        read_vec =  e[self.read_index] * torch.matmul(arg1, memory)
        out_vec = (self.outputs * arg1_long * arg2_long * instr).sum(0).sum(0).sum(0) + read_vec      
        out_vec = out_vec.squeeze(0)
    
        # Update our memory, registers, instruction register, and stopping probability
        memory = self.writeMemory(e, memory, arg1, arg2)
        registers = self.writeRegisters(out_vec, o, registers)
        IR = self.updateIR(e, IR, arg1, arg2)
        stop_prob = self.getStop(e)
        
        return(memory, registers, IR, stop_prob)
             
    def writeRegisters(self, out, o, registers):
        """
        Write the result of our operation to our registers.
        
        :param out: Probability distribution over the output value (M)
        :param o: Probability distribution over the output register (R)
        :param Registers: register matrix (RxM)
        
        :return: The updated registers (RxM)
        """
        # Multiply probability of not writing with old registers and use an outer product
        return (1 - o).unsqueeze(1) * registers + torch.ger(o, out)
    
    def updateIR(self, e, IR, arg1, arg2):
        """
        Update the instruction register
        
        :param e: Distribution over the current instruction (N)
        :param IR: Instruction register (length M)
        :param arg1: Distribution over the first argument value (length M)
        :param arg2: Distribution over the second argument value (length M)
        
        :return: The updated instruction register (BxMx1)
        """
        # probability of actually jumping
        cond = e[self.jump_index] * arg1[0]
        
        # Take a weighted sum of the instruction register with and without jumping
        return torch.cat([IR[-1], IR[:-1]], 0) * (1 - cond) + arg2 * cond
    
    def writeMemory(self, e, mem_orig, arg1, arg2):
        """
        Update the memory
        
        :param e: Distribution over the current instruction (M)
        :param mem_orig: Current memory matrix (MxM)
        :param arg1: Distribution over the first argument value (M)
        :param arg2: Distribution over the second argument value (M)
        
        :return: The updated memory matrix (MxM)
        """
        
        # Probability that we're on the write instruction
        write_probability = e[self.write_index]
        mem_write = torch.ger(arg1, arg2) 
        mem_write = mem_write + (1 - arg1).unsqueeze(1) * mem_orig
        
        return mem_orig * (1 - write_probability) + write_probability * mem_write

    def getStop(self, e):
        """
        Obtain the probability that we will stop at this timestep based on the probability that we are running the STOP op.
        
        :param e: distribution over the current instruction (length M)
        
        :return: probability representing whether the controller should stop.
        """
        return e[self.stop_index]

In [20]:
def one_hotify(vec, number_of_classes, dimension):
    """
    Turn a tensor of integers into a matrix of one-hot vectors.
    
    :param vec: The vector to be converted.
    :param number_of_classes: How many possible classes the one hot vectors encode.
    :param dimension: Which dimension stores the elements of vec.  If 0, they're stored in the rows.  If 1, the columns.
    
    :return A matrix of one-hot vectors, each row or column corresponding to one element of vec
    """
    num_vectors = vec.size()[0]
    binary_vec = torch.zeros(num_vectors, number_of_classes)
    for i in range(num_vectors):
        binary_vec[i][vec[i]] = 1
    if dimension == 1:
        binary_vec.t_()
    
    return binary_vec

In [21]:
# # Addition task
# # Generate this by running the instructions here (but with the addition program file): https://github.com/aditya-khant/neural-assembly-compiler
# # Then get rid of the .cuda in each of the tensors since we (or at least I) don't have cuda
# init_registers = torch.IntTensor([6,2,0,1,0,0]) # Length R, should be RxM
# first_arg = torch.IntTensor([4,3,3,3,4,2,2,5]) # Length M, should be RxM
# second_arg = torch.IntTensor([5,5,0,5,5,1,4,5]) # Length M, should be RxM
# target = torch.IntTensor([4,3,5,3,4,5,5,5]) # Length M, should be RxM
# instruction = torch.IntTensor([8,8,10,5,2,10,9,0]) # Length M, should be NxM

# Increment task
init_registers = torch.IntTensor([6,0,0,0,0,0,0])
first_arg = torch.IntTensor([5,1,1,5,5,4,6])
second_arg = torch.IntTensor([6,0,6,3,6,2,6])
target = torch.IntTensor([1,6,3,6,5,6,6])
instruction = torch.IntTensor([8,10,2,9,2,10,0])

# # Access task
# init_registers = torch.IntTensor([0,0,0])
# first_arg = torch.IntTensor([0,1,1,0,2])
# second_arg = torch.IntTensor([2,2,2,1,2])
# target = torch.IntTensor([1,1,1,2,2])
# instruction = torch.IntTensor([8,2,8,9,0])



# Get dimensions we'll need
M = first_arg.size()[0]
R = init_registers.size()[0]
N = 11

# Turn the given tensors into matrices of one-hot vectors.
init_registers = one_hotify(init_registers, M, 0)
first_arg = one_hotify(first_arg, R, 1)
second_arg = one_hotify(second_arg, R, 1)
target = one_hotify(target, R, 1)
instruction = one_hotify(instruction, N, 1)

In [22]:
class AddTaskDataset(data.Dataset):
    def __init__(self, M, num_examples):
        """
        Generate a dataset for the addition task by randomly choosing two numbers in the allowed range
        and creating the initial/final matrices for adding them.
        
        :param M: The allowable range of integers (from 0 to M-1)
        :param num_examples: The number of training examples to be generated
        """
        
        self.input_list = []
        
        for i in range(num_examples):
            first_addend = random.randint(0, M-1)
            second_addend = random.randint(0, M-1)
            initial_memory = torch.zeros(M, M)
            initial_memory[0][first_addend] = 1
            initial_memory[1][second_addend] = 1
            for j in range(2, M):
                initial_memory[j][0] = 1

            
            output_memory = torch.zeros(M, M)
            output_memory[0][(first_addend + second_addend) % M] = 1

            # Output mask has ones in the row of the memory matrix where the answer will be stored.
            output_mask = torch.zeros(M, M)
            output_mask[0] = torch.ones(M)
            
            self.input_list.append((initial_memory, output_memory, output_mask))
       
    def __len__(self):
        return len(self.input_list)
    
    def __getitem__(self, i):
        """
        Get the i^th element of the dataset.
        
        :param i: The index of the element to be returned.
        :return A tuple containing i^th element of the dataset.
        """
        return self.input_list[i]

In [23]:
class TrivialAddTaskDataset(data.Dataset):
    def __init__(self, M, num_examples):
        """
        Generate a dataset for the addition task by randomly choosing two numbers in the allowed range
        and creating the initial/final matrices for adding them.
        
        :param M: The allowable range of integers (from 0 to M-1)
        :param num_examples: The number of training examples to be generated
        """
        
        self.input_list = []
        
        for i in range(num_examples):
            first_addend = random.randint(0, M-1)
            second_addend = random.randint(0, M-1)
            initial_memory = torch.FloatTensor(M, M).zero_()
            initial_memory[0][first_addend] = 1
            initial_memory[1][second_addend] = 1
            for j in range(2, M):
                initial_memory[j][0] = 1

            
            output_memory = torch.FloatTensor(M, M).zero_()
            output_memory[0][(first_addend + second_addend) % M] = 1

            # Output mask has ones in the rows of the memory matrix where the answer will be stored.
            output_mask = torch.FloatTensor(M, M).zero_()
            output_mask[2] = torch.ones(M)
            
            self.input_list.append((initial_memory, output_memory, output_mask))
       
    def __len__(self):
        return len(self.input_list)
    
    def __getitem__(self, i):
        """
        Get the i^th element of the dataset.
        
        :param i: The index of the element to be returned.
        :return A tuple containing i^th element of the dataset.
        """
        return self.input_list[i]

In [46]:
class IncTaskDataset(data.Dataset):
    def __init__(self, M, list_len, num_examples):
        """
        Generate a dataset for the list task by randomly choosing two numbers in the allowed range
        and creating the initial/final matrices for adding them.
        
        :param M: The allowable range of integers (from 0 to M-1)
        :param list_len: The list length
        :param num_examples: The number of training examples to be generated
        """
        
        if list_len > M:
            raise ValueError("Cannot have a list longer than M")
        
        self.input_list = []
        self.output_list = []
        
        for i in range(num_examples):
#             list_val = random.randint(1, M-1)
            list_val = i % M
            initial_memory = torch.zeros(M, M)
            output_memory = torch.zeros(M, M)
            # Output mask is length of the list itself
            output_mask = torch.zeros(M, M)
            
            for i in range(list_len):
                initial_memory[i][list_val] = 1
                output_memory[i][(list_val + 1 ) % M] = 1
                output_mask[i] = torch.ones(M)
                
            for j in range(list_len, M):
                initial_memory[j][0] = 1
            
#             self.input_list.append((initial_memory, output_memory, output_mask))
            self.input_list.append(initial_memory)
            print("IM", initial_memory)
            self.output_list.append((output_memory, output_mask))
       
    def __len__(self):
        return len(self.input_list)
    
    def __getitem__(self, i):
        """
        Get the i^th element of the dataset.
        
        :param i: The index of the element to be returned.
        :return A tuple containing i^th element of the dataset.
        """
        return self.input_list[i], self.output_list[i]

In [25]:
class AccessTaskDataset(data.Dataset):
    def __init__(self, M, num_examples):
        """
        Generate a dataset for the access task by randomly generating an array.
        The task is to access the 3rd element of the array
        
        :param M: The allowable range of integers (from 0 to M-1)
        :param num_examples: The number of training examples to be generated
        """
        self.input_list = []
        
        for i in range(num_examples):
            
            initial_memory = torch.zeros(M, M)
            output_memory = torch.zeros(M, M)
            
            # Set the initial memory
            for i in range(1,M):
                list_val = random.randint(0, M-1)
                initial_memory[i][list_val] = 1
                
                if i == 4:
                    output_memory[0, list_val] = 1
            
            # Get 3rd element of array
            initial_memory[0, 3] = 1
            
            # Output mask is length of the list itself
            output_mask = torch.zeros(M, M)
            output_mask[0] = torch.ones(M)
            
            self.input_list.append((initial_memory, output_memory, output_mask))
       
    def __len__(self):
        return len(self.input_list)
    
    def __getitem__(self, i):
        """
        Get the i^th element of the dataset.
        
        :param i: The index of the element to be returned.
        :return A tuple containing i^th element of the dataset.
        """
        return self.input_list[i]

In [26]:
class TrivialAccessTaskDataset(data.Dataset):
    def __init__(self, M, num_examples):
        """
        Generate a dataset for the access task by randomly generating an array.
        The task is to access the 3rd element of the array
        
        :param M: The allowable range of integers (from 0 to M-1)
        :param num_examples: The number of training examples to be generated
        """
        self.input_list = []
        
        for i in range(num_examples):
            
            initial_memory = torch.zeros(M, M)
            output_memory = torch.zeros(M, M)
            
            # Set the initial memory
            for i in range(1,M):
                list_val = random.randint(0, M-1)
                initial_memory[i][list_val] = 1
                
#                 if i == 4:
#                     output_memory[0, list_val] = 1
            
            # Get 3rd element of array
            initial_memory[0, 3] = 1
            output_memory[0, 4] = 1
            
            # Output mask is length of the list itself
            output_mask = torch.zeros(M, M)
            output_mask[0] = torch.ones(M)
            
            self.input_list.append((initial_memory, output_memory, output_mask))
       
    def __len__(self):
        return len(self.input_list)
    
    def __getitem__(self, i):
        """
        Get the i^th element of the dataset.
        
        :param i: The index of the element to be returned.
        :return A tuple containing i^th element of the dataset.
        """
        return self.input_list[i]

In [50]:
num_examples = 10 #7200

# M = 8 # Don't change this (as long as we're using the add-task)
# dataset = AddTaskDataset(M, num_examples)
# dataset = TrivialAddTaskDataset(M, num_examples)

M = 7 # Don't change this (as long as we're using the inc-task)
dataset = IncTaskDataset(M, 5, num_examples)

# M = 5
# dataset = AccessTaskDataset(M, num_examples)

# M = 5
# dataset = TrivialAccessTaskDataset(M, num_examples)

data_loader = data.DataLoader(dataset, batch_size = 1) # Don't change this batch size.  You have been warned.

def anc_validation_criterion(output, label):
    initial_memory = label[0]
    target_memory = label[1]
    target_mask = label[2]
    
    output2 = output.data * target_mask
    target_memory = target_memory * target_mask
    _, initial_indices = torch.max(initial_memory, 2)
    _, target_indices = torch.max(target_memory, 2)
    _, output_indices = torch.max(output2, 2)
    _, unmasked_indices = torch.max(output.data, 2)
    return 1 - torch.equal(output_indices, target_indices)

plot_every = 10


IM 
    1     0     0     0     0     0     0
    1     0     0     0     0     0     0
    1     0     0     0     0     0     0
    1     0     0     0     0     0     0
    1     0     0     0     0     0     0
    1     0     0     0     0     0     0
    1     0     0     0     0     0     0
[torch.FloatTensor of size 7x7]

IM 
    0     1     0     0     0     0     0
    0     1     0     0     0     0     0
    0     1     0     0     0     0     0
    0     1     0     0     0     0     0
    0     1     0     0     0     0     0
    1     0     0     0     0     0     0
    1     0     0     0     0     0     0
[torch.FloatTensor of size 7x7]

IM 
    0     0     1     0     0     0     0
    0     0     1     0     0     0     0
    0     0     1     0     0     0     0
    0     0     1     0     0     0     0
    0     0     1     0     0     0     0
    1     0     0     0     0     0     0
    1     0     0     0     0     0     0
[torch.FloatTensor of size 7x7]

IM 
   

In [51]:
# Initialize our controller
controller = Controller(first_arg = first_arg, 
                        second_arg = second_arg, 
                        output = target, 
                        instruction = instruction, 
                        initial_registers = init_registers, 
                        stop_threshold = .9, 
                        multiplier = 5,
                        correctness_weight = 1, 
                        halting_weight = 5, 
                        efficiency_weight = 0.1, 
                        confidence_weight = 0.5, 
                        t_max = 50) 

# Learning rate is a tunable hyperparameter. The paper used 1 or 0.1.
optimizer = optim.Adam(controller.parameters(), lr = 1)

best_model, train_plot_losses, validation_plot_losses = training.train_model_anc(
    controller, 
    data_loader,  
    optimizer, 
    num_epochs = 1, 
    print_every = 10, 
    plot_every = plot_every, 
    deep_copy_desired = False, 
#     validation_criterion = anc_validation_criterion, 
#     forward_train = True, 
    batch_size = 5) # In the paper, they used batch sizes of 1 or 5
    
    #kangaroo

Epoch 0/0
----------
LR is set to 0.005
confidence 0.0
efficiency 0.23650184273719788
halting 0.0
correctness 9.950525283813477
confidence 0.0
efficiency 1.0624462366104126
halting 0.0
correctness 7.775989055633545
confidence 0.0
efficiency 1.0632847547531128
halting 0.0
correctness 7.558948516845703
confidence 0.0
efficiency 1.061649203300476
halting 0.0
correctness 7.561375617980957
confidence 0.0
efficiency 1.051038146018982
halting 0.0
correctness 7.73557186126709
confidence 0.0
efficiency 1.0442863702774048
halting 0.0
correctness 7.881731986999512
confidence 0.0
efficiency 1.0539664030075073
halting 0.0
correctness 7.1672163009643555
confidence 0.0
efficiency 0.23617839813232422
halting 0.0
correctness 9.950095176696777
confidence 0.0
efficiency 1.0642060041427612
halting 0.0
correctness 7.754860877990723
confidence 0.0
efficiency 1.0651696920394897
halting 0.0
correctness 7.535272598266602
Epoch Number: 0, Batch Number: 10, Training Loss: 8.9810
Time so far is 0m 0s

Training co

In [None]:
# plt.plot([x * plot_every for x in range(len(train_plot_losses))], train_plot_losses)
# plt.show()

In [None]:
# plt.plot(range(len(controller.times)), controller.times)
# plt.show()

In [None]:
# plt.plot([x * plot_every for x in range(len(validation_plot_losses))], validation_plot_losses)
# plt.show()

In [None]:

# cutoff = 0.7

# def getBest(vec):
#     maxVal, index = torch.max(vec, 0)
#     if maxVal.data[0] > cutoff:
#         return index.data[0]

# def bestRegister(vec):
#     index = getBest(vec)
#     if index is not None:
#         return "R" + str(1 + index)
#     return "??"
    
# def bestInstruction(vec):
#     ops = [ 
#         "STOP",
#         "ZERO",
#         "INC",
#         "ADD",
#         "SUB",
#         "DEC",
#         "MIN",
#         "MAX",
#         "READ",
#         "WRITE",
#         "JEZ"
#     ]
#     index = getBest(vec)
#     if index is not None:
#         return ops[index]
#     return "??"
    
# # registers = controller.registers

# # # Add task
# # orig_register = [6,2,0,1,0,0]
# # orig_output = [4,3,5,3,4,5,5,5]
# # orig_instruction = [8,8,10,5,2,10,9,0]
# # orig_first = [4,3,3,3,4,2,2,5]
# # orig_second = [5,5,0,5,5,1,4,5]
# # orig_ir = [1,0,0,0,0,0,0,0]

# # # INC Task
# # orig_register = [6,0,0,0,0,0,0]
# # orig_first = [5,1,1,5,5,4,6]
# # orig_second = [6,0,6,3,6,2,6]
# # orig_output = [1,6,3,6,5,6,6]
# # orig_instruction = [8,10,2,9,2,10,0]

# # Access Task
# orig_register = [0,0,0]
# orig_first = [0,1,1,0,2]
# orig_second = [2,2,2,1,2]
# orig_output = [1,1,1,2,2]
# orig_instruction = [8,2,8,9,0]
# orig_ir = [1,0,0,0,0]


# R, M = controller.registers.size()
    
# def printProgram():   
    
#     print("IR = " + str(getBest(controller.IR)))
    
#     # Print registers
#     for i in range(R):
#         print("R" + str(i + 1) + " = " + str(getBest(controller.registers[i,:])))

#     print()

#     # Print the actual program
#     for i in range (M):
#         print(bestRegister(controller.output[:, i]) + " = " + 
#               bestInstruction(controller.instruction[:, i]) + "(" +
#               bestRegister(controller.first_arg[:, i]) + ", " +
#               bestRegister(controller.second_arg[:, i]) + ")")




    
# def compareOutput():
#     # compare our output to theirs
#     # we get one point for every matching number
#     match_count = 0
#     softmax = nn.Softmax(0)
#     for i in range(R):
#         if getBest(nn.Softmax(1)(controller.registers)[i,:]) == orig_register[i]:
#             match_count += 1
#     for i in range (M):
#         if getBest(softmax(controller.output)[:, i]) == orig_output[i]:
#             match_count += 1
#         if getBest(softmax(controller.instruction)[:, i]) == orig_instruction[i]:
#             match_count += 1
#         if getBest(softmax(controller.first_arg)[:, i]) == orig_first[i]:
#             match_count += 1
#         if getBest(softmax(controller.second_arg)[:, i]) == orig_second[i]:
#             match_count += 1
#     if getBest(softmax(controller.IR)) == orig_ir:
#         match_count += 1

#     percent_orig = match_count / (len(orig_register) + len(orig_output) + 
#                                            len(orig_instruction) + len(orig_first) + len(orig_second) + 1)
#     return percent_orig
#     print("PERCENT MATCH", percent_orig)
    
# printProgram()
# compareOutput()

# # Original Add Program   
# # R1 = 6
# # R2 = 2
# # R3 = 0
# # R4 = 1
# # R5 = 0
# # R6 = 0


# # R5 = READ(R5, R6)
# # R4 = READ(R4, R6)
# # R6 = JEZ(R4, R1)
# # R4 = DEC(R4, R6)
# # R5 = INC(R5, R6)
# # R6 = JEZ(R3, R2)
# # R6 = WRITE(R3, R5)
# # R6 = STOP(R6, R6)

# #koala

In [None]:
# Test a bunch of times
num_trials = 20

num_original_convergences = 0
num_0_losses = 0
num_better_convergences = 0
otherPrograms = []
for i in range(num_trials):
    print("Trial ", i)
    best_model, train_plot_losses, validation_plot_losses = training.train_model_anc(
        controller, 
        data_loader,  
        optimizer, 
        num_epochs = 15, 
        print_every = 50, 
        plot_every = plot_every, 
        deep_copy_desired = False, 
        validation_criterion = anc_validation_criterion, 
        forward_train = True, 
        batch_size = 5) # In the paper, they used batch sizes of 1 or 5
    percent_orig = compareOutput()
    if percent_orig > .99:
        num_original_convergences += 1
    end_losses = validation_plot_losses[-2:]
    if sum(end_losses) < .01:
        num_0_losses += 1
    if percent_orig < .99 and sum(end_losses) < .01:
        num_better_convergences += 1
        otherPrograms.append((controller.output, controller.instruction, controller.first_arg, controller.second_arg, controller.registers))
print("LOSS CONVERGENCES", num_0_losses * 1.0 / num_trials)
print("ORIG CONVERGENCES", num_original_convergences * 1.0 / num_trials)
print("BETTER CONVERGENCES", num_better_convergences * 1.0 / num_trials)

    

In [None]:
softmax = nn.Softmax(1)
print(softmax(controller.instruction))
print(controller.memory)

In [None]:
# printProgram()
softmax = nn.Softmax(0)
print(softmax(controller.output))
print(softmax(controller.first_arg))
print(softmax(controller.second_arg))
print(softmax(controller.instruction))
print(nn.Softmax(1)(controller.registers))

In [None]:
# IGNORE THIS... WORK IN PROGRESS
# self.ops = [ 
#     0 Stop(M),
#     1 Zero(M),
#     2 Increment(M),
#     3 Add(M),
#     4 Subtract(M),
#     5 Decrement(M),
#     6 Min(M),
#     7 Max(M),
#     8 Read(M),
#     9 Write(M),
#     10 Jump(M)
# ]
N = 11


# # Stop Test
# M = 4
# R = 3
# init_registers = torch.IntTensor([0,0,0])
# first_arg = torch.IntTensor([0,0,0,0])
# second_arg = torch.IntTensor([0,0,0,0])
# target = torch.IntTensor([0,0,0,0])
# instruction =  torch.IntTensor([3, 0, 0, 0]) # OK


# init_registers = one_hotify(init_registers, M, 0)
# first_arg = one_hotify(first_arg, R, 1)
# second_arg = one_hotify(second_arg, R, 1)
# target = one_hotify(target, R, 1)
# instruction = one_hotify(instruction, N, 1)
# instruction[0,1] = 0.5
# instruction[0,2] = 0.5
# instruction[2,1] = 0.5
# instruction[2,2] = 0.5

# memory = torch.IntTensor([0,0,0,0])
# memory = one_hotify(memory, M, 0)

# # What we expect: stops after 3 iterations; reg should have  [0.5, 0.25, 0.25]


# # Write test

# M = 2
# R = 3
# init_registers = torch.IntTensor([1,1,0])
# first_arg = torch.IntTensor([1,0])
# second_arg = torch.IntTensor([0,0])
# target = torch.IntTensor([0,0])
# instruction =  torch.IntTensor([0,0]) #

# init_registers = one_hotify(init_registers, M, 0)
# first_arg = one_hotify(first_arg, R, 1)
# second_arg = one_hotify(second_arg, R, 1)
# target = one_hotify(target, R, 1)
# instruction = one_hotify(instruction, N, 1)
# instruction[0,0] = 0.5
# instruction[9,0] = 0.5

# memory = torch.IntTensor([0,0])
# memory = one_hotify(memory, M, 0)

# # What we expect: stops after 2 iterations; index 1 of memory should have value (0:.5; 1:.5)


# # Read test

# M = 2
# R = 3
# init_registers = torch.IntTensor([0,0,0])
# first_arg = torch.IntTensor([0,0,0,0,0,0])
# second_arg = torch.IntTensor([0,0,0,0,0,0])
# target = torch.IntTensor([0,0,0,0,0,0])
# instruction =  torch.IntTensor([0,10,0,0,2,0])

# init_registers = one_hotify(init_registers, M, 0)
# first_arg = one_hotify(first_arg, R, 1)
# second_arg = one_hotify(second_arg, R, 1)
# target = one_hotify(target, R, 1)
# instruction = one_hotify(instruction, N, 1)
# instruction[0,0] = 0.5
# instruction[9,0] = 0.5

# memory = torch.IntTensor([0,0])
# memory = one_hotify(memory, M, 0)

# # What we expect: stops after 2 iterations; index 1 of memory should have value (0:.5; 1:.5)


# Normal ops test

M = 5
R = 3
init_registers = torch.IntTensor([0,0,0])
first_arg = torch.IntTensor([0,0,0,0,0])
second_arg = torch.IntTensor([1,1,1,1,1])
target = torch.IntTensor([0,0,0,0,0])
instruction =  torch.IntTensor([1,3,4,7,0])

init_registers = one_hotify(init_registers, M, 0)
first_arg = one_hotify(first_arg, R, 1)
second_arg = one_hotify(second_arg, R, 1)
target = one_hotify(target, R, 1)
instruction = one_hotify(instruction, N, 1)
# zero, inc
instruction[1,0] = 0.5
instruction[2,0] = 0.5

# add, dec
instruction[3,1] = 0.5
instruction[5,1] = 0.5

# sub, min
instruction[4,2] = 0.5
instruction[6,2] = 0.5

# max, write
instruction[7,3] = 0.5
instruction[9,3] = 0.5


memory = torch.IntTensor([0,1,2,3,4])
memory = one_hotify(memory, M, 0)

# What we expect: stops after 2 iterations; index 1 of memory should have value (0:.5; 1:.5)





# Initialize our controller
controller = Controller(first_arg = first_arg, 
                        second_arg = second_arg, 
                        output = target, 
                        instruction = instruction, 
                        initial_registers = init_registers, 
                        stop_threshold = .9, 
                        multiplier = 50,
                        correctness_weight = 10, 
                        halting_weight = 0, 
                        efficiency_weight = 1, 
                        confidence_weight = 0, 
                        t_max = 50) 



In [None]:
N=11


# # AddTest
# M = 3
# R = 4

# init_registers = torch.IntTensor([0,1,0,0])
# first_arg = torch.IntTensor([2,0,3])
# second_arg = torch.IntTensor([1,2,3])
# target = torch.IntTensor([2,3,3])
# instruction = torch.IntTensor([3,9,0])

# # # DecTest
# M = 3
# R = 3
# init_registers = torch.IntTensor([1,0,0])
# first_arg = torch.IntTensor([0,1,2])
# second_arg = torch.IntTensor([2,0,2])
# target = torch.IntTensor([0,2,2])
# instruction = torch.IntTensor([5,9,0])

# # IncTest
# M = 3
# R = 3
# init_registers = torch.IntTensor([0,0,0])
# first_arg = torch.IntTensor([1,0,2])
# second_arg = torch.IntTensor([2,1,2])
# target = torch.IntTensor([1,2,2])
# instruction = torch.IntTensor([2,9,0])

# JezTest
M = 3
R = 3
init_registers = torch.IntTensor([2,0,1,0])
first_arg = torch.IntTensor([1,1,2,3])
second_arg = torch.IntTensor([3,0,1,3])
target = torch.IntTensor([1,3,3,3])
instruction = torch.IntTensor([1,10,9,0])

# # MaxTest
# M = 3
# R = 4
# init_registers = torch.IntTensor([2,1,0,0])
# first_arg = torch.IntTensor([1,2,3])
# second_arg = torch.IntTensor([0,1,3])
# target = torch.IntTensor([1,3,3])
# instruction = torch.IntTensor([7,9,0])

# # MinTest
# M = 3
# R = 4
# init_registers = torch.IntTensor([1,2,0,0])
# first_arg = torch.IntTensor([1,2,3])
# second_arg = torch.IntTensor([0,1,3])
# target = torch.IntTensor([1,3,3])
# instruction = torch.IntTensor([6,9,0])

# ReadTest
# M=3
# R=3
# init_registers = torch.IntTensor([0,0,0])
# first_arg = torch.IntTensor([0,1,2])
# second_arg = torch.IntTensor([2,0,2])
# target = torch.IntTensor([0,2,2])
# instruction = torch.IntTensor([8,9,0])

# # SubTest
# M = 3
# R = 4
# init_registers = torch.IntTensor([1,2,0,0])
# first_arg = torch.IntTensor([1,2,3])
# second_arg = torch.IntTensor([0,1,3])
# target = torch.IntTensor([1,3,3])
# instruction = torch.IntTensor([4,9,0])

# # WriteTest
# M = 2
# R = 3
# init_registers = torch.IntTensor([1,0,0])
# first_arg = torch.IntTensor([1,2])
# second_arg = torch.IntTensor([0,2])
# target = torch.IntTensor([2,2])
# instruction = torch.IntTensor([9,0])

# ZeroTest
M = 3
R = 3
init_registers = torch.IntTensor([0,1,0])
first_arg = torch.IntTensor([1,0,2])
second_arg = torch.IntTensor([2,1,2])
target = torch.IntTensor([1,2,2])
instruction = torch.IntTensor([1,9,0])

init_registers = one_hotify(init_registers, M, 0)
first_arg = one_hotify(first_arg, R, 1)
second_arg = one_hotify(second_arg, R, 1)
target = one_hotify(target, R, 1)
instruction = one_hotify(instruction, N, 1)
initial_memory = torch.zeros(M,M)
initial_memory[:, 2] = 1
# initial_memory[0,0] = 0
# initial_memory[0,2] = 1
                                    
                                    
controller = Controller(first_arg = first_arg, 
                        second_arg = second_arg, 
                        output = target, 
                        instruction = instruction, 
                        initial_registers = init_registers, 
                        stop_threshold = .9, 
                        multiplier = 50,
                        correctness_weight = 10, 
                        halting_weight = 0, 
                        efficiency_weight = 1, 
                        confidence_weight = 0, 
                        t_max = 5) 
print(controller.forward_train(initial_memory, initial_memory, initial_memory))