In [1]:
import torch.nn as nn
import torch
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data
import numpy as np


In [2]:
cd ..

/Users/oliviawatkins/Documents/Schoolwork/NN/neural_nets_research


In [3]:
from neural_nets_library import training

In [4]:
class Controller(nn.Module):
    """
    Contains the two learnable parts of the model in four independent, fully connected layers.
    First the initial values for the registers and instruction registers and second the 
    parameters that computes the required distributions. 
    """

    def __init__(self, 
                 first_arg = None, 
                 second_arg = None, 
                 output = None, 
                 instruction = None, 
                 initial_registers = None, 
                 stop_threshold = .99, 
                 blur = 3, 
                 correctness_weight = .25, 
                 halting_weight = .25, 
                 confidence_weight = .25, 
                 efficiency_weight = .25,
                 t_max = 100):
        #TODO: Read over ANC paper, check if there are more reasonable default initial values.
        """
        Initialize a bunch of constants and pass in matrices defining a program.
        
        :param first_arg: Matrix with the 1st register argument for each timestep stored in the columns (1xRxM)
        :param second_arg: Matrix with the 2nd register argument for each timestep stored in the columns (1xRxM)
        :param output: Matrix with the output register for each timestep stored in the columns (1xRxM)
        :param instruction: Matrix with the instruction for each timestep stored in the columns (1xNxM)
        :param initial_registers: Matrix where each row is a distribution over the value in one register (1xRxM)
        :param stop_threshold: The stop probability threshold at which the controller should stop running
        :param blur: The factor our one-hot vectors are be multiplied by before they're softmaxed to add blur
        :param correctness_weight: Weight given to the correctness component of the loss function
        :param halting_weight: Weight given to the halting component of the loss function
        :param confidence_weight: Weight given to the confidence component of the loss function
        :param efficiency_weight: Weight given to the efficiency component of the loss function
        
        """
        super(Controller, self).__init__()
        
        # Initialize dimension constants
        B, R, M = initial_registers.size()
        self.M = M
        self.R = R
        
        # Initialize loss function weights
        # In the ANC paper, these scalars are called, alpha, beta, gamma, and delta
        self.correctness_weight = correctness_weight
        self.halting_weight = halting_weight
        self.confidence_weight = confidence_weight
        self.efficiency_weight = efficiency_weight
        
        # And yet more initialized constants... yeah, there are a bunch, I know.
        self.t_max = t_max
        self.blur_factor = blur
        self.stop_threshold = stop_threshold
        
        # Blur matrices - i.e. give each different operation/argument/register value some nonzero probability
        if blur is not None:
            first_arg = self.blur(first_arg, blur, 1)
            second_arg = self.blur(second_arg, blur, 1)
            output = self.blur(output, blur, 1)
            instruction = self.blur(instruction, blur, 1)
            initial_registers = self.blur(initial_registers, blur, 2)
            
        # Initialize parameters.  These are the things that are going to be optimized. 
        self.first_arg = nn.Parameter(first_arg.data)
        self.second_arg = nn.Parameter(second_arg.data)
        self.output = nn.Parameter(output.data)
        self.instruction = nn.Parameter(instruction.data) 
        self.registers = nn.Parameter(initial_registers.data)
        
        # Machine initialization
        self.machine = Machine(B, M, R)
    
    def blur(self, matrix, scale_factor, dimension):
        """
        Takes a matrix, each row (or column) of which is a one-hot vector.
        Multiply each 1 by a constant and then softmax it, which 
        effectively "blurs" the matrix a little bit.
        
        :param matrix: Matrix to blur
        :param scale_factor: Constant to multiply the matrix by before it's softmaxed
        :param dimension: Dimension to softmax over
        
        :return: Blurred matrix
        """
        matrix = scale_factor * matrix
        softmax = nn.Softmax(dimension)
        return softmax(Variable(matrix))    
    
    
        
    def forward(self, input, train):
        """
        Runs the controller on a certain input memory matrix.
        It either returns the loss or the output memory.
        
        :param input: A three-tuple of three MxM matrices: (memory matrix, output_memory, output_mask)
        
        :return: If train is true, return the loss. Otherwise, return the output matrix
        """
        # Program's initial memory
        self.memory = input[0]
        # Desired output memory
        self.output_memory = input[1]
        # Mask with 1's in the rows of the output memory matrix which actually contain the answer.
        self.output_mask = input[2]
        
        # Initialize instruction regiser (1xMx1)
        IR = torch.zeros(1, M, 1)
        IR[0][0][0] = 1
        
        # Blur memory and instruction register
        if self.blur_factor is not None:
            self.memory = self.blur(self.memory, self.blur_factor, 2) 
            IR = self.blur(IR, self.blur_factor, 1)
        
        efficiency_loss = 0
        confidence_loss = 0
        self.stop_probability = 0
        
        # Copy registers so we aren't using the values from the previous iteration.
        registers = self.registers
        
        t = 0 
        # Run the program, one timestep at a time, until the program terminates or whe time out
        while t < self.t_max and self.stop_probability < self.stop_threshold: 
            a = torch.bmm(self.first_arg, IR)
            b = torch.bmm(self.second_arg, IR)
            o = torch.bmm(self.output, IR)
            e = torch.bmm(self.instruction, IR)
            
            # Update memory, registers, and IR after machine operation
            self.memory, registers, IR, stop_prob = self.machine(e, a, b, o, self.memory, registers, IR)
            self.stop_probability = self.stop_probability + stop_prob[0]            
            
            # If we're training, calculate loss
            if train:
                new_efficiency, new_confidence = self.timestep_loss()
                efficiency_loss += new_efficiency
                confidence_loss += new_confidence
            
            t += 1
        
        # If we're training, return loss.  Otherwise return memory.
        if train:
            correctness_loss, halting_loss = self.final_loss()
            total_loss  = (
                self.correctness_weight * correctness_loss + 
                self.halting_weight * halting_loss + 
                self.confidence_weight * confidence_loss + 
                self.efficiency_weight * efficiency_loss)
            return torch.sum(self.registers) #total_loss #TODO: Aditya @ Rakia - implement loss, then replace this
        else:
            return self.memory
        
        
    def timestep_loss(self):
        """
        @ Rakia @ Aditya feel free to use this function definition or not.  My main thought was that this would 
        compute the types of loss which get updated every timestep.
        """
        # TODO: Insert losses
        return (1,1)
    
    def final_loss(self):
        """
        @ Rakia @ Aditya feel free to use this function definition or not.  My main thought was that this would 
        compute the types of loss which get updated every timestep.
        """
        # TODO: insert losses
        return (1,1)
   

    def lossfunctions(self, cmatrix, tjmatrix):
        """ compute four diferent loss functions and return a weighted average of the four measuring correctness, 
        halting, efficiency, and confidence"""
        
        self.matrix1 = ((cmatrix)*(self.program_matrix4 - self.memory)^2)
        self.correctness += self.matrix1
        
        
            
        self.efficiency += (1-self.stop_probability[self.t])
        
        self.confidence += torch.matmult(self.stop_probability[self.t] - self.stop_probability[self.t -1], matrix1)
        
      
        
        

