In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"  
device = torch.device(dev)  

In [None]:
batch_size_train = 40
batch_size_test = 40
learning_rate = 0.01
log_interval = 10

num_vectors = 4
len_vectors = 10
img_height = 28
img_width = 28
win_size = 3
epsilon = .7
epochs = 4000
steps = 20

## Data Functions

In [None]:
def build_loader(train_val, batch_size):
    return torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            '/files/', 
            train = train_val, 
            download = True,                   
            transform = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
             ])
        ),
        batch_size = batch_size,
        shuffle = True
    )

In [None]:
train_loader = build_loader(true, batch_size_train)

test_loader = build_loader(false, batch_size_test)

In [None]:
def data_to_state(example_data, batch_size):
    # reshape MNIST input image and stack image data until 
    # the bottom state vector has the entire vector filled
    temp_example_data = torch.reshape(
        example_data,
        (batch_size, img_height, img_width)
    )
    temp_inp = [temp_example_data for i in range(10)]
    temp_inp_data = torch.stack((temp_inp), dim = 3)
    return temp_inp_data.to(device)

In [None]:
def targets_to_state(example_targets, batch_size):
    # one hot MNIST target data and uniformly spread correct target vector 
    # across every pixel of the state per batch
    return torch.nn.functional
        .one_hot(example_targets, num_classes=10)
        .repeat(1, img_height*img_width)
        .view((batch_size, img_height, img_width, 10))
        .float().to(device)

## CA Functions

In [None]:
def init_state(batch_size, img_height, img_width, num_vectors, len_vectors):
    state = torch.rand(
        (batch_size, img_height, img_width, num_vectors, len_vectors)
    ) * .1
    return state.to(device)

In [None]:
class model(nn.Module):
    def __init__(self, num_inp, num_out):
        super(model, self).__init__()
        self.Q1 = nn.Linear(num_inp, num_out)
        self.K1 = nn.Linear(num_out, num_out)
        self.V1 = nn.Linear(num_out, num_out)
        
        self.m = nn.Dropout(p = 0.1)
        
        self.act = nn.LeakyReLU()
        self.act1 = nn.Tanh()
        self.act3 = nn.GELU()
        
    def forward(self, x):
        
        Q = self.act1(self.Q1(self.m(x)))
        K = self.act1(self.K1(self.m(Q))) + Q
        V = self.act1(self.V1(self.m(K))) + Q + K
        
        return V * .01

In [None]:
def get_layer_attention(center_matrix,roll_matrix,after_tri):
    # dot product vectors to find similarity of adjacent vectors
    after_mul = torch.matmul(
        center_matrix, 
        roll_matrix.permute((0, 1, 2, 4, 3))
    )
    after_diag = torch.diagonal(after_mul, offset = 0, dim1 = 3, dim2 = 4)
    
    # multiply vectors by lambda matrix to find full attention numbers
    after_eps = torch.matmul(after_diag, after_tri)
    
    # stack full attention numbers so that each vector gets
    # its proper attention number
    after_sim = torch.stack(
        [after_eps for i in range(len_vectors)],
        dim=3
    ).permute((0,1,2,4,3))
    
    # multiply each vector by the attention numbers to complete
    # the attention step
    full_vec_dis = center_matrix * after_sim.detach()
    
    return full_vec_dis

In [None]:
# TODO: function too long. Complexity too high
def compute_all(
    bottom_up_model_list, top_down_model_list, layer_att_model_list, 
    state, len_vectors, num_vectors, batch_size
):
    # shift state to in 9 directions along the x and y plane
    shifts_list = [
        (-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 0),
        (0, 1), (1, -1), (1, 0), (1, 1)
    ]
    
    roll_list = map(
        lambda shifts_val: torch.roll(
            state, shifts = shifts_val, dims = (1, 2)
        ).to(device),
        shifts_list
    ) 
    
    eps_matrix = epsilon ** torch.arange(start = 1, end = num_vectors+1)
    try_roll = [
        torch.roll(eps_matrix, shifts=(i), dims = (0))
        for i in range(eps_matrix.shape[0])
    ]
    try_roll = torch.stack(try_roll)
    after_tri = torch.triu(try_roll, diagonal = 0).T.to(device)
    
    att_list = []
    for roll in roll_list:
        att_list.append(get_layer_attention(roll, roll5, after_tri))
    
    # concatenate vectors so that att_list contains the state and
    # every adjacent vector on the same vector level
    att_list = torch.cat(roll_list, dim=4)
    
    # feed layers to models:
    # top-down models don't get first two layers as input and 
    # don't add to 1st & last layer
    # bot-up models don't get last layer as input and don't add to first layer 
    # adjacent models don't add to first layer
    delta = [
        torch.zeros(
            (batch_size*img_height*img_width, len_vectors)
        ).to(device) 
        for i in range(num_vectors)
    ]
    for i in range(num_vectors):
        if(i<num_vectors-2):
            top_down_temp = top_down_model_list[i](
                torch.reshape(
                    att_list[:, :, :, i+2, :], 
                    (-1, len_vectors*9)
                )
            )
            delta[i+1] = delta[i+1] + top_down_temp
        if(i<num_vectors-1):
            bottom_up_temp = bottom_up_model_list[i](
                torch.reshape(
                    att_list[:, :, :, i, :],
                    (-1, len_vectors*9)
                )
            )
            att_layer_temp = layer_att_model_list[i](
                torch.reshape(
                    att_list[:, :, :, i+1, :],
                    (-1, len_vectors*9)
                )
            )
            delta[i+1] = delta[i+1] + bottom_up_temp + att_layer_temp
    
    #format delta so that delta and state can be added together
    delta = torch.stack(delta, dim=1)#.permute(0, 2, 1)#.permute(2, 0, 1)
    delta = torch.reshape(
        delta,
        (batch_size, img_height, img_width, num_vectors, len_vectors)
    )
    return delta

