In [1]:
import random
import math
import matplotlib.pyplot as plt
from matplotlib.path import Path
import matplotlib.patches as patches
import matplotlib.colors as mcolors
from matplotlib.patches import Arc, RegularPolygon
import numpy as np
from numpy import radians as rad
from colour import Color


def LocalColoring(std,gc,CMG,DownMScolor=-1): #Given a induced poset, plot Morse Sets and RookField for 2D networks. Color MS down sets
    
    
    fig, ax = plt.subplots(figsize=(10,10))

    sqrt_temp=math.sqrt(std.cc.size(0))

    s_v=0.2
    
    ohang = .5 # overhang for the arrows   
    
    aa=gc.complex()
    
    simp_dim_1 = [simp for simp in std.cc if std.cc.cell_dim(simp) == 1] #all 1 dim simp

    simp_dim_2 = [simp for simp in std.cc if std.cc.cell_dim(simp) == 2] #all 2 dim simp
    


    def p_pst(angle,radius,x,y): #function need to draw hexagon
        angle=math.pi*angle/180
        return (radius*math.cos(angle)+x,radius*math.sin(angle)+y)
    

    def v_(j):              #return dual cells that correspond to fibration value j 
        return [i for i in range(0,aa.size()) if gc.value(i)==j]
    
    def DagVtoCell(j): # return a collection of cubical cells that correspond to the vertex j in Dag graph
        return [aa.dual(i) for i in v_(j) if i in aa.topstar(i)]

    
    def star_dim1(k):   # for a given cell return all 1 dimension cells in the star set that has 2 exact topstars
        return [simp for simp in simp_dim_1 if len(list(set(std.cc.topstar(k)) & set(std.cc.topstar(simp))))==2]


    def PointBlowUp(ii,jj,color): #draw basic point blowup   
        verts = [p_pst(i,s_v,ii,jj) for i in [30,60,120,150,210,240,300,330,30]] #vertices
        codes = [Path.MOVETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,
            Path.LINETO,
            Path.CLOSEPOLY,
                ]
        path = Path(verts, codes) 
        patch = patches.PathPatch(path, facecolor=color, lw=1)
        ax.add_patch(patch)


    def EdgeVertBlowUp(ii,jj,color): #draw basic vertical edge blowup    
        verts = [p_pst(i,s_v,ii,jj+k) for i,k in [(60,0),(120,0),(240,1),(300,1),(60,0)]]
        codes = [Path.MOVETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.CLOSEPOLY,]
        path = Path(verts, codes)
        patch = patches.PathPatch(path, facecolor=color, lw=1)
        ax.add_patch(patch)

    def EdgeHorizBlowUp(ii,jj,color): #draw basic horizantol edge blowup    
        verts = [p_pst(i,s_v,ii+k,jj) for i,k in [(330,0),(30,0),(150,1),(210,1),(330,0)]]
        codes = [Path.MOVETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.CLOSEPOLY,]
        path = Path(verts, codes)
        patch = patches.PathPatch(path, facecolor=color, lw=1)
        ax.add_patch(patch)

    def TopCellBlowUp(ii,jj,color): #draw basic top cell blowup   
        verts = [p_pst(i,s_v,ii+k,jj+l) for i,k,l in [(30,0,0),(60,0,0),(300,0,1),(330,0,1),(210,1,1),(240,1,1),(120,1,0),(150,1,0),(30,0,0)]] #vertices
        codes = [Path.MOVETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,Path.LINETO,
            Path.LINETO,
            Path.CLOSEPOLY,
                ]
        path = Path(verts, codes) 
        patch = patches.PathPatch(path, facecolor=color, lw=1)
        ax.add_patch(patch)

    # simple coloring without considering Morse sets
    for k in simp_dim_2:
        [x,y]= std.cc.coordinates(k)
        if x<sqrt_temp and y<sqrt_temp:
            PointBlowUp(x,y,'w')
            EdgeHorizBlowUp(x,y,'w') 
            EdgeVertBlowUp(x,y,'w')
        if x<sqrt_temp -1 and y<sqrt_temp -1:
            TopCellBlowUp(x,y,'w')


            
