In [1]:
import numpy as np
import matplotlib.pyplot as plt
import triangle as tr
from scipy.special import roots_legendre
import scipy.integrate
from GaussJacobiQuadRule_V3 import *

# TODO: implement B

# TODO: implement I
# TODO: implement all plots:
# H1 error
# Integration convergence
# IVPINN v VPINN

# TODO: change area input to length
# TODO: think about square meshing in the elements
# TODO: add in marianos M stuff to mesh where applicable and VPINN where applicable
# TODO: create complete foundational with new meshing and M


# TODO: organise branches and unify on develop
# TODO: implement M
# TODO: implement M_alpha
# TODO: experiment with performance difference when doing smthn like scipy integrate over each triangle instead


class Mesh:
    def __init__(self, domain: list) -> None:
        self.domain = self.create_domain(domain)

    def generate_mesh(self, h: float) -> None:
        self.mesh = self.create_mesh(self.domain, h)
        self.N = self._get_number_elements()

    @staticmethod
    def create_domain(domain: list) -> dict:
        return dict(vertices=np.array(domain))

    @staticmethod
    def create_mesh(domain: dict, h: float):
        return tr.triangulate(domain, f"eqa{h}")

    def _get_vertices(self):
        return self.mesh["vertices"]

    def _get_edges(self):
        return self.mesh["edges"]

    def _get_number_elements(self):
        return len(self._get_elements())

    def _get_elements(self):
        return self.mesh["triangles"]

    def _get_element_points(self, n: int):
        if n >= self.N:
            raise ValueError(
                "n out of range of number of elements. N.B. elements 0-indexed"
            )
        return self._get_vertices()[self._get_elements()[n]]

    @staticmethod
    def specify_edge_length(h: float):
        return h * h / 3

    def generate_sub_mesh(self, element: int, h: float):
        # TODO: maybe extend to a selection of multiple elements later?
        # TODO: do we want to keep the resolution at a high level or relative to the element
        # i.e. is 0.01 on an element of 0.1 to be 0.1 (relative) or still at 0.01

        sub_mesh = Mesh(self._get_element_points(element))
        sub_mesh.generate_mesh(h)
        return sub_mesh

    def compare(self):
        tr.compare(plt, self.domain, self.mesh)
        plt.show()

    def generate_full_mesh(self, h0: float, h1: float):
        self.meshed_elements = []

        self.generate_mesh(h0)

        for element in range(self.N):
            self.meshed_elements.append(self.generate_sub_mesh(element, h1))

    def plot_sub_mesh(self, to_plot: list = [], figsize: tuple = (6, 6)):
        """
        Plots the sub meshes for the indicated elements, can be one element or all
        """

        plt.figure(figsize=figsize)

        plt.plot(
            self._get_vertices()[:, 0], self._get_vertices()[:, 1], "bo"
        )  # Plot element vertices as blue dots

        if not to_plot:
            to_plot = np.arange(0, self.N, 1)

        plt.triplot(
            self._get_vertices()[:, 0],
            self._get_vertices()[:, 1],
            self._get_elements(),
        )

        for element in to_plot:
            plt.triplot(
                (self.meshed_elements[element])._get_vertices()[:, 0],
                (self.meshed_elements[element])._get_vertices()[:, 1],
                (self.meshed_elements[element])._get_elements(),
                color="lightgreen",
            )

        plt.plot(
            self.domain["vertices"][:, 0], self.domain["vertices"][:, 1], "ro"
        )  # Plot original vertices as red dots

        plt.xlabel("X")
        plt.ylabel("Y")
        plt.show()

    def _get_edge_lengths(self):
        edge_points = self._get_vertices()[self._get_edges()]
        edge_points = np.array([[edge[1][0] - edge[0][0], edge[1][1] - edge[0][1]] for edge in edge_points])
        return np.sqrt((edge_points[:, 0] ** 2) + (edge_points[:, 1] ** 2))

    def max_H(self):
        return max(self._get_edge_lengths())

    def min_H(self):
        return min(self._get_edge_lengths())

    def min_h(self):
        return min(map(lambda sub_mesh: sub_mesh.min_H(), self.meshed_elements))

    def get_refinement_ratio(self):
        return self.min_h() / self.max_H()

    @staticmethod
    def translate(nodes: list, vertices: list):
        """
        Translates given triangle coordinates to the reference
        case ([0, 0], [1, 0], [0, 1])
        """

        output = np.zeros(shape=np.shape(nodes))

        output[:, 0] = (
            nodes[:, 0] * (vertices[1, 0] - vertices[0, 0])
            + nodes[:, 1] * (vertices[2, 0] - vertices[0, 0])
            + vertices[0, 0]
        )

        output[:, 1] = (
            nodes[:, 0] * (vertices[1, 1] - vertices[0, 1])
            + nodes[:, 1] * (vertices[2, 1] - vertices[0, 1])
            + vertices[0, 1]
        )

        det = (vertices[1, 0] - vertices[0, 0]) * (vertices[2, 1] - vertices[0, 1]) - (
            vertices[2, 0] - vertices[0, 0]
        ) * (vertices[1, 1] - vertices[0, 1])

        return output, det

    @staticmethod
    def GLQ(order: int):
        """
        Generate the Gauss-Legendre nodes and weights for
        numerical integration
        """

        # Calculate Gauss-Legendre nodes and weights on the interval [-1, 1]
        nodes, weights = roots_legendre(order)

        # Map nodes and weights from [-1, 1] interval to [0, 1] interval for triangular domain
        nodes = 0.5 * (nodes + 1)
        weights = 0.5 * weights

        x, y = np.meshgrid(nodes, nodes)
        w1, w2 = np.meshgrid(weights, weights)

        filt = x + y <= 1

        w = w1 * w2
        weights_2d = (w[filt]).flatten()

        x, y = x[filt], y[filt]
        points = np.vstack([x.flatten(), y.flatten()]).transpose()

        return np.array(points), np.array(weights_2d).transpose()

    # Calculate characteristic size of the mesh (length of longest edge among all triangles)
    def integrate(self, f, order: int) -> float:
        I = []
        for element in self._get_elements():
            I.append(self.integrate_element(element, f, order))
        return I, np.sum(I)
    
    def integrate_element(self, element, f, order):
        nodes, weights = self.GLQ(order)
        points = self._get_vertices()[element]
        normalised_points, J = self.translate(nodes, points)
        return np.sum(f(normalised_points) * weights * J)


    def convergence(self, f, exact):
        e = []
        n = []
        for order in range(10, 201, 10):  # Specify the desired order
            points, weights = self.GLQ(order)

            n.append(order)
            tot = 0.0
            _,tot = self.integrate(f, order)

            e.append(np.abs(tot - exact))
        e = np.asarray(e)
        n = np.asarray(n)

        plt.figure(figsize=(8, 6))
        plt.loglog(1 / n, e, marker="o", linestyle="-", color="b", label="Error")
        plt.loglog(1 / n, 1 / n, marker="o", linestyle="-", color="g", label="1/n")

        plt.xlabel("1/n")
        plt.title("Convergence Plot")
        plt.grid(True, which="both", ls="--")
        plt.legend()
        plt.show()

    def construct_M(self):
        raise NotImplementedError()