## Model Initializations

In [None]:
bottom_up_model_list = [model(9*len_vectors,len_vectors).cuda() for i in range(num_vectors-1)]
top_down_model_list= [model(9*len_vectors,len_vectors).cuda() for i in range(num_vectors-2)]
layer_att_model_list = [model(9*len_vectors,len_vectors).cuda() for i in range(num_vectors-1)]

## Optimizers

In [None]:
#create parameter list of every model to feed into optimizer
param_list = []
for i in range(num_vectors):
    if (i<num_vectors - 2):
        param_list =  param_list + list(top_down_model_list[i].parameters())
    if (i<num_vectors-1):
        param_list =  param_list + list(bottom_up_model_list[i].parameters())
        param_list =  param_list + list(layer_att_model_list[i].parameters())
        
optimizer = torch.optim.Adam(param_list, lr=learning_rate)
mse = nn.MSELoss()

## Train

In [None]:
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
for epoch in range(epochs):
    optimizer.zero_grad()
    
    # in case StopIteration error is raised
    try:
        batch_idx, (example_data, example_targets) = next(examples)
    except StopIteration:
        examples = enumerate(train_loader)
        batch_idx, (example_data, example_targets) = next(examples)
    
    # initialize state
    state = init_state(
        batch_size_train, img_height, img_width, num_vectors, len_vectors
    )
    
    # put current batches into state
    state[:, :, :, 0, :] = data_to_state(example_data, batch_size_train)
    state1 = torch.clone(state)
    state2 = torch.clone(state)
    for step in range(steps):
        
        delta = compute_all(
            bottom_up_model_list, top_down_model_list, layer_att_model_list, 
            state, len_vectors, num_vectors, batch_size_train
        )
        
        state = state + delta + .0001 * torch.rand((state.shape)).to(device)
        
        # add first state to state in the middle of the steps
        # (allows for RESNET type gradient backprop)
        if ( step % int(steps/2) == 0):
            state = state + state1 * .1
            state1 = torch.clone(state)
            
        if (step % int(steps/4) == 0):
            state = state + state2 * .1
            state2 = torch.clone(state)
            
    state = state + state1 * .1
    # get loss
    pred_out = state[:, :, :, -1]
    targ_out = targets_to_state(example_targets, batch_size_train)
    loss = mse(pred_out, targ_out)
    loss.backward()
    optimizer.step()
    print("Epoch: {}/{}  Loss: {}".format(epoch, epochs, loss))

## Test

In [None]:
tot_corr = 0
tot_batches = 0
for example_data, target in test_loader:
    tot_batches+=batch_size_test
    
    #initialize state
    state = init_state(
        batch_size_test, img_height, img_width, num_vectors, len_vectors
    )
    
    #put current batches into state
    state[:, :, :, 0, :] = data_to_state(example_data,batch_size_test)
    for step in range(steps):
        delta = compute_all(
            bottom_up_model_list, top_down_model_list, layer_att_model_list,
            state, len_vectors, num_vectors ,batch_size_test
        )
        
        #update state
        state = state + delta
        
    for batch in range(batch_size_test):
        temp = torch.zeros((10))
        for height in range(img_height):
            for width in range(img_width):
                ind = torch.argmax(state[batch, height, width, -1])
                temp[ind] += 1
        
        if(target[batch] == torch.argmax(temp)):
            tot_corr += 1
    print("Acc: {}".format(tot_corr/tot_batches))

print("Final Accuracy: {}".format(tot_corr/tot_batches))

In [None]:
def save_models(
    bottom_up_model_list, top_down_model_list, layer_att_model_list
):
    PATH = "best_models/"
    for i in range(num_vectors):
        if (i < num_vectors - 2):
            torch.save(
                top_down_model_list[i],
                PATH + "top_down_model{}".format(i)
            )
        if (i < num_vectors - 1):
            torch.save(
                bottom_up_model_list[i],
                PATH + "bottom_up_model{}".format(i)
            )
            torch.save(
                layer_att_model_list[i],
                PATH + "layer_att_model{}".format(i)
            )

In [None]:
save_models(bottom_up_model_list, top_down_model_list, layer_att_model_list)