<a href="https://colab.research.google.com/github/elichen/nematode/blob/main/nematode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

In [29]:
# Parameters
grid_size = 100
initial_worm_position = (grid_size // 2, grid_size // 2)  # Center
np.random.seed(42)  # For reproducible random locations
num_data_generation_steps = 1000

In [30]:
class HeuristicController:
    def decide_move(self, worm_position, pellet_positions, grid_size):
        if pellet_positions:
            closest_pellet = min(pellet_positions, key=lambda x: abs(x[0] - worm_position[0]) + abs(x[1] - worm_position[1]))
            move_direction = (np.sign(closest_pellet[0] - worm_position[0]), np.sign(closest_pellet[1] - worm_position[1]))
        else:
            move_direction = (0, 0)  # No movement if no pellets
        return move_direction

In [31]:
class WormEnvironment:
    def __init__(self, grid_size, worm_position, pellet_positions, controller):
        self.grid_size = grid_size
        self.worm_position = worm_position
        self.pellet_positions = pellet_positions
        self.controller = controller
        self.score = 0

    def update(self):
        move_direction = self.controller.decide_move(self.worm_position, self.pellet_positions, self.grid_size)
        # Ensure the worm stays within bounds
        new_position = (min(max(self.worm_position[0] + move_direction[0], 0), self.grid_size-1),
                        min(max(self.worm_position[1] + move_direction[1], 0), self.grid_size-1))
        self.worm_position = new_position
        if self.worm_position in self.pellet_positions:
            self.pellet_positions.remove(self.worm_position)
            self.score += 1
        return self.worm_position, self.pellet_positions, self.score

    def draw(self):
        grid = np.zeros((self.grid_size, self.grid_size))
        grid[self.worm_position] = -1  # Worm's position
        for pellet in self.pellet_positions:
            grid[pellet] = 1  # Pellet's position
        return grid

In [32]:
fig, ax = plt.subplots()
pellet_positions = [(np.random.randint(0, grid_size), np.random.randint(0, grid_size)) for _ in range(5)]
environment = WormEnvironment(grid_size, initial_worm_position, pellet_positions, HeuristicController())

def animate(i):
    ax.clear()
    position, pellets, score = environment.update()
    grid = environment.draw()
    ax.imshow(grid, cmap='viridis')
    ax.set_title(f'Step: {i} Score: {score}')
    return ax

ani = animation.FuncAnimation(fig, animate, frames=200, interval=100, blit=False, repeat=False)
plt.close(fig)
HTML(ani.to_jshtml())

In [33]:
controller = HeuristicController()
environment = WormEnvironment(grid_size, initial_worm_position, pellet_positions, controller)

# Data storage
inputs = []  # To store the state: worm position and closest pellet position
outputs = []  # To store the action taken

def reset_environment(environment):
    # Generate new random pellet positions
    new_pellet_positions = [(np.random.randint(0, grid_size), np.random.randint(0, grid_size)) for _ in range(10)]
    environment.pellet_positions = new_pellet_positions

for _ in range(num_data_generation_steps):
    worm_position = environment.worm_position

    if not environment.pellet_positions:  # If no pellets left
        reset_environment(environment)  # Reset or reintroduce pellets

    closest_pellet = min(environment.pellet_positions, key=lambda x: abs(x[0] - worm_position[0]) + abs(x[1] - worm_position[1]))
    normalized_worm_position = (worm_position[0] / grid_size, worm_position[1] / grid_size)
    normalized_pellet_position = (closest_pellet[0] / grid_size, closest_pellet[1] / grid_size)
    current_input = normalized_worm_position + normalized_pellet_position

    action = controller.decide_move(worm_position, environment.pellet_positions, grid_size)
    inputs.append(current_input)
    outputs.append(action)

    environment.update()

len(inputs), len(outputs)

(1000, 1000)

In [38]:
import torch
import torch.nn as nn
import torch.optim as optim

class NeuralController(nn.Module):
    def __init__(self, input_size, output_size):
        super(NeuralController, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 128),  # Input layer
            nn.ReLU(),
            nn.Linear(128, 64),  # Hidden layer
            nn.ReLU(),
            nn.Linear(64, output_size)  # Output layer
        )

    def forward(self, x):
        return self.network(x)

    def decide_move(self, worm_position, pellet_positions, grid_size):
        if not pellet_positions:
            # Option 1: Return a default move, e.g., stay in place
            return (0, 0)

        # Prepare the input for the model, e.g., normalize and format the positions
        closest_pellet = min(pellet_positions, key=lambda x: abs(x[0] - worm_position[0]) + abs(x[1] - worm_position[1]))
        normalized_worm_position = [worm_position[0] / grid_size, worm_position[1] / grid_size]
        normalized_pellet_position = [closest_pellet[0] / grid_size, closest_pellet[1] / grid_size]
        inputs = torch.tensor(normalized_worm_position + normalized_pellet_position, dtype=torch.float).unsqueeze(0)  # Add batch dimension

        # Predict the next move using the model
        with torch.no_grad():  # Disable gradient computation
            outputs = self.forward(inputs)
            action = outputs.argmax(1).item()  # Get the index of the max log-probability

        # Translate the model output to a move direction
        # This mapping depends on how you've structured your model outputs
        if action == 0:
            return (0, -1)  # Up
        elif action == 1:
            return (0, 1)  # Down
        elif action == 2:
            return (-1, 0)  # Left
        elif action == 3:
            return (1, 0)  # Right

