In [1]:
!pip install matplotlib
import torch
import torch.nn.functional as F
import math
import torch.nn
import matplotlib.pyplot as plt
import numpy as np



  from .autonotebook import tqdm as notebook_tqdm


In [391]:
def plot_tensor(tensor):
    fig, ax = plt.subplots()
    im = ax.imshow(tensor, cmap='viridis') # You can use other colormaps like 'plasma', 'inferno', 'magma', etc.
    plt.show()

class Register(object):
    def __init__(self, name, size):
        self.name = name
        self.size = size
        self.offset = None

class Embedding(object):
    def __init__(self, tokens: list[str], registers: list[Register]):
        self.tokens = tokens
        self.token_map = { t: i for i,t in enumerate(tokens) }
        self.registers = registers
        self.register_map = {}
        self.register_size = 0
        
        if len(registers) == 0 or registers[0].name != 'pos':
            raise Exception("First register must be 'pos'") 
        
        offset = len(tokens)
        for reg in registers:
            reg.offset = offset
            offset += reg.size
            self.register_size += reg.size
            self.register_map[reg.name] = reg
            
        self.dim = len(tokens) + self.register_size

    def tokenize(self, string: str):
        return F.one_hot(torch.tensor([self.token_map[c] for c in string]), num_classes=len(self.tokens)).float()

    def embed(self, sequence):
        # We want to create additional space to store the registers
        extension_tensor = torch.zeros(*sequence.shape[:-1], self.register_size)

        # Encode position in the first extra embedding dimension
        for i in range(sequence.shape[0]):
            extension_tensor[i, 0] = math.sin(i*(2*math.pi)/100)
            extension_tensor[i, 1] = math.cos(i*(2*math.pi)/100)

        sequence = torch.cat((sequence, extension_tensor), dim=-1)

        return sequence
    
    def predict(self, sequence):
        return self.tokens[torch.argmax(sequence[-1,:])]

class AttentionLayer(object):
    def __init__(self, instruction):
        self.instruction = instruction
        
    def attend(self, seq):
        query = seq @ self.instruction.query
        key = seq @ self.instruction.key
        value = seq @ self.instruction.value

        causal_mask = torch.triu(torch.ones(seq.shape[0], seq.shape[0]), diagonal=1)*-1e10
        norm = np.sqrt(seq.shape[-1])
        
        kq = torch.nn.Softmax(1)(query @ key.T / norm + causal_mask)
        #print('attention query')
        #plot_tensor(query)
        #print('key')
        #plot_tensor(key)
        #plot_tensor(query @ key.T)
        #plot_tensor(kq)
        #print(query @ key.T)
                    
        s = kq @ value
        #plot_tensor(s)
        return (seq + s)
    
class MLPLayer(object):
    def __init__(self, instruction, debug=False):
        self.instruction = instruction
        self.debug = debug
        
    def forward(self, seq):
        a = torch.nn.GELU()(seq @ self.instruction.first_weights + self.instruction.first_bias)
        if self.debug:
            plot_tensor(a)
            print('AAA', a)
        b = (a @ self.instruction.second_weights)
        if self.debug:
            plot_tensor(b)
            print('BBB', b)
        x = b + self.instruction.second_bias
        if self.debug:
            plot_tensor(x)
            print('XXX',x)
        return seq + x
    
tokens = list('0123456789+= \n')
pos = Register('pos', 2)
left_pos = Register('left_pos', 2)
right_pos = Register('right_pos', 2)
out_pos = Register('out_pos', 2)
left_digit = Register('left', len(tokens))
right_digit = Register('right', len(tokens))
out_digit = Register('out', len(tokens))
final_digit = Register('final', len(tokens))
carry = Register('carry', 1)
distance = Register('distance', 1)

embedding = Embedding(tokens, [pos, left_pos, right_pos, out_pos, left_digit, right_digit, out_digit, carry, distance, final_digit])
        
class FindAndStore(object):
    def __init__(self, embedding: Embedding, token: str, register: Register):
        pos_reg = embedding.register_map['pos']
        
        # No matter the current token, we attend to if the attended token is the given token
        token_select = torch.zeros(embedding.dim, embedding.dim) - 1e10
        token_select[:, int(embedding.token_map[token])] = 1e10

        position_select = torch.zeros(embedding.dim, embedding.dim)
        position_select[pos_reg.offset, register.offset] = 1.0
        position_select[pos_reg.offset + 1, register.offset + 1] = 1.0
        
        self.key = torch.eye(embedding.dim)
        self.query = token_select
        self.value = position_select

ex = embedding.embed(embedding.tokenize('10+10=2111')) 
        
