In [372]:
from IntegralQuery import SearchQuery, IntegralQuery, Filter, Range #################################################
import numpy as np
from dataclasses import dataclass
import astropy.units as u
from astropy.coordinates import SkyCoord
from datetime import datetime
from numba import njit, vectorize



@njit
def calculate_distance_matrix(quick_list, angle_weight, time_weight, max_distance):
    l = len(quick_list)
    distances = np.full((l,l), 2*max_distance)
    
    partitions = [0]
    for i in range(1,l):
        if quick_list[i,2]-quick_list[i-1,2] > max_distance/angle_weight:
            partitions.append(i)
    partitions.append(l)
    
    for i in range(len(partitions)-1):
        for j in range(partitions[i], partitions[i+1]):

            for k in range(j+1, partitions[i+1]):
                distances[j,k] = distances[k,j] = calculate_distance(quick_list[j],quick_list[k],
                                                                     angle_weight,time_weight)
                
    np.fill_diagonal(distances,0.)
            
    return np.array(partitions), distances

@njit
def calculate_distance(point1, point2, angle_weight, time_weight):
    ang_dis = np.arccos( np.sin(point1[1])*np.sin(point2[1]) # Large Error for very small angles
                        + np.cos(point1[1])*np.cos(point2[1]) * np.cos(point1[0] - point2[0]) )
    time_dis = abs(point1[2] - point2[2])
    return ( (angle_weight*ang_dis)**2 + (time_weight*time_dis)**2 )**0.5

@njit
def find_regions(distances, max_distance, partitions):
    regions = []
    for i,partition in enumerate(partitions[:-1]):
        unconnected = [j for j in range(partition, partitions[i+1])]
        while not len(unconnected)==0:
            temp_region = [unconnected.pop(0)]
            search_index = 0
            while search_index < len(temp_region):
                l = len(unconnected)
                for j in range(l-1,-1,-1):
                    if distances[ temp_region[search_index], unconnected[j] ] < max_distance:
                        temp_region.append(unconnected.pop(j))
                search_index += 1
            regions.append(sorted(temp_region))
    return regions

@njit
def choose_random_weighted_interval(weights):
    r = np.random.random(1)[0] * np.sum(weights)
    s = 0.
    for i, w in enumerate(weights):
        s += w
        if r < s:
            return i
        
        
    
    

class Cluster:
    def __init__(self,
                 pointing,
                 region):
        self.indices = [pointing.index]
        self.avg_distance = 0.
        self.num_pointings = 1
        self.pointings = [pointing]
        self.region = region
        
    def add_pointing(self, pointing):
        self.avg_distance = self.calc_new_avg_dist(pointing)
        self.indices.append(pointing.index)
        self.pointings.append(pointing)
        self.num_pointings += 1
        
    def should_add_pointing(self, pointing):
        if self.num_pointings < self.region.query._cluster_size_range[1]:
            if self.find_new_max_dist(pointing) < self.region.query._max_distance:
                if self.num_pointings >= self.region.query._cluster_size_range[0]:
                    if (self.calc_new_avg_dist(pointing)/self.avg_distance 
                        < self.region.query._cluster_size_preference_threshold[self.num_pointings 
                                                                                - self.region.query._cluster_size_range[0]]):
                        return True
                else:
                    return True
        return False
    
        
    def dissolve_cluster(self):
        for p in self.pointings:
            p.cluster = None
            
        
        
    def calc_new_avg_dist(self, pointing):
        return ((self.avg_distance*self.num_pointings
                 + np.sum(self.region.query._distances[pointing.index,self.indices]))
                 / (self.num_pointings+1))
        
    def find_new_max_dist(self, pointing):
        temp_indices = self.indices.copy()
        temp_indices.append(pointing.index)
        return np.amax(self.region.query._distances[temp_indices,:][:,temp_indices])
    
        

