In [2]:
#!/usr/bin/env python3
import argparse
import os
import shutil
import matplotlib.pyplot as plt
import numpy as np
import yaml

from typing import Union, LiteralString, List, Set, Tuple
import pandas as pd

In [3]:
FileName = Union[LiteralString, str]
Numeric = Union[int, float]
DEBUG: bool = False
VERBOSE: bool = False

In [None]:
class ExtractClusterData:

    def __init__(
                self, 
                 data: Union[pd.DataFrame, FileName], 
                 is_offline: bool = True,
                 number_cluster_vars: int = 6,
                 perform_smoothen: bool = False
                 ) -> None:
        
        # Attributes
        self.is_offline: bool = is_offline
        self.number_cluster_vars: int = number_cluster_vars
        self.perform_smoothen: bool = perform_smoothen

        # Create Panda Frame
        self.data = self.create_panda_dataframe(data)

        # Start and End indices
        self.start: int = 0
        self.end: int = self.data.size

        # Figs
        self.figs: Dict[int, plt.Fig] = dict()

        self.define_col_headers()    
    
    def define_col_headers(self) -> None:
        
        # For each column in the data frame
        for c in self.data.columns:

            # Create a class attribute and assign the column
            # NOTE: The .strip() method removes leading and trailing
            # whitespace.
            # NOTE: The .replace(' ', '_') replaces inner spaces with
            # underscores
            self.__setattr__(c.strip().replace(' ', '_'), self.data[c])

        # Prints all the keys
        if VERBOSE: print(self.__dict__.keys())  

    def smoothen() -> None:
        pass

    @staticmethod
    # NOTE: From `https://stackoverflow.com/questions/34017866/arrow-on-a-line-plot`
    def add_arrow(line, position=None, direction='right', size=15, color=None):
        """
        add an arrow to a line.

        line:       Line2D object
        position:   x-position of the arrow. If None, mean of xdata is taken
        direction:  'left' or 'right'
        size:       size of the arrow in fontsize points
        color:      if None, line color is taken.
        """
        if color is None:
            color = line.get_color()

        xdata = line.get_xdata()
        ydata = line.get_ydata()

        if position is None:
            position = xdata.mean()
        # find closest index
        start_ind = np.argmin(np.absolute(xdata - position))
        if direction == 'right':
            end_ind = start_ind + 1
        else:
            end_ind = start_ind - 1

        line.axes.annotate('',
            xytext=(xdata[start_ind], ydata[start_ind]),
            xy=(xdata[end_ind], ydata[end_ind]),
            arrowprops=dict(arrowstyle="->", color=color),
            size=size
        )

    # NOTE: Got from ChatGPT
    @staticmethod
    def add_arrows_to_line(line, n_arrows=20, arrow_style='->', color=None, size=15, ind_spacing=10):
        """
        Adds arrows to a matplotlib Line2D object to indicate direction.

        Parameters:
        - line        : matplotlib Line2D object (e.g. from ax.plot)
        - n_arrows    : number of arrows to draw
        - arrow_style : arrow style (e.g., '->', '-|>')
        - color       : arrow color (defaults to line color)
        - size        : arrow size
        - spacing     : 'data' or 'index' based spacing
        """
        import numpy as np
        import matplotlib.pyplot as plt

        x = np.array(line.get_xdata())
        y = np.array(line.get_ydata())
        ax = line.axes  # get the Axes the line belongs to

        if color is None:
            color = line.get_color()

        # if spacing == 'data':
        #     dist = np.cumsum(np.sqrt(np.diff(x)**2 + np.diff(y)**2))
        #     dist = np.insert(dist, 0, 0)
        #     arrow_locs = np.linspace(0, dist[-1], n_arrows + 2)[1:-1]
        #     arrow_indices = [np.searchsorted(dist, loc) for loc in arrow_locs]
        # else:
        arrow_indices = np.linspace(1, len(x)-2, n_arrows).astype(int)


        s = ind_spacing//2
        for i in arrow_indices:
            
            next = i + s if i + s < len(x) else len(x) - 1
            prev = i - s if i - s > 0 else 0

            dx = x[next] - x[prev]
            dy = y[next] - y[prev]
            ax.annotate('', xy=(x[next], y[next]), xytext=(x[prev], y[prev]),
                        arrowprops=dict(arrowstyle=arrow_style, color=color),
                        size=size)

    def _get_fig(self,
                i, 
                start: int, 
                end: int, 
                x: pd.DataFrame,
                y: pd.DataFrame,
                use_arrows: bool = False,
                subplot_tuple: Tuple[int] = (1, 1, 1),
                figsize: Tuple[Numeric] = None,
                *args,
                **kwargs
                ) -> plt.Figure:
        
        
        # Create fig
        fig = plt.figure(num=i, figsize=figsize)

        # TODO: Make this configurable? Adjust spacing 
        plt.subplots_adjust(hspace=0.5)
        
        # Extract vars
        row, col, ax_ind = subplot_tuple

        # Check if the axis already exists in the figure
        if len(fig.axes) >= ax_ind:  # If axes exist at ax_ind, use it
            ax = fig.axes[ax_ind - 1] 
            _ax_ind: int = ax_ind - 1
        else: 
            # Otherwise, create a new subplot
            ax = fig.add_subplot(row, col, ax_ind)
            _ax_ind: int = -1

        # Create line
        line, = fig.axes[_ax_ind].plot(x[start:end], y[start:end], *args, **kwargs)

        # Add arrows if desired
        if use_arrows: self.add_arrows_to_line(line)
        
        # Add figure to figs
        self.figs[i] = fig
        

        return fig

    
    def _get_time_domain_fig(
                            self,
                            i, 
                            start: int, 
                            end: int, 
                            y: pd.DataFrame,
                            subplot_tuple: Tuple[int] = None,
                            figsize: Tuple[Numeric] = None,
                            *args,
                            **kwargs
                            ) -> plt.Figure:

        return self._get_fig(i, 
                            start, 
                            end, 
                            self.timestamp, 
                            y, 
                            subplot_tuple = subplot_tuple, 
                            figsize= figsize,
                            *args, **kwargs)
    

    def get_time_domain_figs(
                        self, 
                        ylims: List[Tuple[int]] = None, 
                        use_multi_y_axis: bool = False,
                        use_subplots: bool = False,
                        figsize: Tuple[int] = (6, 14), 
                        fontsize: int = 10
                        ) -> List[plt.Figure]:

        # Form static data structures
        # TODO: Can this be more generalizable and
        # not hardcoded?
        cluster_vars: Dict[int, List[pd.DataFrame]] = \
            {
             1: [self.x_c_des, self.x_c],
             2: [self.y_c_des, self.y_c],
             3: [self.t_c_des, self.t_c],
             4: [self.p_des, self.p],
             5: [self.q_des, self.q],
             6: [self.B_des, self.B],
             }
        
        # Form static data structures
        # TODO: Can this be more generalizable and
        # not hardcoded?
        theta_rstr: str = '\theta'
        title_vars: Dict[int, List[str]] = \
            {
                1: ["X-position\n of cluster $x_c$",
                   "X-Position (m)"],
                2: ["Y-position\n of cluster $y_c$",
                   "Y-Position (m)"],
                3: ["Heading\n of cluster $\\theta_c$",
                   f"Heading (rad)"],
                4: ["$p$-length\n of cluster configuration",
                   "Length (m)"],
                5: ["$q$-length\n of cluster configuration",
                   "Length (m)"],
                6: ["$\\beta$-angle\n of cluster configuration",
                   "Angle (rad)"]
            }
        
        # Form static data structures
        # TODO: Can this be more generalizable and
        # not hardcoded?
        if ylims == None:
            ylims = [(-15, 15), (-15, 15), (-4, 4), (0, 15), (0, 15), (-4, 4)]

        # Initialize empty list of 6 length
        if use_subplots:
            figs: list = [None]
        else:
            figs: list = [None] * len(cluster_vars.keys())
        if DEBUG: print(len(figs))

        for k, list_values in cluster_vars.items():
            if DEBUG: print(k)

            # TODO: Fix hardcorded tuples
            if use_subplots:
                subplot_tuple: Tuple[int] = (int(len(cluster_vars.keys())/2), 2, k)
            else:
                subplot_tuple: Tuple[int] = (1, 1, 1)
            for v in list_values:

                # If using subplots
                ind: int = 0 if use_subplots else k-1

                figs[ind] = \
                    self._get_time_domain_fig(
                    ind,
                    self.start,
                    self.end,
                    v,
                    subplot_tuple,
                    figsize=figsize
                    )
                plt.title(title_vars[k][0], fontsize=16, fontweight='bold')
                plt.xlabel("Time (s)", fontsize=fontsize)
                plt.ylabel(title_vars[k][1], fontsize=fontsize)
                plt.legend(["Desired", "Actual"], fontsize=fontsize)
                plt.ylim(ylims[k-1])
                plt.grid(True)

                if use_subplots:
                    # Tight layout
                    plt.tight_layout() 
        return figs
           
    def get_centroid_phase_portrait(
                           self,
                           ind: int,
                           start: int,
                           end: int,
                           fontsize: int = 10,
                           ) -> plt.Figure:

        fig = self._get_fig(ind, start, end, self.x_c_des, self.y_c_des, use_arrows= True, color="blue")
        fig = self._get_fig(ind, start, end, self.x_c, self.y_c, use_arrows= True, color="green")
        plt.title("Phase portrait of cluster centroid", fontsize=16, fontweight='bold')
        plt.xlabel("X-position of cluster centroid $x_c$ (m)", fontsize=fontsize)
        plt.ylabel("Y-position of cluster centroid $y_c$ (m)", fontsize=fontsize)
        plt.legend(["Desired", "Actual"], fontsize=fontsize)
        plt.grid(True)
        plt.axis('equal')

        return fig

    @staticmethod
    def annotate_polygon(
                        ind: int,
                        x: List[Numeric],
                        y: List[Numeric],
                        label: str = None,
                        include_label: bool = True,
                        *args,
                        **kwargs 
                        ) -> plt.Figure:
        
        # If x and y do not match length
        if len(x) != len(y):
            raise("Incorrect size")

        fig = plt.figure(ind)

        # Plot triangle
        plt.plot(x, y, *args, **kwargs)

        colors: List[str] = ["red", "blue", "green", "orange", "purple"]
        # Add vertices
        for i in range(len(x)):

            plt.scatter(x[i], y[i], color=colors[i % len(x) - 1])

        # Add label
        # Default to labeling the first vertex as "start"
        if include_label: 

            if label:
                p = 0
                start_point = [x[p], y[p]]
                ax = fig.gca()
                ax.text(start_point[0] - 5, 
                        start_point[1] - 0.1, 
                        label, 
                        fontsize=12, 
                        color='Black')


        plt.axis('equal')
        
        return fig
        

    def get_cluster_phase_portrait(
                          self,
                           ind: int,
                           start: int,
                           end: int,
                           include_label: bool = True,
                           fontsize: int = 10,
                           legend_fontsize: int = 10
                           ) -> plt.Figure:
        
        # Form static data structures
        # TODO: Can this be more generalizable and
        # not hardcoded?
        xy_pos: Dict[int, List[pd.DataFrame]] = \
            {1: [self.x_1, self.y_1],
             2: [self.x_2, self.y_2],
             3: [self.x_3, self.y_3]}
        
        
        for _, v in xy_pos.items():
            fig = \
                self._get_fig(
                ind,
                start,
                end,
                v[0],
                v[1],
                use_arrows= True
                )

        # plt.title("Phase portrait of 3-cluster configuration of robots")
        plt.xlabel("X-position $x$ (m)", fontsize=fontsize)
        plt.ylabel("Y-position $y$ (m)", fontsize=fontsize)



        # TODO: Can these endpoints be implemented within the annotate_polygon() 
        # instead
        # Creating the points of the cluster config
        x_endpoints: List[Numeric] = [v[0][end] for _, v in xy_pos.items()]
        y_endpoints: List[Numeric] = [v[1][end] for _, v in xy_pos.items()]

        # Adding the first point again to the end to close the polygon
        x_endpoints.append(x_endpoints[0])
        y_endpoints.append(y_endpoints[0])

        if DEBUG:
            print(x_endpoints)
            print(y_endpoints)

        fig = self.annotate_polygon(ind, x_endpoints, y_endpoints, label="End", include_label= include_label, color="blue")

        # Creating the points of the cluster config
        x_startpoints: List[Numeric] = [v[0][start] for _, v in xy_pos.items()]
        y_startpoints: List[Numeric] = [v[1][start] for _, v in xy_pos.items()]

        # Adding the first point again to the end to close the polygon
        x_startpoints.append(x_startpoints[0])
        y_startpoints.append(y_startpoints[0])

        fig = self.annotate_polygon(ind, x_startpoints, y_startpoints, label="Start", include_label= include_label, color="blue")

        # TODO: Fix hardcoded number of robots
        plt.legend([f"Robot {i}, $r_{i}$" for i in range(1,4)], fontsize=legend_fontsize)
        plt.grid(True)

        return fig

    # TODO: Add histogram plot
    def get_sensor_histogram_plot() -> plt.Figure:
        pass


    def export_fig(
                self, 
                ind: int, 
                filename: FileName = None, 
                dirname: FileName = "figures",  
                ext: str= ".png",
                dir_exists_ok: bool = True,
                ) -> bool:

        try:
                            
            # Create directory
            os.makedirs(dirname, exist_ok=dir_exists_ok)

            # Get figure
            fig = self.figs[ind]

            if filename == None:

                # Get current axis
                ax = fig.axes[0]

                # Extract title and replace name
                filename: str = ax.get_title() \
                                        .lower() \
                                        .replace(' ', '_') \
                                        .replace('\\n', '') \
                                        .replace('$', '') \
                                        .replace('\\','')
                if VERBOSE: print(f"Filename: {file_name}")
            
            fig.savefig(os.path.join(dirname, filename + ext), bbox_inches='tight')
            if VERBOSE: print("Saved figure!")

        except Exception as e:
            print(f"Error: {e}")
            return False
        return True
        

    def export_figs(self, dirname: FileName = "figures", ext: str= ".png") -> bool:

        try:
            
            # Create directory
            os.makedirs(dirname, exist_ok=True)

            for k, fig in self.figs.items():

                # Get current axis
                ax = fig.axes[0]

                # Extract title and replace name
                file_name: str = ax.get_title() \
                                        .lower() \
                                        .replace(' ', '_') \
                                        .replace('\n', '') \
                                        .replace('$', '') \
                                        .replace('\\','') 
                if VERBOSE: print(f"Filename: {file_name}")
                
                fig.savefig(os.path.join(dirname,file_name + ext), bbox_inches='tight')
                if VERBOSE: print("Saved figure!")

        except Exception as e:
            print(f"Error: {e}")
            return False
        return True
        
    def clear_fig(self, ind: int = -1) -> None:
        self.figs[ind].clf()
        
    def clear_figs(self) -> None:
        for _, f in self.figs.items():
            self.clear_fig(f.number)


    @staticmethod
    def create_panda_dataframe(data) -> pd.DataFrame:
        
        # Get datatype
        typ = type(data)

        # TODO: Check which type of string
        if typ == str or typ == LiteralString:
            return pd.read_csv(data)
        elif typ == pd.DataFrame:
            return data           

