<a href="https://colab.research.google.com/github/bylehn/auxetic_networks_jaxmd/blob/abhishek/test-auxetic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import jax.numpy as np
from jax import random
from jax.config import config; config.update("jax_enable_x64", True)
from jax_md import space, energy, minimize, simulate, quantity
from jax import random, grad
from jax import jit
from jax import lax
import networkx as nx
import numpy as onp
from scipy.spatial import Delaunay
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt

In [2]:
def createDelaunayGraph(NS, rseed, r_c, del_x):

    # This function creates a Delaunay graph of a set of points.

    # Parameters:
    #   NS: The number of points to generate.
    #   rseed: The random seed to use.
    #   r_c: The radius of the circumcircle of each edge in the graph.
    #   del_x: max noise magnitude from square lattice

    # Returns:
    #   N: The number of points in the graph.
    #   G: The graph object.
    #   X: The coordinates of the points.
    #   E: The edges of the graph.

    # Set the random seed.
    onp.random.seed(rseed)

    # Generate the points.
    xm, ym = onp.meshgrid(onp.arange(1, NS + 1), onp.arange(1, NS + 1))
    X = onp.vstack((xm.flatten(), ym.flatten())).T
    N = X.shape[0]

    # Add some noise to the points.
    X = X + del_x * 2 * (0.5 - onp.random.rand(N, 2))

    # Create the Delaunay triangulation.
    DT = Delaunay(X)

    # Get the edges of the triangulation.
    ET = onp.empty((0, 2), dtype=int)
    for T in DT.simplices:
        ET = onp.vstack((ET, [T[0], T[1]], [T[1], T[2]], [T[0], T[2]]))

    # Sort the edges.
    ET = onp.sort(ET)

    # Get the radii of the circumcircles of the edges.
    R = onp.linalg.norm(X[ET[:, 0], :] - X[ET[:, 1], :], axis=1)

    # Keep only the edges with radii less than r_c.
    EN = ET[R < r_c, :]

    # Create the adjacency matrix.
    A = onp.zeros((N, N))
    A[EN[:, 0], EN[:, 1]] = 1

    # Get the lengths of the edges.
    L = onp.linalg.norm(X[ET[:, 0], :] - X[ET[:, 1], :], axis=1)

    # Keep only the edges with lengths less than r_c.
    EL = L[R < r_c]

    # Create the graph object.
    G = nx.Graph(A)

    # Get the edges of the graph.
    E = onp.array(G.edges)

    # Get the lengths of the edges.
    L = onp.linalg.norm(X[E[:, 0], :] - X[E[:, 1], :], axis=1)

    return N, G, X, E, L

def getSurfaceNodes(G, NS):
    # Retrieve the list of nodes in the graph G
    nodes = np.array(list(G.nodes))
    # Calculate the x and y coordinates of the nodes based on the grid size NS
    x_values = nodes % NS
    y_values = nodes // NS
    # Find the nodes located on the top surface (y = NS - 1)
    top_nodes = nodes[y_values == NS - 1]
    # Find the nodes located on the bottom surface (y = 0)
    bottom_nodes = nodes[y_values == 0]
    # Find the nodes located on the left surface (x = 0)
    left_nodes = nodes[x_values == 0]
    # Find the nodes located on the right surface (x = NS - 1)
    right_nodes = nodes[x_values == NS - 1]
    # Return a dictionary with surface names as keys and node arrays as values
    return {
        'top': top_nodes,
        'bottom': bottom_nodes,
        'left': left_nodes,
        'right': right_nodes
    }

In [3]:
def make_box(R, padding):
    """
    Defines a box length

    R: position matrix
    padding: amount of space to add to the box
    """
    box_length = (np.max((np.max(R[:,0], R[:,1])) - np.min(((np.min(R[:,0], R[:,1])))))) + padding
    return box_length
    