################ functions needed for coloring MS

    def ListOfNivelSets(): #return the list of increasing Nivel sets of a given poset (from a induced poset)

        P0={v for v in CMG.vertices_ if len(CMG.children(v))==0}

        def NivelSet(P_i): #return the parent set of a given set
            P_iplus=set()
            for a in P_i:
                #print(str(a)+' -> '+ str(CMG.parents(a)))
                P_iplus=P_iplus.union(CMG.parents(a))
            return P_iplus-P_i

        P=[P0]
        j=0
        while P[j] != set():
            P.append(NivelSet(P[j]))
            j+=1

        return P
    
    
    

            
    P=ListOfNivelSets()
    
    def PositionInList(P,j):#Given a list of coloring P and a vertex j return the position of j in the list P
        for a in range(0,len(P)):
            if j in P[a]: 
                aa=list(P[a])
                aa.sort()
                return a,aa.index(j)
        
    def ColorPosition(P,j):#Given a list of coloring P and a vertex j return the color associate to the position of j in the list P 
        (a,aa)=PositionInList(P,j)
        pink=["Pink","LightPink","HotPink","DeepPink","PaleVioletRed", "MediumVioletRed"]
        blue=["SkyBlue","LightSkyBlue","DeepSkyBlue","DodgerBlue","SteelBlue","Blue"]
        green=["GreenYellow","Lime", "MediumSpringGreen","MediumAquamarine", "YellowGreen","SeaGreen", "Oliver","DarkGreen" ]
        cyan=["PaleTurquoise", "Aquamarine", "Cyan", "MediumTurquoise", "LightSeaGreen"]
        red=["LightSalmon","Salmon","IndianRed","Crimson","Red"]
        orange=["Orange","DarkOrange","Coral","Tomato"]
        yellow=["LemonChiffon","PaleGoldenrod","Yellow","Gold","DarkKhaki"]
        color_list=[green,yellow,blue,orange,cyan,red,pink]
        return color_list[a][aa]       
        
#     ########## list of color called names. See https://matplotlib.org/3.1.0/gallery/color/named_colors.html
#     colors=mcolors.CSS4_COLORS

#     by_hsv = sorted((tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(color))),
#                      name)
#                     for name, color in colors.items())
#     names = [name for hsv, name in by_hsv]
#     ##########
#     def ColorPosition(P,j):#Given a list of coloring P and a vertex j return the color associate to the position of j in the list P 
#         (a,aa)=PositionInList(P,j)
#         colorlisting=["green","blue","yellow","red"]
#         #colorlisting=["green","cyan","blue","magenta","red","yellow"]
#         a=(a)%len(colorlisting)
#         aa=aa%10
#         x=Color(colorlisting[a])
#         y=list(x.range_to(Color(colorlisting[(a+1)%len(colorlisting)]),30))[2*aa+5]
#         return y
                  