@dataclass
class Pointing:
    '''
    Dataclass that represents a single Pointing
    '''
    scw_id: str
    sky_coords: SkyCoord
    start_time: datetime
    index: int
    cluster: Cluster = None
    cluster2: Cluster = None ############ Reset this, delete unnecessary clusters and regions
    
    def distance_calculator(self, pointing2, angle_weight: float, time_weight: float): #####################################################
        return ( (angle_weight * self.sky_coords.separation(pointing2.sky_coords).deg)**2
                + (time_weight * abs( (self.start_time - pointing2.start_time).total_seconds()/86400 ) )**2 )**0.5






class ClusteredQuery:
    def __init__(self,
                 scw_ids, # Has to be sorted by START_DATE
                 angle_weight,
                 time_weight,
                 max_distance,
                 cluster_size_range = (3,5),
                 cluster_size_preference_threshold = (3,3),
                 failed_improvements_max = 4,
                 suboptimal_cluster_size_range = (1,2), #has to start at 1
                 close_suboptimal_cluster_size_range = (1,3) # above has to be subset
                 ):
        """
        Init the Clustered Query object. Used to cluster pointings
        
        """
        self._max_distance = float(max_distance)
        self._cluster_size_range = cluster_size_range
        self._cluster_size_preference_threshold = cluster_size_preference_threshold
        self._failed_improvements_max = failed_improvements_max
        self._suboptimal_cluster_size_range = suboptimal_cluster_size_range
        self._close_suboptimal_cluster_size_range = close_suboptimal_cluster_size_range
        
        self._num_pointings = len(scw_ids)
        
        self.quick_list = np.zeros((self._num_pointings, 3))
        self.quick_list[:,0:2] = scw_ids[:,1:3]
        for i in range(self._num_pointings):
            self.quick_list[i,2] = (scw_ids[i,3] - datetime(2000,1,1,0,0,0)).total_seconds()/86400
            
        partitions, self._distances = calculate_distance_matrix(self.quick_list, angle_weight, time_weight, self._max_distance)
        
        self._region_indices = find_regions(self._distances, self._max_distance, partitions)
        
        
        self._pointings = np.array([Pointing(pointing[0],
                                            SkyCoord(pointing[1],pointing[2],frame="icrs",unit="deg"),
                                            pointing[3], index)
                                    for index, pointing in enumerate(scw_ids)])
                
        self._regions = []
        
        for i in self._region_indices:
            self._regions.append(Region(i, self))
            
    
    
    @property
    def pointings(self):
        """
        :returns: Base List of Pointings
        """
        return self._pointings
    
    @property
    def distances(self):
        """
        :returns: Base Distance Matrix of Pointings
        """
        return self._distances
    

    
    