class GetRelativeToken(object):
    def __init__(self, embedding: Embedding, pos_reg: Register, steps: int, out: Register):
        tpos_reg = embedding.register_map['pos']
        
        position_select = torch.zeros(embedding.dim, embedding.dim)
        position_select[tpos_reg.offset, tpos_reg.offset] = 1e10
        position_select[tpos_reg.offset + 1, tpos_reg.offset + 1] = 1e10

        i = -steps
        sin = math.sin(i*(2*math.pi)/100)*1
        cos = math.cos(i*(2*math.pi)/100)*1

        rotation = torch.zeros(embedding.dim, embedding.dim)
        rotation[pos_reg.offset, tpos_reg.offset] = cos
        rotation[pos_reg.offset + 1, tpos_reg.offset] = -sin
        rotation[pos_reg.offset, tpos_reg.offset + 1] = sin
        rotation[pos_reg.offset + 1, tpos_reg.offset + 1] = cos
        #plot_tensor(rotation)
        
        token_copy = torch.zeros(embedding.dim, embedding.dim)
        for i in range(len(embedding.tokens)):
            token_copy[i, i + out.offset] = 1.0
            
        self.query = rotation
        self.key = position_select
        self.value = token_copy
        
class Multiply(object):
    def __init__(self, embedding: Embedding, left_token: Register, right_token: Register, carry: Register, out_token: Register):
        width = 10*10*2

        self.first_weights = torch.zeros(embedding.dim, width)
        self.first_bias = torch.zeros(width)
        
        self.second_weights = torch.zeros(width, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)

        idx = 0
        for i in range(0, 10):
            for j in range(0, 10):
                self.first_weights[left_token.offset + i, idx] = 500
                self.first_weights[right_token.offset + j, idx] = 500
                self.first_weights[carry.offset, idx] = -1000
                self.first_bias[idx] = -900
                self.second_weights[idx, out_token.offset + (i + j) % 10] = 0.01
                if (i + j) >= 10:
                    self.second_weights[idx, carry.offset] = 0.01
                idx += 1
                
                self.first_weights[left_token.offset + i, idx] = 333
                self.first_weights[right_token.offset + j, idx] = 333
                self.first_weights[carry.offset, idx] = 333
                self.first_bias[idx] = -900
                self.second_weights[idx, out_token.offset + (i + j + 1) % 10] = 0.01 * (1.0/0.99)
                # If we need to carry, there was already a carry so we change nothing 
                if (i + j + 1) >= 10:
                    self.second_weights[idx, carry.offset] = 0.0
                # If we don't need to carry, we need to clear the carry bit,
                else:
                    self.second_weights[idx, carry.offset] = -0.01 * (1.0/0.99)
                idx += 1
                