In [5]:
class Operation(nn.Module):
    """
    Parent class for our binary operations
    """
    def __init__(self, M):
        """
        Initialize the memory length (needed so we can mod our answer in case it exceeds the range 0-M-1)
        Also calculate the output matrix for the operation
        
        :param M: Memory length
        """
        super(Operation, self).__init__()
        self.M = M
        
        # Create a MxMxM matrix where the (i,j,k) cell is 1 iff operation(i,j) = k.
        self.outputs = torch.zeros(M, M, M)
        for i in range(M):
            for j in range(M):
                val = self.compute(i, j)
                self.outputs[val][i][j] = 1
    
    def compute(self, x, y):
        """ 
        Perform the binary operation.  The arguments may or may not be used.
        
        :param x: First argument
        :param y: Second argument
        """
        raise NotImplementedError
    
    def forward(self):
        """
        :return: The output matrix
        """
        return self.outputs

In [6]:
class Add(Operation):

    def __init__(self, M):
        super(Add, self).__init__(M)
    
    def compute(self, x, y):
        return (x + y) % self.M


In [7]:
class Stop(Operation):
    
    def __init__(self, M):
        super(Stop, self).__init__(M)

    def compute(self, _1, _2):
        return 0

In [8]:
class Jump(Operation):
    
    def __init__(self, M):
        super(Jump, self).__init__(M)

    def compute(self, _1, _2):
        return 0 # Actual jump happens in the Machine class

