<a href="https://colab.research.google.com/github/elichen/nematode/blob/main/nematode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [89]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import random
from tqdm.notebook import tqdm

In [87]:
grid_size = 100
np.random.seed(42)  # For reproducible random locations

In [62]:
class HeuristicController:
    def decide_move(self, worm_position, pellet_positions, grid_size):
        if pellet_positions:
            closest_pellet = min(pellet_positions, key=lambda x: abs(x[0] - worm_position[0]) + abs(x[1] - worm_position[1]))
            move_direction = (np.sign(closest_pellet[0] - worm_position[0]), np.sign(closest_pellet[1] - worm_position[1]))
        else:
            move_direction = (0, 0)  # No movement if no pellets
        return move_direction

In [63]:
class WormEnvironment:
    def __init__(self, grid_size, controller):
        self.grid_size = grid_size
        self.controller = controller
        self.reset_environment()

    def reset_environment(self):
        self.worm_position = (random.randint(0, self.grid_size-1), random.randint(0, self.grid_size-1))
        self.pellet_positions = [(random.randint(0, self.grid_size-1), random.randint(0, self.grid_size-1)) for _ in range(10)]
        self.score = 0

    def update(self):
        move_direction = self.controller.decide_move(self.worm_position, self.pellet_positions, self.grid_size)
        # Ensure the worm stays within bounds
        new_position = (min(max(self.worm_position[0] + move_direction[0], 0), self.grid_size-1),
                        min(max(self.worm_position[1] + move_direction[1], 0), self.grid_size-1))
        self.worm_position = new_position
        if self.worm_position in self.pellet_positions:
            self.pellet_positions.remove(self.worm_position)
            self.score += 1
        return self.worm_position, self.pellet_positions, self.score

    def draw(self):
        grid = np.zeros((self.grid_size, self.grid_size))
        grid[self.worm_position] = -1  # Worm's position
        for pellet in self.pellet_positions:
            grid[pellet] = 1  # Pellet's position
        return grid

In [64]:
fig, ax = plt.subplots()
environment = WormEnvironment(grid_size, HeuristicController())

def animate(i):
    ax.clear()
    position, pellets, score = environment.update()
    grid = environment.draw()
    ax.imshow(grid, cmap='viridis')
    ax.set_title(f'Step: {i} Score: {score}')
    return ax

ani = animation.FuncAnimation(fig, animate, frames=200, interval=100, blit=False, repeat=False)
plt.close(fig)
HTML(ani.to_jshtml())

In [91]:
controller = HeuristicController()
environment = WormEnvironment(grid_size, controller)

# Data storage
inputs = []  # To store the state: worm position and closest pellet position
outputs = []  # To store the action taken

for _ in tqdm(range(100000)):
    environment.reset_environment()

    closest_pellet = min(environment.pellet_positions, key=lambda x: abs(x[0] - environment.worm_position[0]) + abs(x[1] - environment.worm_position[1]))
    normalized_worm_position = (environment.worm_position[0] / grid_size, environment.worm_position[1] / grid_size)
    normalized_pellet_position = (closest_pellet[0] / grid_size, closest_pellet[1] / grid_size)
    current_input = normalized_worm_position + normalized_pellet_position

    action = controller.decide_move(environment.worm_position, environment.pellet_positions, grid_size)
    inputs.append(current_input)
    outputs.append(action)

    environment.update()

len(inputs), len(outputs)

  0%|          | 0/100000 [00:00<?, ?it/s]

(100000, 100000)

In [99]:
inputs, outputs