def create_spring_constants(R,E,k_1):
    """
    Creates spring constants for each edge in the graph

    k_1: spring constant for a spring of unit length
    R: position matrix
    E: edge matrix
    """
    displacements = R[E[:, 0],:] - R[E[:, 1], :]
    distance = np.linalg.norm(displacements, axis=1)
    return (k_1/distance).reshape(-1,1), distance

@jit
def compute_distance(point1, point2):
    """
    Calculate the Euclidean distance between two points.
    """
    return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)


#@jit
def constrained_force_fn(R, energy_fn, left_indices, right_indices):
    """
    Calculates forces with frozen sources.

    R: position matrix
    energy_fn: energy function
    source_idx: indices of the source nodes
    """
    force_fn = quantity.force(energy_fn)

    def new_force_fn(R):
        total_force = force_fn(R)
        total_force = total_force.at[left_indices, :].set(0.0)
        total_force = total_force.at[right_indices, :].set(0.0)
        return total_force

    return new_force_fn

@jit
def fitness(poisson):
    """
    Constructs a fitness function based on the Poisson ratio.
    """
    return (poisson + 1)**2

@jit
def poisson_ratio(initial_horizontal, initial_vertical, final_horizontal, final_vertical):
    """
    Calculate the Poisson ratio based on average edge positions.
    
    initial_horizontal: initial horizontal edge positions
    initial_vertical: initial vertical edge positions
    final_horizontal: final horizontal edge positions
    final_vertical: final vertical edge positions
    """

    delta_horizontal = final_horizontal - initial_horizontal
    delta_vertical = final_vertical - initial_vertical

    return -delta_vertical / delta_horizontal

@jit
def update_kbonds(gradients, k_bond, learning_rate = 0.01):
    """
    Updates spring constants based on gradients.

    
    """
    gradients_perpendicular = gradients - np.mean(gradients)
    gradients_normalized = gradients_perpendicular / np.max(gradients_perpendicular)
    k_bond_new = k_bond * (1 - learning_rate * gradients_normalized)

    return k_bond_new  

In [16]:
steps = 200
write_every = 10
perturbation = 0.5


