In [None]:
import jax.numpy as jnp
import jax
from jax import random
from scipy.spatial import Delaunay
import matplotlib.pyplot as plt
import numpy as np
import copy
import random as rppy

In [None]:
import jax.numpy as jnp
from jax import random

def generate_positions_jax(num_positions, x_range, y_range, min_distance, key):
    positions = []

    def distance(p1, p2):
        return jnp.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)

    def is_valid_position(new_position):
        return all(distance(new_position, pos) >= min_distance for pos in positions)

    while len(positions) < num_positions:
        key, subkey_x, subkey_y = random.split(key, 3)

        new_x = random.uniform(subkey_x, minval=x_range[0], maxval=x_range[1])
        new_y = random.uniform(subkey_y, minval=y_range[0], maxval=y_range[1])

        new_position = (new_x.item(), new_y.item())

        if is_valid_position(new_position):
            positions.append(new_position)

    return positions


'\n\n\nkey = random.PRNGKey(0)\nnum_positions = 10\nx_range = (0, 100)\ny_range = (0, 100)\nmin_distance = 5\n\npositions = generate_positions_jax(num_positions, x_range, y_range, min_distance, key)\nprint(positions)\n'

In [None]:
import numpy as np

def calculate_poisson_ratio(initial_grid, deformed_grid, axis=0):
    initial_length_axial = np.max(initial_grid[:, axis]) - np.min(initial_grid[:, axis])
    deformed_length_axial = np.max(deformed_grid[:, axis]) - np.min(deformed_grid[:, axis])
    axial_strain = (deformed_length_axial - initial_length_axial) / initial_length_axial
    transverse_axis = 1 - axis  # if axis=0 (x-axis), transverse_axis=1 (y-axis) and vice versa
    initial_length_transverse = np.max(initial_grid[:, transverse_axis]) - np.min(initial_grid[:, transverse_axis])
    deformed_length_transverse = np.max(deformed_grid[:, transverse_axis]) - np.min(deformed_grid[:, transverse_axis])
    transverse_strain = (deformed_length_transverse - initial_length_transverse) / initial_length_transverse
    poisson_ratio = -transverse_strain / axial_strain
    return poisson_ratio

In [None]:
def create_triangulation_association_matrix(grid):
    tri = Delaunay(grid)
    num_points = grid.shape[0]
    adj_matrix = jnp.zeros((num_points, num_points))
    for simplex in tri.simplices:
        for i in range(3):
            for j in range(i + 1, 3):
                p1, p2 = simplex[i], simplex[j]
                adj_matrix = adj_matrix.at[p1, p2].set(1)
                adj_matrix = adj_matrix.at[p2, p1].set(1)
    return adj_matrix

In [None]:
def display_grid_with_bonds(grid, adj_matrix):
    x, y = grid[:, 0], grid[:, 1]
    plt.scatter(x, y, c='blue', marker='o', zorder=5)
    num_points = grid.shape[0]
    for i in range(num_points):
        for j in range(i + 1, num_points):
            if adj_matrix[i, j] == 1:
                plt.plot([grid[i, 0], grid[j, 0]], [grid[i, 1], grid[j, 1]], 'k-', zorder=1)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('Grid of Points with Bonds')
    plt.grid(True)
    plt.show()

In [None]:
import jax
import jax.numpy as jnp

def remove_random_bonds(adj_matrix, num_bonds=20, seed=0):
    assert (adj_matrix == adj_matrix.T).all(), "adjacency matrix must be symmetric"
    key = jax.random.PRNGKey(seed)
    non_zero_indices = jnp.array(jnp.triu(adj_matrix).nonzero()).T
    chosen_indices = jax.random.choice(key, len(non_zero_indices), (min(num_bonds, len(non_zero_indices)),), replace=False)
    indices_to_remove = non_zero_indices[chosen_indices]
    for (i, j) in indices_to_remove:
        adj_matrix = adj_matrix.at[i, j].set(0)
        adj_matrix = adj_matrix.at[j, i].set(0)

    return adj_matrix



In [None]:
def spring_force(pos1, pos2, rest_length, k=1.0):
    displacement = pos2 - pos1
    distance = jnp.linalg.norm(displacement)
    direction = displacement / (distance + 1e-8)
    force_magnitude = -k * (distance - rest_length)
    return force_magnitude * direction

