In [115]:
import torch
import torch.nn as nn

In [117]:
# Create a simple recurrent neural network (RNN) with PyTorch
class SnaekRNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, num_layers):
        super(SnaekRNN, self).__init__()

        self.input_size = input_size  # Number of inputs to the RNN (3) 
        self.output_size = output_size # Number of outputs from the RNN (2) 
        self.hidden_size = hidden_size # Number of neurons in the hidden layer of the RNN. 

        self.rnn = nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=num_layers, batch_first=True)

        self.fc = nn.Linear(in_features=self.hidden_size, out_features=self.output_size)

    def forward(self, x): 
        out, _ = self.rnn(x)   # Pass in x to get out and _ from rnn layer 
        out = self.fc(out[-1]) # Pass the last output from rnn to fc layer
        return out           # Return result from fc layer

In [165]:
class SnakeFF(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, num_layers):
        super(SnakeFF, self).__init__()

        self.num_layers = num_layers
        self.input_size = input_size  # Number of inputs to the Feed Forward Network
        self.output_size = output_size # Number of outputs from the Feed Forward Network
        self.hidden_size = hidden_size # Number of neurons in the hidden layer of the Feed Forward Network.

        for i in range(num_layers):
            if i == 0:
                setattr(self, f'fc{i}', nn.Linear(in_features=self.input_size, out_features=self.hidden_size))
            else:
                setattr(self, f'fc{i}', nn.Linear(in_features=self.hidden_size, out_features=self.hidden_size))
        
        setattr(self, f'fc{num_layers}', nn.Linear(in_features=self.hidden_size, out_features=self.output_size))

    def forward(self, x):
        layer_count = self.num_layers

        for i in range(layer_count):
            x = getattr(self, f'fc{i}')(x)
            x = nn.ReLU()(x)

        x = getattr(self, f'fc{layer_count}')(x)
        x = nn.Sigmoid()(x)

        return x

# Create training data

In [155]:
# inputs = [front(-1, 0, 1), left(-1, 0, 1), right(-1, 0, 1)]
# outputs = [left(0, 1) > 0.8, right(0, 1) > 0.8]

## create lots of training data in memory

# inputs (float, float, float)
# setup inputs. Values are either -1, 0, or 1
# inputs also need batch dimension
x_train = torch.zeros(1000, 1, 3)
for i in range(1000):
    x_train[i][0][0] = torch.randint(-1, 2, (1,)) # front
    x_train[i][0][1] = torch.randint(-1, 2, (1,)) # left
    x_train[i][0][2] = torch.randint(-1, 2, (1,)) # right

# For each input, figure out the correct output
y_train = torch.zeros(1000, 1, 2)
for i in range(1000):
    front = x_train[i][0][0]
    left = x_train[i][0][1]
    right = x_train[i][0][2]
    death = -1
    empty = 0
    food = 1

    # if front is food, go forward
    # if left and right is food, go left or right randomly 50% of the time
    # if left or right is food, go there 50% of the time
    # if front is death, turn left or right randomly, prefering the side with food
    # if left or right is death, turn away from it 25% of the time
    # outputs = [left(0, 1), right(0, 1)]
    # if left output is < 0.8 and right output is > 0.8, the snake will go right
    # if left output is > 0.8 and right output is < 0.8, the snake will go left
    # if left output is > 0.8 and right output is > 0.8, the snake will go straight
    # if left output is < 0.8 and right output is < 0.8, the snake will go straight
    if front == food:
        y_train[i, 0, 0] = 0
        y_train[i, 0, 1] = 0
    elif left == food and right == food:
        # random choose direction
        if torch.rand(1) > 0.5:
            y_train[i, 0, 0] = 0
            y_train[i, 0, 1] = 1
        else:
            y_train[i, 0, 0] = 1
            y_train[i, 0, 1] = 0
    elif left == food:
        y_train[i, 0, 0] = 1
        y_train[i, 0, 1] = 0
    elif right == food:
        y_train[i, 0, 0] = 0
        y_train[i, 0, 1] = 1
    elif front == death:
        if left == death: # go right
            y_train[i, 0, 0] = 0
            y_train[i, 0, 1] = 1
        elif right == death: # go left
            y_train[i, 0, 0] = 1
            y_train[i, 0, 1] = 0
        else: # go left or right randomly
            if torch.rand(1) > 0.5:
                y_train[i, 0, 0] = 0
                y_train[i, 0, 1] = 1
            else:
                y_train[i, 0, 0] = 1
                y_train[i, 0, 1] = 0
    else: # go straight
        y_train[i, 0, 0] = 0
        y_train[i, 0, 1] = 0

   
# get batches, each batch is just a single input and output. Don't include a paramter for batch_size, it will be set to 1
def get_batches(x_train, y_train):
    for i in range(len(x_train)):
        yield x_train[i], y_train[i]

# Train

In [162]:
# Inputs: 3 values, 1 for each direction and 1 for what's in front of it
# Outputs: 2 values, each between 0 and 1 
input_size = 3 
output_size = 2 
hidden_size = 6
num_layers = 9
epochs = 50

# Create an instance of our SnaekRNN class and define criterion and optimizer  
model = SnakeFF(input_size=input_size, output_size=output_size, hidden_size=hidden_size, num_layers=num_layers)

# train
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_function = nn.BCELoss()

for i in range(epochs):
    accuracies = []
    losses = []
    for x, y in get_batches(x_train, y_train):
        # Calculate loss and gradients
        y_pred = model(x)
        loss = loss_function(y_pred, y)
        loss.backward()

        # calculate accuracy
        accuracy = 0
        if y[0][0] == 1 and y_pred[0][0] > 0.8 and y_pred[0][1] < 0.8:
            accuracy += 1
        elif y[0][1] == 1 and y_pred[0][0] < 0.8 and y_pred[0][1] > 0.8:
            accuracy += 1
        elif y[0][0] == 0 and y[0][1] == 0 and y_pred[0][0] < 0.8 and y_pred[0][1] < 0.8:
            accuracy += 1

        accuracy = accuracy / len(y_pred)
    
        accuracies.append(accuracy)
        losses.append(loss)

        # Update weights using optimizer
        optimizer.step()
        optimizer.zero_grad()

    # Logging every 10 iterations
    if (i+1) % 10 == 0:
        accuracy = sum(accuracies) / len(accuracies)
        loss = sum(losses) / len(losses)
        print(f'Epoch: {i+1:3}/{epochs:3}  Loss: {loss.item():10.8f}  Accuracy: {accuracy:2.2f}')

        # create image of the prediction, before and after
        

Epoch:  10/ 50  Loss: 0.39368773  Accuracy: 0.47
Epoch:  20/ 50  Loss: 0.39368775  Accuracy: 0.47
Epoch:  30/ 50  Loss: 0.39368775  Accuracy: 0.47
Epoch:  40/ 50  Loss: 0.39368775  Accuracy: 0.47
Epoch:  50/ 50  Loss: 0.39368775  Accuracy: 0.47


# Predict

In [None]:
# Initialize the Model and Load the Weights from file
#model = SnaekRNN(input_size=input_size, output_size=output_size, hidden_size=hidden_size, num_layers=num_layers)
#model.load_state_dict(torch.load('snake_nn_weights_2x2x2x2x4x4x4x4'))
with torch.no_grad(): # Don't track history for testing 

    inputs = torch.tensor([[food, death, empty]]).float()

    movement = model(inputs)

    print(movement)

tensor([ -31.7520, -139.8530])