In [9]:
class Decrement(Operation):
    
    def __init__(self, M):
        super(Decrement, self).__init__(M)

    def compute(self, x, _):
        return (x - 1) % self.M

In [10]:
class Increment(Operation):
    
    def __init__(self, M):
        super(Increment, self).__init__(M)

    def compute(self, x, _):
        return (x + 1) % self.M

In [11]:
class Max(Operation):
    
    def __init__(self, M):
        super(Max, self).__init__(M)

    def compute(self, x, y):
        return max(x,y)

In [12]:
class Min(Operation):
    
    def __init__(self, M):
        super(Min, self).__init__(M)

    def compute(self, x, y):
        return min(x,y)

In [13]:
class Read(Operation):
    
    def __init__(self, M):
        super(Read, self).__init__(M)
        self.outputs = torch.zeros(M, M, M) # Leave output matrix blank since we're gonna do the reading elsewhere

    def compute(self, x, _):
        return 0 # Actual reading happens in the Machine class

In [14]:
class Subtract(Operation):
    
    def __init__(self, M):
        super(Subtract, self).__init__(M)

    def compute(self, x, y):
        return (x - y) % self.M

In [15]:
class Write(Operation):
    
    def __init__(self, M):
        super(Write, self).__init__(M)

    def compute(self, x, y):
        return 0 # Actual write happens in the Machine class

In [16]:
class Zero(Operation):
    
    def __init__(self, M):
        super(Zero, self).__init__(M)

    def compute(self, _1, _2):
        return 0