class Clear(object):
    def __init__(self, embedding: Embedding, registers: list[Register]):
        self.first_weights = torch.zeros(embedding.dim, embedding.dim)
        self.first_bias = torch.zeros(embedding.dim)
        
        for reg in registers:
            for i in range(reg.size):
                self.first_weights[reg.offset + i, reg.offset + i] = 100.0
        
        self.second_weights = torch.zeros(embedding.dim, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        for reg in registers:
            for i in range(reg.size):
                self.second_weights[reg.offset + i, reg.offset + i] = -0.01
                
class DiffPos(object):
    def __init__(self, embedding: Embedding, left_pos: Register, right_pos: Register, distance: Register):
        self.first_weights = torch.zeros(embedding.dim, 2)
        self.first_bias = torch.zeros(2)
        
        # Note: it's important that the x and y are multiplied by different numbers,
        # otherwise 1, 2 computes to be the same as 2, 1
        
        self.first_weights[left_pos.offset, 0] = 1e2
        self.first_weights[left_pos.offset + 1, 0] = 1e3
        self.first_weights[right_pos.offset, 0] = -1e2
        self.first_weights[right_pos.offset + 1, 0] = -1e3
        
        self.first_weights[left_pos.offset, 1] = -1e2
        self.first_weights[left_pos.offset + 1, 1] = -1e3
        self.first_weights[right_pos.offset, 1] = 1e2
        self.first_weights[right_pos.offset + 1, 1] = 1e3
        
        self.second_weights = torch.zeros(2, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.second_weights[0, distance.offset] = 1.0
        self.second_weights[1, distance.offset] = 1.0
        
class IsZero(object):
    def __init__(self, embedding: Embedding, zero: Register):
        self.first_weights = torch.zeros(embedding.dim, 2)
        self.first_bias = torch.zeros(1)
        
        self.first_weights[zero.offset, 0] = -100.0
        self.first_weights[zero.offset, 1] = 1.0
        self.first_bias[0] = 10.0
        
        self.second_weights = torch.zeros(2, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        self.second_weights[0, zero.offset] = 0.1
        self.second_weights[1, zero.offset] = -1.0
        self.second_bias[zero.offset] = 10.0
        
class CopyOnMatch(object):
    def __init__(self, embedding: Embedding, match: Register, src: Register, dst: Register):
        self.first_weights = torch.zeros(embedding.dim, src.size)
        self.first_bias = torch.ones(src.size)
        
        for i in range(src.size):
            self.first_weights[match.offset, i] = 100
            self.first_weights[src.offset + i, i] = 100
            self.first_bias[i] = -190
        
        self.second_weights = torch.zeros(src.size, embedding.dim)
        self.second_bias = torch.zeros(embedding.dim)
        
        for i in range(src.size):
            self.second_weights[i, dst.offset + i] = 0.1

class StepPosition(object):
    def __init__(self, embedding: Embedding, positions: list[Register], offsets: list[Register]):

        rotation = torch.zeros(embedding.dim, embedding.dim)
        
        n = 0
        for reg in positions:
            i = offsets[n]
            sin = math.sin(-i*(2*math.pi)/100)*1
            cos = math.cos(-i*(2*math.pi)/100)*1
        
            rotation[reg.offset, reg.offset] = cos - 1
            rotation[reg.offset + 1, reg.offset] = -sin
            rotation[reg.offset, reg.offset + 1] = sin
            rotation[reg.offset + 1, reg.offset + 1] = cos - 1
            
            n += 1

        self.first_weights = rotation
        self.first_bias = torch.ones(embedding.dim)*1000
        
        self.second_weights = torch.eye(embedding.dim, embedding.dim)
        self.second_bias = torch.ones(embedding.dim)*-1000

In [423]:
import tqdm

# Attention layer to search for the location of various tokens
find_plus = AttentionLayer(FindAndStore(embedding, '+', left_pos))
find_equal = AttentionLayer(FindAndStore(embedding, '=', right_pos))
find_out = AttentionLayer(FindAndStore(embedding, '=', out_pos))

# Fully connected layers to adjust positions by fixed amounts
step_out = MLPLayer(StepPosition(embedding, [out_pos], [-1]))
step_pos = MLPLayer(StepPosition(embedding, [left_pos, right_pos, out_pos], [-1, -1, 1]))

# Layers to read the token at a specific position
read_left = AttentionLayer(GetRelativeToken(embedding, left_pos, 0, left_digit))
read_right = AttentionLayer(GetRelativeToken(embedding, right_pos, 0, right_digit))

# Layers to look up how tokens combine (symbolically)
multiply = MLPLayer(Multiply(embedding, left_digit, right_digit, carry, out_digit))

# Utility layer to clear the residual stream
clear_lro = MLPLayer(Clear(embedding, [left_digit, right_digit, out_digit]))

# Layers to keep track of if the out cursor is the next token and if so copy the out digit
diff_pos = MLPLayer(DiffPos(embedding, pos, out_pos, distance))
is_zero = MLPLayer(IsZero(embedding, distance))
copy_out = MLPLayer(CopyOnMatch(embedding, distance, out_digit, final_digit))

# Linear layer to move the final digit to the out digit
final_projection = torch.zeros(embedding.dim, embedding.dim)
for i in range(final_digit.size):
    final_projection[final_digit.offset + i, i] = 1.0

def read_out(x):
    print('Out:', torch.argmax(x[-1,out_digit.offset:out_digit.offset + out_digit.size]))

def generate(input_string):
    
    # First we embed the original string
    x = embedding.embed(embedding.tokenize(input_string))

    # Then we look for the various symbols that direct us to
    x = find_plus.attend(x)
    x = find_equal.attend(x)
    x = find_out.attend(x)
    x = step_out.forward(x)

    for i in range(4):

        x = clear_lro.forward(x)
        x = step_pos.forward(x)
        x = read_left.attend(x)
        x = read_right.attend(x)
        x = multiply.forward(x)

        # If the focused output is this one, copy the output
        x = diff_pos.forward(x)
        #print(x[:, out_pos.offset:out_pos.offset + 2])
        #print(x[:, distance.offset])
        x = is_zero.forward(x)
        x = copy_out.forward(x)

        #plot_tensor(x)
        #read_out(x)

    #plot_tensor(x)
    x =  x @ final_projection
    #plot_tensor(x)

    return input_string + embedding.predict(x)

test = '0999+0111='
for i in range(10):
    test = generate(test)
    print(test)

to_check = [(i, j) for i in range(1000) for j in range(1000)]
passed = 0

for i, j in tqdm.tqdm(to_check):
    test = str(i).zfill(4) + '+' + str(j).zfill(4) + '='
    expected = test + str(i + j).zfill(4)[::-1]

    for n in range(4):
        test = generate(test)

    if test != expected:
        print("Failed!", test, expected)
    else:
        passed += 1

print("All done! Passed count: ", passed, "of", len(to_check))


0999+0111=0
0999+0111=01
0999+0111=011
0999+0111=0111
0999+0111=01110
0999+0111=011100
0999+0111=0111000
0999+0111=01110000
0999+0111=011100000
0999+0111=0111000000


100%|██████████████████████████████████████████████████████████████████████████| 1000000/1000000 [3:33:21<00:00, 78.12it/s]

All done! Passed count:  1000000 of 1000000



