<a href="https://colab.research.google.com/github/bylehn/auxetic_networks_jaxmd/blob/abhishek/test-auxetic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as np
import numpy as onp
from jax import random
from jax.config import config; config.update("jax_enable_x64", True)
from jax_md import space, energy, minimize, simulate, quantity
from jax import random, grad
from jax import jit
from jax import lax
import networkx as nx
from scipy.spatial import Delaunay
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt
from matplotlib import animation
import seaborn as sns
  
sns.set_style(style='white')

def format_plot(x, y):  
  plt.xlabel(x, fontsize=20)
  plt.ylabel(y, fontsize=20)
  
def finalize_plot(shape=(1, 1)):
  plt.gcf().set_size_inches(
    shape[0] * 1.5 * plt.gcf().get_size_inches()[1], 
    shape[1] * 1.5 * plt.gcf().get_size_inches()[1])
  plt.tight_layout()

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [32]:
!pip install JSAnimation

Collecting JSAnimation
  Downloading JSAnimation-0.1.tar.gz (8.9 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: JSAnimation
  Building wheel for JSAnimation (setup.py) ... [?25ldone
[?25h  Created wheel for JSAnimation: filename=JSAnimation-0.1-py3-none-any.whl size=11407 sha256=b8a600b5d8d2e771f8db82eb0442440f1f1597a26eae9ea22c0e13c992ea33ee
  Stored in directory: /home/fabian/.cache/pip/wheels/1e/16/64/028f540fe8f5eae5026a423bfd88356248074379b79f27b646
Successfully built JSAnimation
Installing collected packages: JSAnimation
Successfully installed JSAnimation-0.1


In [2]:
def createDelaunayGraph(NS, rseed, r_c, del_x):

    # This function creates a Delaunay graph of a set of points.

    # Parameters:
    #   NS: The number of points to generate.
    #   rseed: The random seed to use.
    #   r_c: The radius of the circumcircle of each edge in the graph.
    #   del_x: max noise magnitude from square lattice

    # Returns:
    #   N: The number of points in the graph.
    #   G: The graph object.
    #   X: The coordinates of the points.
    #   E: The edges of the graph.

    # Set the random seed.
    onp.random.seed(rseed)

    # Generate the points.
    xm, ym = onp.meshgrid(onp.arange(1, NS + 1), onp.arange(1, NS + 1))
    X = onp.vstack((xm.flatten(), ym.flatten())).T
    N = X.shape[0]

    # Add some noise to the points.
    X = X + del_x * 2 * (0.5 - onp.random.rand(N, 2))

    # Create the Delaunay triangulation.
    DT = Delaunay(X)

    # Get the edges of the triangulation.
    ET = onp.empty((0, 2), dtype=int)
    for T in DT.simplices:
        ET = onp.vstack((ET, [T[0], T[1]], [T[1], T[2]], [T[0], T[2]]))

    # Sort the edges.
    ET = onp.sort(ET)

    # Get the radii of the circumcircles of the edges.
    R = onp.linalg.norm(X[ET[:, 0], :] - X[ET[:, 1], :], axis=1)

    # Keep only the edges with radii less than r_c.
    EN = ET[R < r_c, :]

    # Create the adjacency matrix.
    A = onp.zeros((N, N))
    A[EN[:, 0], EN[:, 1]] = 1

    # Get the lengths of the edges.
    L = onp.linalg.norm(X[ET[:, 0], :] - X[ET[:, 1], :], axis=1)

    # Keep only the edges with lengths less than r_c.
    EL = L[R < r_c]

    # Create the graph object.
    G = nx.Graph(A)

    # Get the edges of the graph.
    E = onp.array(G.edges)

    # Get the lengths of the edges.
    L = onp.linalg.norm(X[E[:, 0], :] - X[E[:, 1], :], axis=1)

    return N, G, X, E, L

def getSurfaceNodes(G, NS):
    # Retrieve the list of nodes in the graph G
    nodes = np.array(list(G.nodes))
    # Calculate the x and y coordinates of the nodes based on the grid size NS
    x_values = nodes % NS
    y_values = nodes // NS
    # Find the nodes located on the top surface (y = NS - 1)
    top_nodes = nodes[y_values == NS - 1]
    # Find the nodes located on the bottom surface (y = 0)
    bottom_nodes = nodes[y_values == 0]
    # Find the nodes located on the left surface (x = 0)
    left_nodes = nodes[x_values == 0]
    # Find the nodes located on the right surface (x = NS - 1)
    right_nodes = nodes[x_values == NS - 1]
    # Return a dictionary with surface names as keys and node arrays as values
    return {
        'top': top_nodes,
        'bottom': bottom_nodes,
        'left': left_nodes,
        'right': right_nodes
    }

In [3]:
def make_box(R, padding):
    """
    Defines a box length

    R: position matrix
    padding: amount of space to add to the box
    """
    box_length = (np.max((np.max(R[:,0], R[:,1])) - np.min(((np.min(R[:,0], R[:,1])))))) + padding
    return box_length
    
def create_spring_constants(R,E,k_1):
    """
    Creates spring constants for each edge in the graph

    k_1: spring constant for a spring of unit length
    R: position matrix
    E: edge matrix
    """
    displacements = R[E[:, 0],:] - R[E[:, 1], :]
    distance = np.linalg.norm(displacements, axis=1)
    return (k_1/distance).reshape(-1,1), distance

@jit
def compute_distance(point1, point2):
    """
    Calculate the Euclidean distance between two points.
    """
    return np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)


#@jit
def constrained_force_fn(R, energy_fn, left_indices, right_indices, mask):
    """
    Calculates forces with frozen edges.

    R: position matrix
    energy_fn: energy function
    left_indices: indices of left boundary nodes
    right_indices: indices of right boundary nodes
    """
    
    
    def new_force_fn(R):
        force_fn = quantity.force(energy_fn)
        total_force = force_fn(R)
        total_force *= mask
        return total_force

    return new_force_fn


@jit
def fitness(poisson):
    """
    Constructs a fitness function based on the Poisson ratio.
    """
    return (poisson + 1)**2

@jit
def poisson_ratio(initial_horizontal, initial_vertical, final_horizontal, final_vertical):
    """
    Calculate the Poisson ratio based on average edge positions.
    
    initial_horizontal: initial horizontal edge positions
    initial_vertical: initial vertical edge positions
    final_horizontal: final horizontal edge positions
    final_vertical: final vertical edge positions
    output: Poisson ratio
    """

    delta_horizontal = final_horizontal - initial_horizontal
    delta_vertical = final_vertical - initial_vertical

    return -delta_vertical / delta_horizontal

@jit
def update_kbonds(gradients, k_bond, learning_rate = 0.01):
    """
    Updates spring constants based on gradients.

    
    """
    gradients_perpendicular = gradients - np.mean(gradients)
    gradients_normalized = gradients_perpendicular / np.max(gradients_perpendicular)
    k_bond_new = k_bond * (1 - learning_rate * gradients_normalized)

    return k_bond_new

@jit
def compute_force_norm(fire_state):
    return np.linalg.norm(fire_state.force)


def remove_zero_rows(log_dict):
    """
    Remove rows (entries) in the log dictionary that are all zeros.
    """
    for key in log_dict:
        log_dict[key] = log_dict[key][~np.all(log_dict[key] == 0.0, axis=(1, 2))]
    return log_dict

In [10]:
steps = 500
write_every = 10
perturbation = 0.1


def simulate_auxetic(R, k_bond, shift, surface_nodes, perturbation, displacement, E, bond_lengths):
    """
    Simulates the auxetic process.

    """
    # Get the surface nodes.
    top_indices = surface_nodes['top']
    bottom_indices = surface_nodes['bottom']
    left_indices = surface_nodes['left']
    right_indices = surface_nodes['right']  
    mask = np.ones(R.shape)   
    mask = mask.at[left_indices].set(0)
    mask = mask.at[right_indices].set(0)

    log_first_min = {
    'force': np.zeros((steps // write_every,) + R.shape),
    'position': np.zeros((steps // write_every,) + R.shape) 
    }

    log_second_min = {
    'force': np.zeros((steps // write_every,) + R.shape),
    'position': np.zeros((steps // write_every,) + R.shape)
    }

    def step_fn_generator(apply):
        def step_fn(i, state_and_log):
            """
            Minimizes the configuration at each step.

            i: step number
            state_and_log: state and log dictionary
            """
            fire_state, log = state_and_log
            
            #energy = quantity.energy(energy_fn)
            #log['energy'] = log['energy'].at[i].set(energy(fire_state.position))
            #log['energy'] = lax.cond(i % write_every == 0,
            #                         lambda e: e.at[i // write_every].set(np.array(energy_fn(fire_state.position))),
            #                         lambda e: e,
            #                         log['energy'])
            log['force'] = lax.cond(i % write_every == 0,
                                    lambda p: p.at[i // write_every].set(fire_state.force),
                                    lambda p: p,
                                    log['force'])
            
            log['position'] = lax.cond(i % write_every == 0,
                                    lambda p: p.at[i // write_every].set(fire_state.position),
                                    lambda p: p,
                                    log['position'])
            
            
            fire_state = apply(fire_state)
            return fire_state, log

        return step_fn

    # First minimization before pinching the source nodes.
    energy_fn = energy.simple_spring_bond(displacement, E, length=bond_lengths, epsilon=k_bond[:, 0])  

    fire_init, fire_apply = minimize.fire_descent(energy_fn, shift)
    fire_apply = jit(fire_apply)
    fire_state = fire_init(R)
    step_fn = step_fn_generator(fire_apply)
    fire_state, log_first_min = lax.fori_loop(0, steps, step_fn, (fire_state, log_first_min))
    R_init = fire_state.position
    # Initial dimensions (before deformation)
    initial_horizontal = np.mean(R[right_indices], axis=0)[0] - np.mean(R[left_indices], axis=0)[0]
    initial_vertical = np.mean(R[top_indices], axis=0)[1] - np.mean(R[bottom_indices], axis=0)[1]

    # Shift the left edge.
    R_init = R_init.at[left_indices, 0].add(perturbation)

    # Second minimization after pinching the source nodes.
    energy_fn = energy.simple_spring_bond(displacement, E, length=bond_lengths, epsilon=k_bond[:, 0])
    force_fn = constrained_force_fn(R_init, energy_fn, left_indices, right_indices, mask)
    fire_init, fire_apply = minimize.fire_descent(force_fn, shift)
    fire_state = fire_init(R_init)
    step_fn = step_fn_generator(fire_apply)
    fire_state, log_second_min = lax.fori_loop(0, steps, step_fn, (fire_state, log_second_min))
    R_final = fire_state.position
    # Final dimensions (after deformation)
    final_horizontal = np.mean(R_final[right_indices], axis=0)[0] - np.mean(R_final[left_indices], axis=0)[0]
    final_vertical = np.mean(R_final[top_indices], axis=0)[1] - np.mean(R_final[bottom_indices], axis=0)[1]

    # Calculate the poisson ratio.
    poisson = poisson_ratio(initial_horizontal, initial_vertical, final_horizontal, final_vertical)
    fit = fitness(poisson)
    
    return fit #, poisson, log_first_min, log_second_min #, traj1, traj2, step

    

In [11]:
#create graph
N,G,X,E,bond_lengths =createDelaunayGraph(10, 25, 2.0, 0.4)
R = np.array(X)
k_bond, _ = create_spring_constants(R,E,1.0)
surface_nodes = getSurfaceNodes(G, 10)
displacement, shift = space.free() #displacement = points in space, shift = small shifts of each particle
grad_f = grad(simulate_auxetic, argnums=1) 

In [26]:
surface_nodes['left']
#surface_nodes['right']

Array([ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90], dtype=int64)

In [6]:
fit, poisson, traj1, traj2 = simulate_auxetic(R, k_bond, shift, surface_nodes, perturbation, displacement, E, bond_lengths)

In [12]:
poisson

Array(0.17124101, dtype=float64)

In [None]:
traj2['force'][1]

In [13]:
opt_steps = 200
k_temp = k_bond
for i in range(opt_steps):
    net_fitness = simulate_auxetic(R, k_temp, shift, surface_nodes, perturbation, displacement, E, bond_lengths)
    gradients = grad_f(R, k_temp, shift, surface_nodes, perturbation, displacement, E, bond_lengths)
    k_temp = update_kbonds(gradients, k_temp)
    print(i, np.max(gradients), net_fitness)

0 0.024324568145968894 1.3718055103843918
1 0.024475382926723147 1.3690813644657622
2 0.0246269886042878 1.366394270933596
3 0.02477937065208631 1.3637432429821665
4 0.024932514962797917 1.3611273347123676
5 0.025086407815863934 1.3585456386637942
6 0.02524103584757678 1.3559972835488174
7 0.025396386024212483 1.3534814321653568
8 0.025552445617486243 1.350997279474085
9 0.025709202182049605 1.3485440508210333
10 0.02586664353561679 1.3461210002927844
11 0.026024757739558323 1.3437274091926699
12 0.026183533082738887 1.3413625846242496
13 0.026342958065671704 1.3390258581758718
14 0.026503021386534136 1.3367165846939717
15 0.026663711927973257 1.3344341411395089
16 0.026825018745488625 1.3321779255200863
17 0.026986931056646724 1.3299473558903498
18 0.027149438231005324 1.3277418694168257
19 0.027312529780995465 1.3255609215003925
20 0.027476195353904196 1.3234039849523154
21 0.027640424723923213 1.3212705492207588
22 0.027805207785325595 1.3191601196617275
23 0.027970534546185172 1.31

KeyboardInterrupt: 

In [34]:
ms = 30
R_plt = onp.array(traj2['position'][-1])

plt.plot(R_plt[:N, 0], R_plt[:N, 1], 'o', markersize=ms * 0.5)

# Plotting bonds
for bond in E:
    point1 = R_plt[bond[0]]
    point2 = R_plt[bond[1]]
    plt.plot([point1[0], point2[0]], [point1[1], point2[1]], c='black')  # Using black for bond color


plt.xlim([0, np.max(R_plt[:, 0])])
plt.ylim([0, np.max(R_plt[:, 1])])

plt.axis('on')

finalize_plot((1, 1))

<IPython.core.display.Javascript object>

In [35]:
%matplotlib notebook
from matplotlib.animation import FuncAnimation
from JSAnimation.IPython_display import display_animation
from IPython.display import HTML

# Set style
sns.set_style(style='white')

# Define the init function, which sets up the plot
def init():
    plt.xlim([0, np.max(traj2['position'][:, :, 0])])
    plt.ylim([0, np.max(traj2['position'][:, :, 1])])
    plt.axis('on')
    return plt

# Define the update function, which is called for each frame
def update(frame):
    plt.clf()  # Clear the current figure
    R_plt = traj2['position'][frame]
    plt.plot(R_plt[:N, 0], R_plt[:N, 1], 'o', markersize=ms * 0.5)

    # Plotting bonds
    for bond in E:
        point1 = R_plt[bond[0]]
        point2 = R_plt[bond[1]]
        plt.plot([point1[0], point2[0]], [point1[1], point2[1]], c='black')  # Using black for bond color
    return plt

# Create the animation
ani = FuncAnimation(plt.figure(), update, frames=range(len(traj2['position'])), init_func=init, blit=False)

# Display the animation
HTML(ani.to_jshtml())

<IPython.core.display.Javascript object>