In [17]:
class Machine(nn.Module):
    """
    The Machine executes assembly instructions passed to it by the Controller.
    It updates the given memory, registers, and instruction pointer.
    The Machine doesn't have any learnable parameters.
    """
    def __init__(self, B, M, R):
        """
        Initializes dimensions, operations, and counters
        
        :param B: Batch size (meant to be 1)
        :param M: Memory length.  Integer values also take on values 0-M-1.  M is also the program length.
        :param R: Number of registers
        """
        super(Machine, self).__init__()
        
        # Store parameters as class variables
        self.R = R # Number of registers
        self.M = M # Memory length (also largest number)
        self.B = B # Batch size
        
        # Start off with 0 probability of stopping
        self.stop_probability = torch.zeros(B)
        
        # List of ops (must be in same order as the original ANC paper so compilation works right)
        self.ops = [ 
            Stop(M),
            Zero(M),
            Increment(M),
            Add(M),
            Subtract(M),
            Decrement(M),
            Min(M),
            Max(M),
            Read(M),
            Write(M),
            Jump(M)
        ]
        
        # Number of instructions
        self.N = len(self.ops)
        
        # Create a 4D matrix composed of the output matrices of each of the ops
        self.outputs = Variable(torch.zeros(self.N, M, M, M))
        
        for i in range(self.N):
            op = self.ops[i]
            self.outputs[i] = op()
            
        # Add an extra batch dimension
        self.outputs = torch.unsqueeze(self.outputs, 0)
        self.outputs = self.outputs.expand(B, -1, -1, -1, -1)
        
        # Keep track of ops which will be handled specially
        self.jump_index = 10
        self.stop_index = 0
        self.write_index = 9
        self.read_index = 8
        

        
    def forward(self, e, a, b, o, memory, registers, IR):
        
        """
        Run the Machine for one timestep (corresponding to the execution of one line of Assembly).
        The first four parameter names correspond to the vector names used in the original ANC paper
        
        :param e: Probability distribution over the instruction being executed (BxMx1)
        :param a: Probability distribution over the first argument register (length BxRx1)
        :param b: Probability distribution over the second argument register (length BxRx1)
        :param o: Probability distribution over the first argument register (length BxRx1)
        :param memory: Memory matrix (size BxMxM)
        :param registers: Register matrix (size BxRxM)
        :param IR: Instruction Register (length BxM)
        
        :return: The memory, registers, and instruction register after the timestep
        """
        
        # Dimensions B x 1 x R -> B x 1 x R
        a = torch.transpose(a, 1, 2)
        b = torch.transpose(b, 1, 2)
        
        # Calculate distributions over the two argument values by multiplying each 
        # register by the probability that register is being used.
        arg1 = torch.bmm(a, registers)
        arg2 = torch.bmm(b, registers)
        
        # Multiply the output matrix by the arg1 and arg2 vectors to take into account
        # Before we do this, we're going to have to do a bunch of dimension squishing.
        
        # arg1_long dimensions: B x 1 x M --> B x 1 x 1 x 1 x M
        arg1_long = torch.unsqueeze(arg1, 1)
        arg1_long = torch.unsqueeze(arg1_long, 1)
        
        outputs_x_arg1 = torch.matmul(arg1_long, self.outputs)
        
        # outputs_x_arg1 dimensions: B x N x M x 1 x M -> B x N x M x M
        outputs_x_arg1 = torch.squeeze(outputs_x_arg1, 3)
        
        # arg2_long dimensions: B x 1 x M --> B x 1 x M x 1
        arg2_long = torch.unsqueeze(arg2, 3)
        
        outputs_x_args = torch.matmul(outputs_x_arg1, arg2_long)
        
        # outputs_x_args dimensions: B x N x M x 1 -> B x N x M
        outputs_x_args = torch.squeeze(outputs_x_args, 3)
        
        # e dimensions B x N x 1 -> B x 1 x N
        e = torch.transpose(e, 1, 2)
        
        # read_vec dimensions B x 1 -> B x 1 x 1
        read_vec =  e[:, :, self.read_index]
        read_vec = read_vec.unsqueeze(1)
        
        # Length Bx1xM vector over the output of the operation
        out_vec = torch.matmul(e, outputs_x_args)
        
        # Deal with memory reads separately
        out_vec = out_vec + read_vec * torch.matmul(arg1, memory)        
        
        # Update our memory, registers, instruction register, and stopping probability
        memory = self.writeMemory(e, o, memory, arg1, arg2)
        registers = self.writeRegisters(out_vec, o, registers)
        IR = self.updateIR(e, IR, arg1, arg2)
        stop_prob = self.getStop(e)
        
        return(memory, registers, IR, stop_prob)
        
        
    def writeRegisters(self, out, o, registers):
        """
        Write the result of our operation to our registers.
        
        :param out: Probability distribution over the output value (Bx1xM)
        :param o: Probability distribution over the output register (BxRx1)
        :param Registers: register matrix (BxRxM)
        
        :return: The updated registers (BxRxM)
        """
        
        # Multiply probability of writing to each output register by the distribution over the value we're writing there.
        new_register_vals = torch.matmul(o, out)
        
        # Multiply each original register cell by the probabilty of not writing to that register
        old_register_vals = (1-o).expand(self.B, self.R, self.M) * registers
        
        # Take a weighted sum over the old and new register values
        registers =  new_register_vals + old_register_vals
        
        return registers
    
    def updateIR(self, e, IR, arg1, arg2):
        """
        Update the instruction register
        
        :param e: Distribution over the current instruction (BxNx1)
        :param IR: Instruction register (length BxMx1)
        :param arg1: Distribution over the first argument value (length BxMx1)
        :param arg2: Distribution over the second argument value (length BxMx1)
        
        :return: The updated instruction register (BxMx1)
        """
        
        # Dimensions B x 1 x M -> B x M x 1
        arg2 = arg2.transpose(1, 2)
        
        # Probability that we're on the jump instruction
        jump_probability = e[:, :, self.jump_index]
        
        # Probability that the first argument is 0
        is_zero = arg1[:, :, 0]
        
        # Slicing lost a dimension.  Let's add it back
        jump_probability = torch.unsqueeze(jump_probability, 1)
        is_zero = torch.unsqueeze(is_zero, 1)
        
        # If we're not jumping, just shift IR by one slot
        wraparound = IR[:, -1]
        normal_instructions = IR[:, :-1]
        
        # For whatever reason, when you chop off one row/column, that dimension disappears.  Add it back.
        wraparound = wraparound.unsqueeze(1)
        IR_no_jump = torch.cat([wraparound, normal_instructions], 1)
        
        # If we are on a jump instruction, check whether the argument's 0.
        # If it is, jump to the location specified by arg2.  Otherwise, increment like normal.
        IR_jump = arg2 * is_zero + (1 - is_zero) * IR_no_jump
        
        # Take a weighted sum of the instruction register with and without jumping
        IR = IR_no_jump * (1 - jump_probability) + IR_jump * jump_probability
        
        return IR
    
    def writeMemory(self, e, o, mem_orig, arg1, arg2):
        """
        Update the memory
        
        :param e: Distribution over the current instruction (B x1xM)
        :param mem_orig: Current memory matrix (BxMxM)
        :param arg1: Distribution over the first argument value (Bx1xM)
        :param arg2: Distribution over the second argument value (Bx1xM)
        
        :return: The updated memory matrix (BxMxM)
        """
        
        # Probability that we're on the write instruction
        write_probability = e[:,:, self.write_index]
        
        # write_probability dimensions: Bx1 -> B x 1 x 1
        write_probability = torch.unsqueeze(write_probability, 1)
        
        # arg1 dimensions: B x 1 x M -> B x M x 1
        arg1 = torch.transpose(arg1, 1, 2)
        
        # If we are on a write instruction, write the value arg2 in register arg1. Otherwise, leave memory as is.
        mem_changed = torch.bmm(arg1, arg2)
        mem_unchanged = mem_orig * (1-arg1).expand(-1, -1, self.M)
        mem_write = mem_changed + mem_unchanged
        
        # Take a weighted sum over the new memory and old memory
        memory = mem_orig * (1 - write_probability) + mem_write * write_probability

        return memory
        
    def getStop(self, e):
        """
        Obtain the probability that we will stop at this timestep based on the probability that we are running the STOP op.
        
        :param e: distribution over the current instruction (length Bx1xM)
        
        :return: probability representing whether the controller should stop.
        """
        return e[:, :, self.stop_index].data[0]