##############
        
    
    for j in range(len(CMG.vertices())): #coloring Morse sets 
        corr=str(ColorPosition(P,list(CMG.vertices())[j]))      
        
        for k in DagVtoCell(list(CMG.vertices())[j]):
            [x,y]= std.cc.coordinates(k)
            if std.cc.cell_dim(k)==0:
                PointBlowUp(x,y,corr)
            elif std.cc.cell_dim(k)==2:
                TopCellBlowUp(x,y,corr)
            elif std.cc.cell_dim(k)==1 and std.cc.size()/2-1<k<std.cc.size()-std.cc.size(2):
                EdgeVertBlowUp(x,y,corr)
            else:
                EdgeHorizBlowUp(x,y,corr) 



    ### ploting double arrows
    for (iiii,jjjj) in std.digraph.edges():
        
        if iiii==jjjj:            ##### ploting self edge
            I_dual=std.dc.dual(iiii)            
            #top_temp = std.cc.topstar(I_dual)
            x_temp,y_temp=std.cc.coordinates(I_dual)
            
            if std.cc.cell_dim(I_dual)==0:
                ax.plot(x_temp,y_temp,color='red',marker=r'$\circlearrowleft$',ms=20)
            if std.cc.cell_dim(I_dual)==1:
                if std.cc.size(0)+ sqrt_temp <= I_dual <= std.cc.size(0)+ 2.5*sqrt_temp:###horizontal
                    ax.plot(x_temp+.5,y_temp,color='red',marker=r'$\circlearrowleft$',ms=20) 
                if (0<I_dual%sqrt_temp<sqrt_temp-1) and (std.cc.size()/2<I_dual < std.cc.size()-std.cc.size(2)-sqrt_temp-1): #### vertical
                    ax.plot(x_temp,y_temp+.5,color='red',marker=r'$\circlearrowleft$',ms=20) 
            #if std.cc.cell_dim(I_dual)==2: ### not ploting self edge for dim cell 2
            #    ax.plot(x_temp+.3,y_temp+.3,color='red',marker=r'$\circlearrowleft$',ms=20)            
            
        if (jjjj,iiii) in std.digraph.edges():
            I_dual=std.dc.dual(iiii)
            J_dual=std.dc.dual(jjjj)
        
            if I_dual%sqrt_temp in [0,sqrt_temp-1] and std.cc.cell_dim(I_dual)==0:
                continue
            if J_dual%sqrt_temp in [0,sqrt_temp-1] and std.cc.cell_dim(J_dual)==0:
                continue
            if 0<I_dual<sqrt_temp-1 or std.cc.size(0)-sqrt_temp<= I_dual<std.cc.size(0):
                continue
            if 0<J_dual<sqrt_temp-1 or std.cc.size(0)-sqrt_temp<= J_dual<std.cc.size(0):
                continue
            
            # double arrow vertices to face
            if std.cc.cell_dim(I_dual)==0 and std.cc.cell_dim(J_dual)==2:
                top_temp = std.cc.topstar(I_dual)
                x_temp,y_temp=std.cc.coordinates(I_dual)
                if top_temp.index(J_dual)==0:
                    ax.annotate('', xy=(x_temp-.23,y_temp-.23), xytext=(x_temp-.04,y_temp-.04), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')
                if top_temp.index(J_dual)==1:
                    ax.annotate('', xy=(x_temp+.23,y_temp-.23), xytext=(x_temp+.04,y_temp-.04), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')        
                if top_temp.index(J_dual)==2:
                    ax.annotate('', xy=(x_temp-.23,y_temp+.23), xytext=(x_temp-.04,y_temp+.04), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')                    
                if top_temp.index(J_dual)==3:
                    ax.annotate('', xy=(x_temp+.23,y_temp+.23), xytext=(x_temp+.04,y_temp+.04), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')                    
                    
            #double arrows from horizontal edge to face
            if std.cc.cell_dim(I_dual)==1 and std.cc.cell_dim(J_dual)==2:
                if std.cc.size(0)+ sqrt_temp <= I_dual <= std.cc.size(0)+ 2.5*sqrt_temp:
                    top_temp = std.cc.topstar(I_dual)
                    x_temp,y_temp=std.cc.coordinates(I_dual)
                    if top_temp.index(J_dual)==0:
                        ax.annotate('', xy=(x_temp+.5,y_temp-.01), xytext=(x_temp+.5,y_temp-.21), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center') 
                    if top_temp.index(J_dual)==1:
                        ax.annotate('', xy=(x_temp+.5,y_temp+.01), xytext=(x_temp+.5,y_temp+.21), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')       
                    
            #double arrows from vertical edge to face
            if std.cc.cell_dim(I_dual)==1 and std.cc.cell_dim(J_dual)==2:
                sqrt_temp=math.sqrt(std.cc.size(0))
                if (0<I_dual%sqrt_temp<sqrt_temp-1) and (std.cc.size()/2<I_dual < std.cc.size()-std.cc.size(2)-sqrt_temp-1):
                    top_temp = std.cc.topstar(I_dual)
                    x_temp,y_temp=std.cc.coordinates(I_dual)
                    if top_temp.index(J_dual)==0:
                        ax.annotate('', xy=(x_temp-.01,y_temp+.5), xytext=(x_temp-.21,y_temp+.5), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center') 
                    if top_temp.index(J_dual)==1:
                        ax.annotate('', xy=(x_temp+.01,y_temp+.5), xytext=(x_temp+.21,y_temp+.5), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')
                                                
            #double arrows vertice to edge
            if std.cc.cell_dim(I_dual)==0 and std.cc.cell_dim(J_dual)==1:
                top_temp = star_dim1(I_dual)
                x_temp,y_temp=std.cc.coordinates(I_dual)
                if top_temp.index(J_dual)==0: 
                    ax.annotate('', xy=(x_temp-.33,y_temp), xytext=(x_temp-.04,y_temp), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')
                if top_temp.index(J_dual)==1:
                    ax.annotate('', xy=(x_temp+.33,y_temp), xytext=(x_temp+.04,y_temp), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')
                if top_temp.index(J_dual)==2:
                    ax.annotate('', xy=(x_temp,y_temp-.33), xytext=(x_temp,y_temp-.04), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')
                if top_temp.index(J_dual)==3:
                    ax.annotate('', xy=(x_temp,y_temp+.33), xytext=(x_temp,y_temp+.04), arrowprops={'arrowstyle': '<->', 'lw': 5, 'ec': 'gray'}, va='center')

#     ### removing edges not needed in digraph  
    
    without_double=std.digraph
    for (iiii,jjjj) in std.digraph.edges():
#         if (jjjj,iiii) in std.digraph.edges():
#             without_double.remove_edge(iiii,jjjj)
#             without_double.remove_edge(jjjj,iiii)
        I_dual=std.dc.dual(iiii)
        J_dual=std.dc.dual(jjjj)
        #removing fringe vertices and edges
        if I_dual%sqrt_temp in [0,sqrt_temp-1] and std.cc.cell_dim(I_dual)==0:
            without_double.remove_edge(iiii,jjjj)
            without_double.remove_edge(jjjj,iiii)
        if J_dual%sqrt_temp in [0,sqrt_temp-1] and std.cc.cell_dim(J_dual)==0:
            without_double.remove_edge(iiii,jjjj)
            without_double.remove_edge(jjjj,iiii)
        if 0<I_dual<sqrt_temp-1 or std.cc.size(0)-sqrt_temp<= I_dual<std.cc.size(0):
            without_double.remove_edge(iiii,jjjj)
            without_double.remove_edge(jjjj,iiii)
        if 0<J_dual<sqrt_temp-1 or std.cc.size(0)-sqrt_temp<= J_dual<std.cc.size(0):
            without_double.remove_edge(iiii,jjjj)
            without_double.remove_edge(jjjj,iiii)
        if 0<I_dual<sqrt_temp-1 or std.cc.size()/2-sqrt_temp<= I_dual<std.cc.size()/2:
            without_double.remove_edge(iiii,jjjj)
            without_double.remove_edge(jjjj,iiii)
        if 0<J_dual<sqrt_temp-1 or std.cc.size()/2-sqrt_temp<= J_dual<std.cc.size()/2:
            without_double.remove_edge(iiii,jjjj)
            without_double.remove_edge(jjjj,iiii)

         
            
#     #### ploting arrows that are not double arrow 
#     for (iiii,jjjj) in without_double.edges():
#         I_dual=std.dc.dual(iiii)
#         J_dual=std.dc.dual(jjjj)

    #### ploting arrows that are not double arrow 
    for (iiii,jjjj) in std.digraph.edges():
        if not (jjjj,iiii) in std.digraph.edges():#not print double edges now
            I_dual=std.dc.dual(iiii)
            J_dual=std.dc.dual(jjjj)


                #ploting arrows from cubical vertice to cubical face
            if std.cc.cell_dim(I_dual)==0 and std.cc.cell_dim(J_dual)==2:
                top_temp = std.cc.topstar(I_dual)
                x_temp,y_temp=std.cc.coordinates(I_dual)
                if top_temp.index(J_dual)==0:
                    ax.arrow(x_temp,y_temp, -0.17, -0.17, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)    
                if top_temp.index(J_dual)==1:
                    ax.arrow(x_temp,y_temp, 0.17, -0.17, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)         
                if top_temp.index(J_dual)==2:
                    ax.arrow(x_temp,y_temp, -0.17, 0.17, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang) 
                if top_temp.index(J_dual)==3:
                    ax.arrow(x_temp,y_temp, 0.17, 0.17, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)

            #ploting arrows from  face to  vertice
            if std.cc.cell_dim(I_dual)==2 and std.cc.cell_dim(J_dual)==0:
                top_temp = std.cc.topstar(J_dual)
                x_temp,y_temp=std.cc.coordinates(J_dual)
                if top_temp.index(I_dual)==0:
                    ax.arrow(x_temp-0.25,y_temp-0.25,0.15,0.15, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)    
                if top_temp.index(I_dual)==1:
                    ax.arrow(x_temp +0.25,y_temp-0.25,-0.15,0.15, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)         
                if top_temp.index(I_dual)==2:
                    ax.arrow(x_temp-0.25,y_temp +0.25,0.15,-0.15, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang) 
                if top_temp.index(I_dual)==3:
                    ax.arrow(x_temp +0.25,y_temp +0.25, -0.15, -0.15, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)



            #ploting arrows from horizontal edge to face
            if std.cc.cell_dim(I_dual)==1 and std.cc.cell_dim(J_dual)==2:
                if std.cc.size(0)+ sqrt_temp <= I_dual <= std.cc.size(0)+ 2.5*sqrt_temp:
                    top_temp = std.cc.topstar(I_dual)
                    x_temp,y_temp=std.cc.coordinates(I_dual)
                    if top_temp.index(J_dual)==0:
                        ax.arrow(x_temp+0.5,y_temp-.05, 0, -0.2, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)    
                    if top_temp.index(J_dual)==1:
                        ax.arrow(x_temp+0.5,y_temp+.05, 0, +0.2, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)

            #ploting arrows from face to cubical horizontal edge
            if std.cc.cell_dim(I_dual)==2 and std.cc.cell_dim(J_dual)==1:
                sqrt_temp=math.sqrt(std.cc.size(0))
                if std.cc.size(0)+ sqrt_temp <= J_dual <= std.cc.size(0)+ 2.5*sqrt_temp:
                    top_temp = std.cc.topstar(J_dual)
                    x_temp,y_temp=std.cc.coordinates(J_dual)
                    if top_temp.index(I_dual)==0:
                        ax.arrow(x_temp+0.5,y_temp-0.35, 0, +0.2, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)    
                    if top_temp.index(I_dual)==1:
                        ax.arrow(x_temp+0.5,y_temp+0.35, 0, -0.2, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)

            #ploting arrows from vertical edge to face
            if std.cc.cell_dim(I_dual)==1 and std.cc.cell_dim(J_dual)==2:
                sqrt_temp=math.sqrt(std.cc.size(0))
                if (0<I_dual%sqrt_temp<sqrt_temp-1) and (std.cc.size()/2<I_dual < std.cc.size()-std.cc.size(2)-sqrt_temp-1):
                    top_temp = std.cc.topstar(I_dual)
                    x_temp,y_temp=std.cc.coordinates(I_dual)
                    if top_temp.index(J_dual)==0:
                        ax.arrow(x_temp-.05,y_temp+0.5, -0.2, 0, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)    
                    if top_temp.index(J_dual)==1:
                        ax.arrow(x_temp+.05,y_temp+0.5, +0.2, 0, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)         


            #ploting arrows from face  to vertical edge
            if std.cc.cell_dim(I_dual)==2 and std.cc.cell_dim(J_dual)==1:
                sqrt_temp=math.sqrt(std.cc.size(0))
                if (0<J_dual%sqrt_temp<sqrt_temp-1) and (std.cc.size()/2<J_dual < std.cc.size()-std.cc.size(2)-sqrt_temp):
                    top_temp = std.cc.topstar(J_dual)
                    x_temp,y_temp=std.cc.coordinates(J_dual)
                    if top_temp.index(I_dual)==0:
                        ax.arrow(x_temp-0.35,y_temp+0.5, +0.2, 0, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)    
                    if top_temp.index(I_dual)==1:
                        ax.arrow(x_temp+0.35,y_temp+0.5, -0.2, 0, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)

            #ploting arrows from  vertice to  edge
            if std.cc.cell_dim(I_dual)==0 and std.cc.cell_dim(J_dual)==1:
                top_temp = star_dim1(I_dual)
                x_temp,y_temp=std.cc.coordinates(I_dual)
                if top_temp.index(J_dual)==0:
                    ax.arrow(x_temp,y_temp, -0.2, 0, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)    
                if top_temp.index(J_dual)==1:
                    ax.arrow(x_temp,y_temp, 0.2, 0, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)         
                if top_temp.index(J_dual)==2:
                    ax.arrow(x_temp,y_temp, 0, -0.2, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang) 
                if top_temp.index(J_dual)==3:
                    ax.arrow(x_temp,y_temp, 0, 0.2, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)

            #ploting arrows from  edge to  vertice
            if std.cc.cell_dim(I_dual)==1 and std.cc.cell_dim(J_dual)==0:
                top_temp = star_dim1(J_dual)
                x_temp,y_temp=std.cc.coordinates(J_dual)
                if top_temp.index(I_dual)==0:
                    ax.arrow(x_temp-0.3,y_temp, 0.1, 0, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)    
                if top_temp.index(I_dual)==1:
                    ax.arrow(x_temp+0.3,y_temp, -0.1, 0, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)         
                if top_temp.index(I_dual)==2:
                    ax.arrow(x_temp,y_temp-0.3, 0, 0.1, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang) 
                if top_temp.index(I_dual)==3:
                    ax.arrow(x_temp,y_temp+0.3, 0, -0.1, head_width=0.1, head_length=0.1, fc='k', ec='k',overhang=ohang)

    ##### draw circle arrow                
    def drawCirc(ax,radius,centX,centY,angle_,theta2_,color_='black'):
        #========Line
        arc = Arc([centX,centY],radius,radius,angle=angle_,
              theta1=0,theta2=theta2_,capstyle='round',linestyle='-',lw=2,color=color_)
        ax.add_patch(arc)


        #========Create the arrow head
        endX=centX+(radius/2)*np.cos(rad(theta2_+angle_)) #Do trig to determine end position
        endY=centY+(radius/2)*np.sin(rad(theta2_+angle_))

        ax.add_patch(                    #Create triangle as arrow head
            RegularPolygon(
                (endX, endY),            # (x,y)
                3,                       # number of vertices
                radius/3,                # radius
                rad(angle_+theta2_),     # orientation
                color=color_
            )
        )
        ax.set_xlim([centX-radius,centY+radius]) and ax.set_ylim([centY-radius,centY+radius]) 
        # Make sure you keep the axes scaled or else arrow will distort  
    #####
    
    
    ######## Color Down sets MS
    def DownCells(s): #return all cells that belong to the downset of the vertex s
        downC=set()
        for a in dag.descendants(s):
            downC=downC.union(set(DagVtoCell(a)))
        return downC

    def ColorDownMS(s): #color the down sets of given vertex that represent a Morse Set
        corr="LavenderBlush"
        for k in DownCells(s):
            [x,y]= std.cc.coordinates(k)
            if std.cc.cell_dim(k)==0:
                PointBlowUp(x,y,corr)
            elif std.cc.cell_dim(k)==2:
                TopCellBlowUp(x,y,corr)
            elif std.cc.cell_dim(k)==1 and std.cc.size()/2-1<k<std.cc.size()-std.cc.size(2):
                EdgeVertBlowUp(x,y,corr)
            else:
                EdgeHorizBlowUp(x,y,corr) 
    
    if DownMScolor != -1:
        ColorDownMS(DownMScolor)
    ##########    
    
            
    #ploting the RookField    
    for a in range(std.cc.size()-std.cc.size(std.D), std.cc.size()-1):
        x_temp,y_temp=std.cc.coordinates(a)
        x_temp2,y_temp2=std.CorrectRookField(a)
        if [x_temp2,y_temp2]==[0,0]:
            drawCirc(ax,.15,x_temp+.5,y_temp+.5,140,270,color_='m')
            #ax.plot(x_temp+.5,y_temp+.5,color='m',marker=r'$\circlearrowleft$',ms=40)
        if [x_temp2,y_temp2]!=[0,0]:
            ax.arrow(x_temp+.5-0.2*x_temp2,y_temp+.5-0.2*y_temp2, 0.1*x_temp2,0.1*y_temp2, head_width=0.1, head_length=0.1, fc='m', ec='m')
                    
                    
    ax.set_xlim(0, sqrt_temp-1)
    ax.set_ylim(0, sqrt_temp-1)
    

    
    return P