In [None]:
with open("plotting_params.yaml") as f:
    data = yaml.safe_load(f)


grapher: ExtractClusterData = ExtractClusterData(
    data["data"],
)

for d in data["clusters"]:

    dir = d["dirname"]

    if os.path.exists(dir):
        shutil.rmtree(dir) 

    if d["enable"]:

        grapher.start = d["time"]["start"]
        grapher.end = d["time"]["end"]

        if d["figures"]["time_history"]["enable"]:
            # Time history
            time_history_config = d["figures"]["time_history"]
            use_subplots = time_history_config["use_subplots"]
            fontsize = time_history_config["fontsize"]

            figsize = time_history_config["figsize"] if use_subplots else None

            ylims = time_history_config["ylim"]
            # Plot time domain graphs
            figs = grapher.get_time_domain_figs(ylims, 
                                                use_subplots=use_subplots,
                                                figsize=figsize,
                                                fontsize=fontsize)

            if use_subplots:
                # Export figure
                grapher.export_fig(
                    ind=figs[0].number,
                    filename=time_history_config["filename"],
                    dirname=d["dirname"],
                    ext=".pdf"
                )
            else:
                grapher.export_figs(
                    dirname=d["dirname"],
                    ext="_" + d["dirname"] + ".pdf"
                )





        if d["figures"]["centroid_phase_portrait"]["enable"]:
            
            centroid_portrait_config = d["figures"]["centroid_phase_portrait"]
            fontsize = centroid_portrait_config["fontsize"]

            # Plot centroid phase portrait
            phase_portrait: plt.Figure = \
                    grapher.get_centroid_phase_portrait(
                                                ind=7,
                                                start=grapher.start,
                                                end=grapher.end,
                                                fontsize=fontsize
                                                )
            phase_portrait.show()

            # Export figure
            grapher.export_fig(
                ind=7,
                filename=centroid_portrait_config["filename"],
                dirname=d["dirname"],
                ext=".pdf"
            )


        if d["figures"]["cluster_phase_portrait"]["enable"]:

            cluster_phase_config = d["figures"]["cluster_phase_portrait"]
            fontsize = cluster_phase_config["fontsize"]
            legend_fontsize = cluster_phase_config["legend_fontsize"]

            # Plot cluster phase portrait
            cluster_phase_portrait: plt.Figure = \
                    grapher.get_cluster_phase_portrait(
                        ind=8,
                        start=grapher.start,
                        end=grapher.end,
                        include_label=cluster_phase_config["include_label"],
                        fontsize=fontsize,

                    )
            cluster_phase_portrait.show()

            # Export figure
            grapher.export_fig(
                ind=8,
                filename=cluster_phase_config["filename"],
                dirname=d["dirname"],
                ext=".pdf"
            )

    # Clear all figs
    grapher.clear_figs()


  fig = plt.figure(num=i, figsize=figsize)
  cluster_phase_portrait.show()
  fig = plt.figure(num=i, figsize=figsize)
  cluster_phase_portrait.show()
  fig = plt.figure(num=i, figsize=figsize)
  cluster_phase_portrait.show()


<Figure size 900x1000 with 0 Axes>

<Figure size 640x480 with 0 Axes>

In [None]:
# Find desired
ind: List[int] = list()
vals: List[Numeric] = [grapher.t_c_des[0]]

for i in range(1, len(grapher.t_c_des) - 1):

    # If rising from 0
    if abs(grapher.t_c_des[i]) >= 1e-01 and (abs(grapher.t_c_des[i] - vals[-1]) >= 1):
        vals.append(grapher.t_c_des[i])
        ind.append(i)