In [32]:
def one_hotify(vec, length, dimension):
    """
    Turn a tensor of integers into a matrix of one-hot vectors.
    
    :param vec: The vector to be converted.
    :param length: One dimension of the matrix (the other is the length of vec)
    :param dimension: Which dimension stores the elements of vec.  If 0, they're stored in the rows.  If 1, the columns.
    
    :return A matrix of one-hot vectors, each row or column corresponding to one element of vec
    """
    x = vec.size()[0]
    if dimension == 0:
        binary_vec = torch.zeros(x, length)
        for i in range(x):
            binary_vec[i][vec[i]] = 1
        return binary_vec
    elif dimension == 1:
        binary_vec = torch.zeros(length, x)
        for i in range(x):
            binary_vec[vec[i]][i] = 1
        return binary_vec
        
        
    
    
    
    
    

In [33]:
# Addition task
# Generate this by running the instructions here (but with the addition program file): https://github.com/aditya-khant/neural-assembly-compiler
# Then get rid of the .cuda in each of the tensors since we (or at least I) don't have cuda
init_registers = torch.IntTensor([6,2,0,1,0,0]) # Length R, should be RxM
first_arg = torch.IntTensor([4,3,3,3,4,2,2,5]) # Length M, should be RxM
second_arg = torch.IntTensor([5,5,0,5,5,1,4,5]) # Length M, should be RxM
target = torch.IntTensor([4,3,5,3,4,5,5,5]) # Length M, should be RxM
instruction = torch.IntTensor([8,8,10,5,2,10,9,0]) # Length M, should be NxM