In [39]:
from torch.utils.data import TensorDataset, DataLoader

# Convert lists to tensors
inputs_tensor = torch.tensor(inputs, dtype=torch.float)
outputs_tensor = torch.tensor(outputs, dtype=torch.long)

# Create a TensorDataset and DataLoader
dataset = TensorDataset(inputs_tensor, outputs_tensor)
data_loader = DataLoader(dataset, batch_size=64, shuffle=True)

  inputs_tensor = torch.tensor(inputs, dtype=torch.float)
  outputs_tensor = torch.tensor(outputs, dtype=torch.long)


In [43]:
input_size = 4  # Assuming 2 for worm position + 2 for closest pellet position
output_size = 4  # Assuming 4 possible actions: up, down, left, right

model = NeuralController(input_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 1000  # Number of epochs to train for
print_interval = num_epochs // 10  # Calculate print interval

for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in data_loader:
        # Convert labels from one-hot to indices if necessary
        _, labels_indices = labels.max(dim=1)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels_indices)  # Use labels_indices here
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    if (epoch+1) % print_interval == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(data_loader)}')

Epoch [100/1000], Loss: 0.4255770742893219
Epoch [200/1000], Loss: 0.183755561709404
Epoch [300/1000], Loss: 0.10274732112884521
Epoch [400/1000], Loss: 0.06657607108354568
Epoch [500/1000], Loss: 0.041227299720048904
Epoch [600/1000], Loss: 0.024723617359995842
Epoch [700/1000], Loss: 0.015230825170874596
Epoch [800/1000], Loss: 0.009911916218698025
Epoch [900/1000], Loss: 0.005970858037471771
Epoch [1000/1000], Loss: 0.0041078003123402596


In [44]:
fig, ax = plt.subplots()
model.eval()
pellet_positions = [(np.random.randint(0, grid_size), np.random.randint(0, grid_size)) for _ in range(5)]
environment = WormEnvironment(grid_size, initial_worm_position, pellet_positions, model)

def animate(i):
    ax.clear()
    position, pellets, score = environment.update()
    grid = environment.draw()
    ax.imshow(grid, cmap='viridis')
    ax.set_title(f'Step: {i} Score: {score}')
    return ax

ani = animation.FuncAnimation(fig, animate, frames=200, interval=100, blit=False, repeat=False)
plt.close(fig)
HTML(ani.to_jshtml())