In [13]:
import jax
import jax.numpy as jnp
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import jax_md

# Generate a grid-based planar graph
def generate_grid_graph(size):
    G = nx.grid_2d_graph(size, size)
    pos = {node: (node[0] / (size - 1), node[1] / (size - 1)) for node in G.nodes}  # normalize positions to [0, 1] range
    return G, pos

# Define the energy function for auxetic behavior
def energy_function(positions, box):
    pos = jnp.array(positions, dtype=jnp.float32)
    displacement_fn, _ = jax_md.space.periodic_general(box, fractional_coordinates=False)
    
    # Calculate pairwise displacements
    n = len(positions)
    displacement = jnp.zeros((n, n, 2), dtype=jnp.float32)
    for i in range(n):
        for j in range(i + 1, n):
            displacement_ij = displacement_fn(pos[i], pos[j])
            displacement = jax.ops.index_add(displacement, (i, j), displacement_ij)
            displacement = jax.ops.index_add(displacement, (j, i), -displacement_ij)
    
    dist_matrix = jnp.sqrt(jnp.sum(displacement ** 2, axis=-1))
    target_distance = jnp.mean(dist_matrix)
    auxetic_penalty = jnp.sum((dist_matrix - target_distance) ** 2)
    
    boundary_penalty = jnp.sum(jnp.maximum(pos - 1, 0) ** 2) + jnp.sum(jnp.maximum(-pos, 0) ** 2)
    
    return auxetic_penalty + boundary_penalty

# Optimize the graph using JAX MD
def optimize_graph(G, pos, box, steps=1000, learning_rate=1e-3):
    positions = np.array([pos[n] for n in G.nodes], dtype=np.float32)
    energy_fn = lambda positions: energy_function(positions, box)
    grad_fn = jax.grad(energy_fn)
    
    positions = jnp.array(positions, dtype=jnp.float32)
    
    for step in range(steps):
        grad = grad_fn(positions)
        positions -= learning_rate * grad
        
        if step % 100 == 0:
            print(f"Step {step}: Energy = {energy_fn(positions)}")
    
    return positions

# Calculate the Poisson ratio using JAX MD elasticity routines
def calculate_poisson_ratio(energy_fn, positions, box):
    moduli_fn = jax_md.elasticity.athermal_moduli(energy_fn)
    moduli = moduli_fn(positions, box)
    C = moduli[0]
    poisson_ratio = -C[0, 1] / C[0, 0]
    return poisson_ratio

# Main script
grid_size = 5
box = jnp.eye(2) * (grid_size - 1)  # Define a box with the size of the grid

G, pos = generate_grid_graph(grid_size)
optimized_positions = optimize_graph(G, pos, box)

# Calculate Poisson ratio
positions_array = np.array([optimized_positions[n] for n in G.nodes], dtype=np.float32)
energy_fn = lambda positions: energy_function(positions, box)
poisson_ratio = calculate_poisson_ratio(energy_fn, positions_array, box)
print(f"Poisson Ratio: {poisson_ratio}")

# Update the positions in the graph
for i, n in enumerate(G.nodes):
    pos[n] = optimized_positions[i]

# Plot the optimized graph
plt.figure(figsize=(8, 8))
nx.draw(G, pos, with_labels=True, node_size=300, node_color="skyblue", edge_color="gray")
plt.show()


AttributeError: module 'jax.ops' has no attribute 'index_add'