# Get dimensions we'll need
M = first_arg.size()[0]
R = init_registers.size()[0]
N = 11

# Turn the given tensors into matrices of one-hot vectors.
init_registers = one_hotify(init_registers, M, 0)
first_arg = one_hotify(first_arg, R, 1)
second_arg = one_hotify(second_arg, R, 1)
target = one_hotify(target, R, 1)
instruction = one_hotify(instruction, N, 1)

# Add a fake first batch
init_registers = init_registers.unsqueeze(0)
first_arg = first_arg.unsqueeze(0)
second_arg = second_arg.unsqueeze(0)
target = target.unsqueeze(0)
instruction = instruction.unsqueeze(0)





# An example starting memory.  In this case, we're adding 3+4 and expect 7.
initial_memory = torch.IntTensor([3,4,0,0,0,0,0,0])
initial_memory = one_hotify(initial_memory, M, 0)
output_memory = torch.zeros(M, M)
output_memory[0][7] = 1

# Output mask has ones in the rows of the memory matrix where the answer will be stored.
output_mask = torch.zeros(M, M)
output_mask[0] = torch.ones(M)







In [35]:
class ListDataset(data.Dataset):
    def __init__(self):
        
        self.input_list = [(initial_memory, output_memory, output_mask)] # Manually add lists of our initial memory
       
    def __len__(self):
        return len(self.input_list)
    
    def __getitem__(self, i):
        return self.input_list[i], torch.zeros(1,1)

In [36]:
dataset = ListDataset()
data_loader = data.DataLoader(dataset, batch_size = 1) # Don't change this batch size.  You have been warned.

# TODO: Make this prettier!
def anc_loss(loss, label):
    return loss

# Initialize our controller
controller = Controller(first_arg = first_arg, 
                        second_arg = second_arg, 
                        output = target, 
                        instruction = instruction, 
                        initial_registers = init_registers, 
                        stop_threshold = .9, 
                        blur = 3, 
                        correctness_weight = .3, 
                        halting_weight = .3, 
                        confidence_weight = .3, 
                        efficiency_weight = .1, 
                        t_max = 10) 

# Learning rate is a tunable hyperparameter
# The paper didn't mention which one they used
optimizer = optim.Adam(controller.parameters(), lr = 0.01)

# TODO: Choose better initial values by checking what the paper did.
# TODO: Think about whether there's a better validation criterion than just the loss.
best_model, train_plot_losses, validation_plot_losses = training.train_model_anc(
    controller, 
    data_loader, 
    anc_loss, 
    optimizer, 
    print_every=1, 
    num_epochs=10, 
    deep_copy_desired=False, 
    validation_criterion=anc_loss, 
    forward_train=True, 
    batch_size=10)

Epoch 0/9
----------
LR is set to 0.001
forward
Epoch Number: 0, Batch Number: 1, Training Loss: 0.0000
Time so far is 0m 0s

Epoch 1/9
----------
forward
Epoch Number: 1, Batch Number: 1, Training Loss: 0.0000
Time so far is 0m 0s

Epoch 2/9
----------
forward
Epoch Number: 2, Batch Number: 1, Training Loss: 0.0000
Time so far is 0m 0s

Epoch 3/9
----------
forward
Epoch Number: 3, Batch Number: 1, Training Loss: 0.0000
Time so far is 0m 0s

Epoch 4/9
----------
forward
Epoch Number: 4, Batch Number: 1, Training Loss: 0.0000
Time so far is 0m 0s

Epoch 5/9
----------
forward
Epoch Number: 5, Batch Number: 1, Training Loss: 0.0000
Time so far is 0m 0s

Epoch 6/9
----------
forward
Epoch Number: 6, Batch Number: 1, Training Loss: 0.0000
Time so far is 0m 0s

Epoch 7/9
----------
LR is set to 0.0001
forward
Epoch Number: 7, Batch Number: 1, Training Loss: 0.0000
Time so far is 0m 0s

Epoch 8/9
----------
forward
Epoch Number: 8, Batch Number: 1, Training Loss: 0.0000
Time so far is 0m 0s