In [14]:
import numpy as np
import underworld as uw
from underworld import function as fn
import glucifer
import networkx as nx
import operator





class TectModel(nx.DiGraph):
    
    """
    
    
    ***note on has_edge()***
    I intend most methods requiring two plates to take a tuple (plate1Id, plate2Id)
    the DiGraph has_edge() method can be handed a tuple, if it prececeeded by a *
    .has_edge(*(plate1Id, plate2Id)) //or// .has_edge(plate1Id, plate2Id)
    
    
    """
    
    
    def __init__(self, mesh, starttime, endtime, dt):
        
        ########Trying various ways to init the parent class
        #super(nx.DiGraph, self).__init__(*args)
        #super().__init__(*args) 
        nx.DiGraph.__init__(self) 
        ################################
        
        self.times = np.arange(starttime, endtime, dt)
        #self.add_node('times', times=self.times)
        self.plateIdUsedList = []
        self.plateIdDefaultList = list(np.arange(1, 101))
        
        #mesh and coordinate functions
        self.mesh = mesh
        self._coordinate = fn.input()
        self._xFn = self._coordinate[0]
    

    #using getters as I'm unsure how this part of the code may evolve,
    #i.e. the coupling between underworld objects and the current Class
    
    @property
    def xFn(self): 
        return self._xFn
    
    @property
    def minX(self): 
        return self.mesh.minCoord[0]
    
    @property
    def maxX(self): 
        return self.mesh.maxCoord[0]
    
    @property
    def undirected(self): 
        return self.to_undirected()
    

    

  
    
    ################################# 
    #Read from Dict function to allow checkpointing
    #################################  
    def pop_from_dict_of_lists(self, d):

        """Return a graph from a dictionary of lists.
        Adapted from the networkX function

        """

        self.add_nodes_from(d)
        self.add_edges_from(((u, v, data)
                              for u, nbrs in d.items()
                              for v, data in nbrs.items()))


    #################################        
    ##General graph query / utilities 
    ################################# 
    
    def connected_plates(self, plateId):
        
        #return (list(set([x for x in nx.all_neighbors(self, plateId)])))
        return self.undirected.neighbors(plateId)
        
    
    def is_subduction_boundary(self, platePair):
        result = False
        
        #possible options are: no connection between nodes (False, False)
        #           :  two way connecrion (ridge) (True, True)
        #           : one way connection (Truem False) ....
        
        if self.has_edge(*(platePair[0], platePair[1])) != self.has_edge(*(platePair[1], platePair[0])):
            result = True            
        return result
    
    def subduction_edge_order(self, platePair):
        
        if self.has_edge(*(platePair[0], platePair[1])) and not self.has_edge(*(platePair[1], platePair[0])):
            return [platePair[0], platePair[1]]
        elif self.has_edge(*(platePair[1], platePair[0])) and not self.has_edge(*(platePair[0], platePair[1])):
            return [platePair[1], platePair[0]]
        else:
            raise ValueError("boundary does not exist, or not a subduction boundary")
            
    def subduction_direction(self, platePair):
        if self.is_subduction_boundary((platePair[0], platePair[1])):
            
            segde = self.subduction_edge_order((platePair[0], platePair[1]))
            #sz goes form segde[0] to segde[1]
            
            if  np.sort(self.get_boundaries(segde[0])).mean() <  np.sort(self.get_boundaries(segde[1])).mean():
                return 1.
            else:
                return -1.
        else:
            raise ValueError("boundary does not exist, or not a subduction boundary")
            

            
            
    def subduction_boundary_from_plate(self, plateId):
        
        if self.is_subducting_plate(plateId):
            
            cps = self.connected_plates(plateId )
            if self.has_edge(*(plateId, cps[0])) and self.has_edge(*(cps[0], plateId)):
                return self.subduction_edge_order((cps[1], plateId))
            else:
                return self.subduction_edge_order((cps[0], plateId))
        else:
            raise ValueError("not a subduction boundary")

   
            
    
    def is_subducting_plate(self, plateId):
        #This one is pretty horrible
        
        sp = False
        cps = self.connected_plates(plateId )
        
        for b in cps:
            if self.has_edge(*(plateId, b)) and self.has_edge(*(b, plateId)) and sp == False:
                sp = False
            else:
                sp = True
                
        return sp
        
        
    
    def has_boundary_plate(self, plateId):
        return  plateId in self[plateId].keys() and len(self.undirected[plateId].keys()) ==2
    
    
    def get_boundaries(self, plateId):
        
        """
        Use the undirected graph
        """
        
        if len(self.connected_plates(plateId)) == 2:

            
            cps = self.connected_plates(plateId)
            #cps = self.undirected.neighbors(plateId)
            #print(cps)
            loc1 = self.undirected[plateId][cps[0]]['loc']
            loc2 = self.undirected[plateId][cps[1]]['loc']   
            return [loc1, loc2]
                
        else:
            print('plate does not have 2 boundaries. Cannot define extent.')
            return []

    #################################     
    ##Adding plates / plate boundaries
    #################################   
    
    def add_plate(self, ID = False, velocities = []):
        if type(velocities) == int or type(velocities) == float:
            vels = np.ones(len(self.times ))*velocities
        elif len(velocities) == len(self.times ):
            vels = velocities
        elif len(velocities) != len(self.times ):
            raise ValueError("velocities must be a single float/int or list/array of length self.times ")
        
        
        if not ID:
            ID = self.plateIdDefaultList[0]
        
        if ID not in self.plateIdUsedList:
            self.add_node(ID, velocities= vels)
            self.plateIdUsedList.append(ID)
            self.plateIdDefaultList.remove(ID)
        else:
            raise ValueError("plate ID already assigned")
        
    def add_subzone(self, subPlate, upperPlate, loc, subInitAge=0.0, upperInitAge=0.0):
        
        #check whether the plate boundary can be simply inserted
        if len(self.connected_plates(subPlate)) <=2 and len(self.connected_plates(upperPlate)) <=2 : 

            self.add_edge(subPlate, upperPlate, 
                          loc= loc,
                          ages = {subPlate:subInitAge, 
                                  upperPlate:upperInitAge})
        else:
            print('plate already has 2 boundaries. Wait for plate transfer to be implemented')
        
    def add_ridge(self, plate1, plate2, loc, plate1InitAge=0.0, plate2InitAge=0.0):
        
        #check whether the plate boundary can be simply inserted
        if len(self.connected_plates(plate1)) <=2 and len(self.connected_plates(plate2)) <=2: 
            
            #note that if plate1 == plate2, there will only be one entry in teh age dictionary
                          
            self.add_edge(plate1, plate2, loc= loc, 
                          ages = {plate1:plate1InitAge, 
                                  plate2:plate2InitAge})
            self.add_edge(plate2, plate1, loc= loc, ages = {plate1:plate1InitAge, 
                                  plate2:plate2InitAge})
        else:
            print('plate already has 2 boundaries. Wait for plate transfer to be implemented')
    
    def add_left_boundary(self, plate,  plateInitAge=0.0):
        
        #check whether the plate boundary can be simply inserted
        if len(self.connected_plates(plate)) <=2: 
            
            #note that if plate1 == plate2, there will only be one entry in teh age dictionary
                          
            self.add_edge(plate, plate, loc= self.minX, 
                          ages = {plate:plateInitAge})
        else:
            print('plate already has 2 boundaries. Wait for plate transfer to be implemented')
    
    def add_right_boundary(self, plate,  plateInitAge=0.0):
        
        #check whether the plate boundary can be simply inserted
        if len(self.connected_plates(plate)) <=2: 
            
            #note that if plate1 == plate2, there will only be one entry in teh age dictionary
                          
            self.add_edge(plate, plate, loc= self.maxX, 
                          ages = {plate:plateInitAge})
        else:
            print('plate already has 2 boundaries. Wait for plate transfer to be implemented')
        
    
    #################################     
    ##Functions...
    ################################# 
        
        
    def plate_id_fn(self, boundtol=1e-5):

        condList = []
        for n in self.nodes():
            bounds  = np.sort(self.get_boundaries(n))

            #edgetol = 1e-4
            if fn.math.abs(bounds[0] - self.minX) < boundtol:
                lb = bounds[0] - boundtol
            else:
                lb = bounds[0]

            if fn.math.abs(bounds[1] - self.maxX) > boundtol:
                ub = bounds[1] + boundtol
            else:
                ub = bounds[1]

            cond = operator.and_(self.xFn >= lb, self.xFn < ub)
            condList.append((cond, n))
        condList.append((True, -99))

        idFn = fn.branching.conditional( condList)
        return idFn
    
    
    def plate_age_fn(self):
        """
        provides a plate age 'field' based on linear interpolation between
        ages provided at boundary locations.
        
        returns a dictionary of functions, so that individual parts of the plate age can be altered
        
        The dictionary can be used in conjuction with the plate ID function (plate_id_fn()) 
        to produce the piecewise plate age 
        
        Here, we iterate through an undireted version of the graph, which is simpler
        
        """
        ageFnDict = {0:fn.misc.constant(0.)}
        
        uG = self.undirected
        
        for n in uG.nodes():
            ns = uG.neighbors(n)
            locAge1 = (uG[n][ns[0]]['loc'],  uG[n][ns[0]]['ages'][n])
            locAge2 = (uG[n][ns[1]]['loc'],  uG[n][ns[1]]['ages'][n])

            #Age gradient
            Agrad = (locAge2[1] - locAge1[1])/(locAge2[0] - locAge1[0])
            
            ageFn=  locAge1[1] + Agrad*(self.xFn - locAge1[0])
            ageFnDict[n] = ageFn
        
        return ageFnDict 
    
    


## Simple test