(tensor([[0.6400, 0.5600, 0.5700, 0.5700],
         [0.0000, 0.0300, 0.6300, 0.1300],
         [0.0200, 0.6700, 0.0600, 0.7700],
         [0.7000, 0.0300, 0.5900, 0.0800],
         [0.6800, 0.0000, 0.5400, 0.1800],
         [0.6400, 0.7800, 0.6700, 0.9400],
         [0.5700, 0.2500, 0.6100, 0.1500],
         [0.2400, 0.3000, 0.2600, 0.2500],
         [0.3700, 0.7400, 0.4400, 0.7400],
         [0.2000, 0.9200, 0.0400, 0.6500],
         [0.3900, 0.0300, 0.4000, 0.3800],
         [0.3600, 0.2600, 0.3600, 0.2400],
         [0.1300, 0.2500, 0.1400, 0.3800],
         [0.5500, 0.7300, 0.4500, 0.7300],
         [0.3500, 0.5900, 0.5600, 0.6100],
         [0.1900, 0.9200, 0.1300, 0.7800],
         [0.1800, 0.5600, 0.4400, 0.5500],
         [0.1400, 0.6900, 0.0300, 0.4900],
         [0.2900, 0.4100, 0.1800, 0.5400],
         [0.1400, 0.0100, 0.1500, 0.8300],
         [0.1200, 0.9000, 0.0300, 0.8200],
         [0.1200, 0.7800, 0.1300, 0.5800],
         [0.3200, 0.4400, 0.3100, 0.6200],
         [0

In [92]:
import torch
import torch.nn as nn
import torch.optim as optim

class NeuralController(nn.Module):
    def __init__(self, input_size, output_size):
        super(NeuralController, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 128),  # Input layer
            nn.ReLU(),
            nn.Linear(128, 64),  # Hidden layer
            nn.ReLU(),
            nn.Linear(64, output_size)  # Output layer
        )

    def forward(self, x):
        return self.network(x)

    def decide_move(self, worm_position, pellet_positions, grid_size):
        # Prepare the input for the model, e.g., normalize and format the positions
        closest_pellet = min(pellet_positions, key=lambda x: abs(x[0] - worm_position[0]) + abs(x[1] - worm_position[1]))
        normalized_worm_position = [worm_position[0] / grid_size, worm_position[1] / grid_size]
        normalized_pellet_position = [closest_pellet[0] / grid_size, closest_pellet[1] / grid_size]
        inputs = torch.tensor(normalized_worm_position + normalized_pellet_position, dtype=torch.float).unsqueeze(0)  # Add batch dimension

        # Predict the next move using the model
        with torch.no_grad():  # Disable gradient computation
            outputs = self.forward(inputs)
            action = outputs.argmax(1).item()  # Get the index of the max log-probability

        # Translate the model output to a move direction
        # This mapping depends on how you've structured your model outputs
        if action == 0:
            return (0, -1)  # Up
        elif action == 1:
            return (0, 1)  # Down
        elif action == 2:
            return (-1, 0)  # Left
        elif action == 3:
            return (1, 0)  # Right

In [93]:
from torch.utils.data import TensorDataset, DataLoader

# Convert lists to tensors
inputs_tensor = torch.tensor(inputs, dtype=torch.float)
outputs_tensor = torch.tensor(outputs, dtype=torch.long)

# Create a TensorDataset and DataLoader
dataset = TensorDataset(inputs_tensor, outputs_tensor)
data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

In [96]:
input_size = 4  # Assuming 2 for worm position + 2 for closest pellet position
output_size = 4  # Assuming 4 possible actions: up, down, left, right

model = NeuralController(input_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10  # Number of epochs to train for
print_interval = num_epochs // 10  # Calculate print interval

for epoch in tqdm(range(num_epochs)):
    running_loss = 0.0
    for inputs, labels in data_loader:
        # Convert labels from one-hot to indices if necessary
        _, labels_indices = labels.max(dim=1)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels_indices)  # Use labels_indices here
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    if (epoch+1) % print_interval == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(data_loader)}')

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch [1/10], Loss: 0.2607015926445705
Epoch [2/10], Loss: 0.06950312646097544
Epoch [3/10], Loss: 0.04438646633645146
Epoch [4/10], Loss: 0.032963394584528675
Epoch [5/10], Loss: 0.02583160602803576
Epoch [6/10], Loss: 0.021729968462352537
Epoch [7/10], Loss: 0.01891262696592695
Epoch [8/10], Loss: 0.015851075337790344
Epoch [9/10], Loss: 0.012882223632307918
Epoch [10/10], Loss: 0.012448214556690616


In [98]:
fig, ax = plt.subplots()
model.eval()
pellet_positions = [(np.random.randint(0, grid_size), np.random.randint(0, grid_size)) for _ in range(5)]
environment = WormEnvironment(grid_size, model)

def animate(i):
    ax.clear()
    position, pellets, score = environment.update()
    grid = environment.draw()
    ax.imshow(grid, cmap='viridis')
    ax.set_title(f'Step: {i} Score: {score}')
    return ax

ani = animation.FuncAnimation(fig, animate, frames=200, interval=100, blit=False, repeat=False)
plt.close(fig)
HTML(ani.to_jshtml())