class Region:
    def __init__(self,
                 region_indices,
                 query
                 ):
        self.indices = region_indices
        
        self.query = query
        
        self.distances = query._distances[self.indices,:][:,self.indices] ############ point?
        self.sortable_distances = np.concatenate((self.distances,np.array([self.indices]).T),axis=1) # Adds indices as right column, sort using a[a[:,x].argsort()]
        
        self.clusters = {}
        self.potential_clusters1 = {}
        self.potential_clusters2 = {}
        for i in range(query._cluster_size_range[1]):
            self.clusters[i+1]=[]
            self.potential_clusters1[i+1]=[]
            self.potential_clusters2[i+1]=[]
            

        
        self.initial_clustering()
        
        failed_improvements = 0
        while failed_improvements < self.query._failed_improvements_max and self.has_suboptimal_clusters():
            if not self.attempt_improvement():
                failed_improvements += 1
            else:
                failed_improvements = 0
        
    
    def initial_clustering(self):
        cluster = Cluster(self.query._pointings[self.indices[0]], self)
        for index in self.indices[1:]:
            if cluster.should_add_pointing(self.query._pointings[index]):
                cluster.add_pointing(self.query._pointings[index])
            else:
                self.clusters[cluster.num_pointings].append(cluster)
                cluster = Cluster(self.query._pointings[index], self)
        self.clusters[cluster.num_pointings].append(cluster)
        
    def attempt_improvement(self):
        c1 = self.find_suboptimal_cluster()
        c2 = self.find_close_suboptimal_cluster(c1)
        pass
    
    def find_suboptimal_cluster(self):
        size_weights = np.array([len(self.clusters[i]) / i**2
                                 for i in range(self.query._suboptimal_cluster_size_range[0],
                                                self.query._suboptimal_cluster_size_range[1] + 1)])
        size = choose_random_weighted_interval(size_weights) + 1
        index = np.random.randint(len(self.clusters[size]))
        cluster = self.clusters[size].pop(index)
        self.potential_clusters1[size].append(cluster)
        return cluster
    
    def find_close_suboptimal_cluster(self, cluster): ######################################## test this
        clusters = []
        cluster_size_indices = [0]
        for i in range(self.query._close_suboptimal_cluster_size_range[0], 
                       self.query._close_suboptimal_cluster_size_range[1] + 1):
            clusters.append(self.clusters[i])
            cluster_size_indices.append( len(self.clusters[i]) + cluster_size_indices[i-1] )
        cluster_weights = np.zeros(len(clusters))
        for i, c in enumerate(clusters):
            cluster_weights[i] = 1 / np.amin( self.query._distances[cluster._indices,:][:,c._indices] )**2
        for i in range(1, self.query._close_suboptimal_cluster_size_range[1] + 1):
            cluster_weights[cluster_size_indices[i-1]:cluster_size_indices[i]] /= i
        index = choose_random_weighted_interval( cluster_weights )
        for size, csi in enumerate(cluster_size_indices):
            if not index >= csi:
                break
        true_index = index - cluster_size_indices[size-1]
        
        cluster2 = self.clusters[size].pop(true_index)
        self.potential_clusters1[size].append(cluster2)
        return cluster2
        
        
        
        
    
    def find_cluster_path(self, cluster1, cluster2):
        pass
    
    def calc_cluster_cost(self, clusters):
        pass
    
    def recluster_pointings(self, pointings):
        pass
    
    def has_suboptimal_clusters(self):
        s = 0
        for i in range(self.query._suboptimal_cluster_size_range[0], 
                       self.query._suboptimal_cluster_size_range[1] + 1):
            s += len(self.clusters[i])
        if s >= 2:
            return True
        elif s >= 1:
            for i in range(self.query._suboptimal_cluster_size_range[1] + 1,
                           self.query._close_suboptimal_cluster_size_range[1] + 1):
                s += len(self.clusters[i])
            if s >= 2:
                return True
        return False
            
        






In [373]:
searchquerry = SearchQuery(object_name="Cyg X-1", resultmax=0)
cat = IntegralQuery(searchquerry)
f = Filter(SCW_TYPE="POINTING")
scw_ids = cat.apply_filter(f,True)



In [374]:
test = ClusteredQuery(scw_ids, 1, 1, 2.8)
#print(test._region_indices)

In [376]:
for key,value in test._regions[1].clusters.items():
    print(f"{key},{len(value)}")

1,0
2,2
3,15
4,11
5,119


In [377]:
test._distances

array([[0.        , 0.10729474, 0.11235486, ..., 5.6       , 5.6       ,
        5.6       ],
       [0.10729474, 0.        , 0.02592593, ..., 5.6       , 5.6       ,
        5.6       ],
       [0.11235486, 0.02592593, 0.        , ..., 5.6       , 5.6       ,
        5.6       ],
       ...,
       [5.6       , 5.6       , 5.6       , ..., 0.        , 2.18506734,
        2.81393936],
       [5.6       , 5.6       , 5.6       , ..., 2.18506734, 0.        ,
        0.62942099],
       [5.6       , 5.6       , 5.6       , ..., 2.81393936, 0.62942099,
        0.        ]])

In [383]:
a = np.reshape( np.arange(100), (10,10) )
a[[1,4,5],:][:,[3,7,8,9]]

array([[13, 17, 18, 19],
       [43, 47, 48, 49],
       [53, 57, 58, 59]])