def angular_spring_force(p1, p2, p3, rest_angle, k=1.0):
    v1 = p1 - p2
    v2 = p3 - p2
    angle = jnp.arctan2(jnp.linalg.norm(jnp.cross(v1, v2)), jnp.dot(v1, v2))
    angle_difference = angle - rest_angle
    return -k * angle_difference

In [None]:
import jax.numpy as jnp

def calculate_initial_lengths_and_angles(grid, adj_matrix):
    num_points = grid.shape[0]
    rest_lengths = jnp.zeros_like(adj_matrix, dtype=float)
    rest_angles = {}

    for i in range(num_points):
        for j in range(i + 1, num_points):
            if adj_matrix[i, j] == 1:
                dist = jnp.linalg.norm(grid[i] - grid[j])
                rest_lengths = rest_lengths.at[i, j].set(dist)
                rest_lengths = rest_lengths.at[j, i].set(dist)

    for i in range(num_points):
        neighbors = jnp.where(adj_matrix[i] == 1)[0]
        for j in range(len(neighbors)):
            for k in range(j + 1, len(neighbors)):
                neighbor1 = int(neighbors[j])
                neighbor2 = int(neighbors[k])
                v1 = grid[neighbor1] - grid[i]
                v2 = grid[neighbor2] - grid[i]
                angle = jnp.arccos(jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2)))
                rest_angles[(i, neighbor1, neighbor2)] = angle

    return rest_lengths, rest_angles

def update_positions(grid, adj_matrix, fixed_indices, dt=0.01, num_iterations=100, k_spring=1.0, k_angle=0.1):
    grid = grid.copy()
    num_points = grid.shape[0]
    rest_lengths, rest_angles = calculate_initial_lengths_and_angles(grid, adj_matrix)
    for _ in range(num_iterations):
        forces = jnp.zeros_like(grid)
        for i in range(num_points):
            for j in range(i + 1, num_points):
                if adj_matrix[i, j] == 1:
                    rest_length = rest_lengths[i, j]
                    force = spring_force(grid[i], grid[j], rest_length, k_spring)
                    forces = forces.at[i].add(force)
                    forces = forces.at[j].add(-force)
        for i in range(num_points):
            neighbors = jnp.where(adj_matrix[i] == 1)[0]
            for j in range(len(neighbors)):
                for k in range(j + 1, len(neighbors)):
                    neighbor1 = int(neighbors[j])
                    neighbor2 = int(neighbors[k])
                    rest_angle = rest_angles[(i, neighbor1, neighbor2)]
                    angle_force = angular_spring_force(grid[neighbor1], grid[i], grid[neighbor2], rest_angle, k_angle)
                    forces = forces.at[i].add(angle_force)

        forces = forces.at[fixed_indices].set(0.0)
        grid = grid + dt * forces

    return grid


In [None]:
import jax.numpy as jnp

def update_positions_v3(grid, adj_matrix, fixed_indices, pulling_indices, dt=0.01, num_iterations=100, k_spring=1e9, k_angle=0.1, pulling_force=(-1.0, 0.0)):
    grid = grid.copy()
    num_points = grid.shape[0]
    rest_lengths, rest_angles = calculate_initial_lengths_and_angles(grid, adj_matrix)
    for _ in range(num_iterations):
        forces = jnp.zeros_like(grid)
        for index in pulling_indices:
            forces = forces.at[index].add(jnp.array(pulling_force))
        fixed_indices_rows = [index for index in fixed_indices]
        forces = forces.at[fixed_indices].set(0.0)

        for i in range(num_points):
            for j in range(i + 1, num_points):
                if adj_matrix[i, j] == 1:
                    rest_length = rest_lengths[i, j]
                    force = spring_force(grid[i], grid[j], rest_length, k_spring)
                    forces = forces.at[i].add(force)
                    forces = forces.at[j].add(-force)

        for i in range(num_points):
            neighbors = jnp.where(adj_matrix[i] == 1)[0]
            for j in range(len(neighbors)):
                    for k in range(j + 1, len(neighbors)):
                        neighbor1 = neighbors[j]
                        neighbor2 = neighbors[k]
                        # Convert JAX arrays to tuples before using them as keys
                        rest_angle = rest_angles[tuple(np.array((i, neighbor1, neighbor2)))]
                        angle_force = angular_spring_force(grid[neighbor1], grid[i], grid[neighbor2], rest_angle, k_angle)
                        forces = forces.at[i].add(angle_force)



        grid = grid + dt * forces

    return grid