def simulate_auxetic(R, k_bond, shift, surface_nodes, perturbation, displacement, E, bond_lengths):
    """
    Simulates the allosteric process.

    """
    # Get the surface nodes.
    top_indices = surface_nodes['top']
    bottom_indices = surface_nodes['bottom']
    left_indices = surface_nodes['left']
    right_indices = surface_nodes['right']     
    
    log_first_min = {
    #'energy': np.zeros((steps,)),
    'position': np.zeros((steps // write_every,) + R.shape) 
    }
    log_second_min = {
    #'energy': np.zeros((steps,)),
    'position': np.zeros((steps // write_every,) + R.shape)
    }

    def step_fn(i, state_and_log):
        """
        Minimizes the configuration at each step.

        i: step number
        state_and_log: state and log dictionary
        """
        fire_state, log = state_and_log
        
        #energy = quantity.energy(energy_fn)
        #log['energy'] = log['energy'].at[i].set(energy(fire_state.position))
        #log['energy'] = lax.cond(i % write_every == 0,
        #                         lambda e: e.at[i // write_every].set(np.array(energy_fn(fire_state.position))),
        #                         lambda e: e,
        #                         log['energy'])

        log['position'] = lax.cond(i % write_every == 0,
                                lambda p: p.at[i // write_every].set(fire_state.position),
                                lambda p: p,
                                log['position'])
        
        
        fire_state = fire_apply(fire_state)
        return fire_state, log

    # First minimization before pinching the source nodes.
    energy_fn = energy.simple_spring_bond(displacement, E, length=bond_lengths, epsilon=k_bond[:, 0])  

    fire_init, fire_apply = minimize.fire_descent(energy_fn, shift)
    fire_apply = jit(fire_apply)
    #step = jit(lambda i, state: fire_apply(state))
    fire_state = fire_init(R)
    fire_state, log = lax.fori_loop(0, steps, step_fn, (fire_state, log_first_min))
    #fire_state = lax.fori_loop(0, steps, step, fire_state)
    R_init = fire_state.position
    # Initial dimensions (before deformation)
    initial_horizontal = np.mean(R[right_indices], axis=0)[0] - np.mean(R[left_indices], axis=0)[0]
    initial_vertical = np.mean(R[top_indices], axis=0)[1] - np.mean(R[bottom_indices], axis=0)[1]

    # Shift the left edge.
    R_init = R_init.at[left_indices, 0].add(perturbation)

    # Second minimization after pinching the source nodes.
    energy_fn = energy.simple_spring_bond(displacement, E, length=bond_lengths, epsilon=k_bond[:, 0])
    force_fn = constrained_force_fn(R_init, energy_fn, left_indices, right_indices)
    fire_init, fire_apply = minimize.fire_descent(force_fn, shift)
    fire_state = fire_init(R_init)
    #fire_state = lax.fori_loop(0, steps, step, fire_state)
    fire_state, log = lax.fori_loop(0, steps, step_fn, (fire_state, log_second_min))
    R_final = fire_state.position
    # Final dimensions (after deformation)
    final_horizontal = np.mean(R_final[right_indices], axis=0)[0] - np.mean(R_final[left_indices], axis=0)[0]
    final_vertical = np.mean(R_final[top_indices], axis=0)[1] - np.mean(R_final[bottom_indices], axis=0)[1]

    # Calculate the poisson ratio.
    poisson = poisson_ratio(initial_horizontal, initial_vertical, final_horizontal, final_vertical)
    fit = fitness(poisson)
    
    return fit, poisson#, traj1, traj2, step

    

In [17]:
#create graph
N,G,X,E,bond_lengths =createDelaunayGraph(10, 25, 2.0, 0.4)
R = np.array(X)
k_bond, _ = create_spring_constants(R,E,1.0)
surface_nodes = getSurfaceNodes(G, 10)
displacement, shift = space.free() #displacement = points in space, shift = small shifts of each particle
grad_f = grad(simulate_auxetic, argnums=1) 

In [18]:
simulate_auxetic(R, k_bond, shift, surface_nodes, perturbation, displacement, E, bond_lengths)

(Array(50.91042257, dtype=float64), Array(6.135154, dtype=float64))

In [9]:
%timeit grad_f(R, k_bond, shift, surface_nodes, perturbation, displacement, E, bond_lengths)

2.34 s ± 102 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
poisson

NameError: name 'poisson' is not defined

In [None]:
opt_steps = 1000
k_temp = k_bond
for i in range(opt_steps):
    net_fitness = simulate_auxetic(R, k_temp, shift, surface_nodes, perturbation, displacement, E, bond_lengths)
    gradients = grad_f(R, k_temp, shift, surface_nodes, perturbation, displacement, E, bond_lengths)
    k_temp = update_kbonds(gradients, k_temp)
    print(i, np.max(gradients), net_fitness)

In [17]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots

# Convert the trajectory to a NumPy array
trajectory_array = np.array(traj_shifted)

# Create a subplot for the scatter plot
fig = make_subplots(rows=1, cols=1, specs=[[{'type': 'scatter'}]], print_grid=False)

# Create a scatter plot using Plotly
trace = go.Scatter(x=trajectory_array[:, 0], y=trajectory_array[:, 1], mode='markers')
fig.add_trace(trace)

# Create and add slider
steps = []
for i, frame in enumerate(trajectory_array):
    step = dict(
        method="update",
        args=[
            {"x": [frame[:, 0]], "y": [frame[:, 1]]},
        ],
        label=str(i),
    )
    steps.append(step)

slider = dict(steps=steps, active=0, pad={"t": 50}, currentvalue={"prefix": "Frame: "})

fig.update_layout(sliders=[slider])

# Update the layout
fig.update_layout(
    width=600,
    height=600,
    xaxis=dict(title='X', showgrid=False),
    yaxis=dict(title='Y', showgrid=False),
    plot_bgcolor='rgba(255, 255, 255, 1)', # White background
    margin=dict(l=50, r=50, b=50, t=50) # Adjust margin if needed
)

fig.show()