In [None]:
num_points = 5  # Number of points in each column
x_fixed_even = 0  # x-coordinate for evenly spaced points
y_min, y_max = 0, 100  # Range for y-coordinates
even_y_values = np.linspace(y_min, y_max, num_points)
pull = [(x_fixed_even, y) for y in even_y_values]
hold = [(100, y) for y in even_y_values]
hold = jnp.array(hold)
pull = jnp.array(pull)

In [None]:
def nodesAndAdj(num_positions, x_range, y_range, min_distance, key):
  positions = generate_positions_jax(num_positions, x_range, y_range, min_distance, key)
  positions = jnp.array(positions)
  positions = jnp.concatenate((hold, pull,positions))
  adj_matrix = create_triangulation_association_matrix(positions)

  return {'nodes': positions, 'adj_matrix': adj_matrix}



In [None]:
# Genetic Algorithm
class GeneticAlgorithm:
    def __init__(self, population_size, mutation_rate, generations):
        self.population_size = population_size
        self.mutation_rate = mutation_rate
        self.generations = generations
        self.population = [nodesAndAdj(6, (0, 100), (0, 100), 5, random.PRNGKey(_)) for _ in range(population_size)]
        self.number_nodes = 6 + 10

    def fitness(self, grid):
        # Define the fitness function Poisson ratio)
      input_field =jnp.array(grid['nodes'])
      input_matrix = jnp.array(grid['adj_matrix'])
      distrurbedgrid = update_positions_v3(input_field,input_matrix,jnp.array((0,1,2,3,4)),jnp.array((10,11,12,13,14)),dt=0.1, num_iterations=50, k_spring=0.1, k_angle=0.1, pulling_force=(-1.0, 0.0))

      poisson_ratio = calculate_poisson_ratio(grid['nodes'], distrurbedgrid)

      if not np.isfinite(poisson_ratio):  # Check if poisson_ratio is finite
            return -1e10
      return -poisson_ratio

    def select_parents(self):
        # Select parents based on fitness
        fitness_scores = [self.fitness(grid) for grid in self.population]
        # Extract fitness values for weights
        #weights = [score[1] for score in fitness_scores]
        parents = rppy.choices(self.population, weights=fitness_scores, k=2)


        return parents

    def crossover(self, parent1, parent2,sort_index):
        child = copy.deepcopy(parent1)

        sorted_nodes_p1 = jnp.array(parent1['nodes'][:sort_index].tolist() + sorted(parent1['nodes'][sort_index:].tolist(), key=lambda x: (x[0], x[1]))) # Convert both to lists and then concatenate
        sorted_nodes_p2 = jnp.array(parent2['nodes'][:sort_index].tolist() + sorted(parent2['nodes'][sort_index:].tolist(), key=lambda x: (x[0], x[1]))) # Convert both to lists and then concatenate
        '''
        for i in range(self.number_nodes):
           child['nodes'][i] = (
               (sorted_nodes_p1[i][0] + sorted_nodes_p2[i][0]) / 2,
                (sorted_nodes_p1[i][1] + sorted_nodes_p2[i][1]) / 2

            )
        '''
        new_nodes = jnp.array([(
                (sorted_nodes_p1[i][0] + sorted_nodes_p2[i][0]) / 2,
                (sorted_nodes_p1[i][1] + sorted_nodes_p2[i][1]) / 2
            ) for i in range(self.number_nodes)])
        child['nodes'] = new_nodes



        return child

    def mutate(self, grid):
        # randomly mutate grid with some probability
        if rppy.random() < self.mutation_rate:
            input_matrix = jnp.array(grid['adj_matrix'])
            remove_random_bonds(input_matrix,1)
            pass

    def evolve(self):
        for generation in range(self.generations):
            new_population = []
            for _ in range(self.population_size):
                parent1, parent2 = self.select_parents()
                child = self.crossover(parent1, parent2,10)
                self.mutate(child)
                new_population.append(child)
            self.population = new_population

            #  best fitness for monitoring
            best_fitness = max(self.fitness(grid) for grid in self.population)
            print(f"Generation {generation}, Best Fitness: {best_fitness}")


ga = GeneticAlgorithm(population_size=5, mutation_rate=0.1, generations=5)
ga.evolve()

Generation 0, Best Fitness: 1.1506661176681519
Generation 1, Best Fitness: 1.1506661176681519
Generation 2, Best Fitness: 1.1506661176681519


KeyboardInterrupt: 