In [317]:
from IntegralQuery import SearchQuery, IntegralQuery, Filter, Range #################################################
import numpy as np
from dataclasses import dataclass
import astropy.units as u
from astropy.coordinates import SkyCoord, SkyOffsetFrame
from datetime import datetime
from numba import njit



@njit
def calculate_distance_matrix(quick_list, angle_weight, time_weight, max_distance):
    l = len(quick_list)
    distances = np.full((l,l), 2*max_distance)
    
    partitions = [0]
    for i in range(1,l):
        if quick_list[i,2]-quick_list[i-1,2] > max_distance/time_weight:
            partitions.append(i)
    partitions.append(l)
    
    for i in range(len(partitions)-1):
        for j in range(partitions[i], partitions[i+1]):

            for k in range(j+1, partitions[i+1]):
                distances[j,k] = distances[k,j] = calculate_distance(quick_list[j],quick_list[k],
                                                                     angle_weight,time_weight)
                
    np.fill_diagonal(distances,0.)
            
    return np.array(partitions), distances

@njit
def calculate_distance(point1, point2, angle_weight, time_weight): #include minimum distance ################################################
    ang_dis = np.arccos( np.sin(point1[1])*np.sin(point2[1]) 
                        + np.cos(point1[1])*np.cos(point2[1]) * np.cos(point1[0] - point2[0]) )
    if np.isnan(ang_dis): ############ find better solution?
        print("ANG_DIS IS NAN!")
    time_dis = abs(point1[2] - point2[2])
    return ( (angle_weight*ang_dis)**2 + (time_weight*time_dis)**2 )**0.5

@njit
def find_regions(distances, max_distance, partitions):
    regions = []
    for i,partition in enumerate(partitions[:-1]):
        unconnected = [j for j in range(partition, partitions[i+1])]
        while not len(unconnected)==0:
            temp_region = [unconnected.pop(0)]
            search_index = 0
            while search_index < len(temp_region):
                l = len(unconnected)
                for j in range(l-1,-1,-1):
                    if distances[ temp_region[search_index], unconnected[j] ] < max_distance:
                        temp_region.append(unconnected.pop(j))
                search_index += 1
            regions.append(sorted(temp_region))
    return regions

@njit
def choose_random_weighted_interval(weights):
    r = np.random.random(1)[0] * np.sum(weights)
    s = 0.
    for i, w in enumerate(weights):
        s += w
        if r < s:
            return i
        
@njit
def calc_pair_combinations(number):
    return (number)*(number-1)/2
    
    

class Cluster:
    def __init__(self,
                 pointing,
                 region):
        self.indices = [pointing.index]
        self.avg_distance = 0.
        self.num_pointings = 1
        self.pointings = [pointing]
        self.region = region
        
    def add_pointing(self, pointing):
        self.avg_distance = self.calc_new_avg_dist(pointing)
        self.indices.append(pointing.index)
        self.pointings.append(pointing)
        self.num_pointings += 1
        
    def should_add_pointing(self, pointing):
        if self.num_pointings < self.region.query._cluster_size_range[1]:
            if self.find_new_max_dist(pointing) < self.region.query._max_distance:
                if self.num_pointings >= self.region.query._cluster_size_range[0]:
                    if (self.calc_new_avg_dist(pointing)/self.avg_distance ################## make sigmoid? why make sigmoid?
                        < self.region.query._cluster_size_preference_threshold[self.num_pointings 
                                                                                - self.region.query._cluster_size_range[0]]):
                        return True
                else:
                    return True
        return False
    
        
    def dissolve_cluster(self): ################################
        for p in self.pointings:
            p.cluster = None
            
    def finalize_cluster(self): #######################################
        for p in self.pointings:
            p.cluster = self
            
        
        
    def calc_new_avg_dist(self, pointing):
        return ((self.avg_distance * calc_pair_combinations(self.num_pointings)
                 + np.sum(self.region.query._distances[pointing.index,self.indices]))
                 / calc_pair_combinations(self.num_pointings+1) )
        
    def find_new_max_dist(self, pointing): # only checks for new pointing!
        return np.amax(self.region.query._distances[pointing.index,self.indices])

    def find_closest_pointings(self, cluster2):
        d = self.region.query._distances[self.indices,:][:,cluster2.indices]
        c = np.unravel_index(d.argmin(), d.shape)
        return self.pointings[c[0]], cluster2.pointings[c[1]]
    
        

@dataclass
class Pointing:
    '''
    Dataclass that represents a single Pointing
    '''
    scw_id: str
    sky_coords: SkyCoord
    start_time: datetime
    index: int
    cluster: Cluster = None
    # delete unnecessary clusters and regions ##############################################
    
    def distance_calculator(self, pointing2, angle_weight: float, time_weight: float): #####################################################
        return ( (angle_weight * self.sky_coords.separation(pointing2.sky_coords).deg)**2
                + (time_weight * abs( (self.start_time - pointing2.start_time).total_seconds()/86400 ) )**2 )**0.5
        
    def angle_between_three_pointings(self, pointing2, pointing3, angle_weight, time_weight):
        origin = SkyCoord(self.sky_coords.ra.deg, (self.sky_coords.dec.deg%180.)-90., frame="icrs",unit="deg")
        origin_Frame = SkyOffsetFrame(origin = origin)
        
        p1 = self.sky_coords.transform_to(origin_Frame)
        p2 = pointing2.sky_coords.transform_to(origin_Frame)
        p3 = pointing3.sky_coords.transform_to(origin_Frame)
        
        a_a = p2.lon.deg - p3.lon.deg
        a_d2 = abs(p1.lat.deg - p2.lat.deg) * angle_weight
        a_d3 = abs(p1.lat.deg - p3.lat.deg) * angle_weight
        
        t_d2 = (pointing2.start_time - self.start_time).total_seconds()/86400 * time_weight
        t_d3 = (pointing3.start_time - self.start_time).total_seconds()/86400 * time_weight
        
        return np.arccos( np.clip((a_d2*a_d3*np.cos(np.deg2rad(a_a)) + t_d2*t_d3) 
                                  / np.linalg.norm([a_d2,t_d2]) / np.linalg.norm([a_d3,t_d3]), -1., 1.) )
        
        





class ClusteredQuery:
    def __init__(self,
                 scw_ids, # Has to be sorted by START_DATE
                 angle_weight,
                 time_weight,
                 max_distance,
                 cluster_size_range = (3,5),
                 cluster_size_preference_threshold = (5.,5.),
                 failed_improvements_max = 4,
                 suboptimal_cluster_size_range = (1,2), #has to start at 1
                 close_suboptimal_cluster_size_range = (1,3) # above has to be subset
                 ):
        """
        Init the Clustered Query object. Used to cluster pointings
        
        """
        self._angle_weight = float(angle_weight)
        self._time_weight = float(time_weight)
        self._max_distance = float(max_distance)
        self._cluster_size_range = cluster_size_range
        self._cluster_size_preference_threshold = cluster_size_preference_threshold
        self._failed_improvements_max = failed_improvements_max
        self._suboptimal_cluster_size_range = suboptimal_cluster_size_range
        self._close_suboptimal_cluster_size_range = close_suboptimal_cluster_size_range
        
        self._num_pointings = len(scw_ids)
        
        self.quick_list = np.zeros((self._num_pointings, 3))
        self.quick_list[:,0:2] = scw_ids[:,1:3]
        for i in range(self._num_pointings):
            self.quick_list[i,2] = (scw_ids[i,3] - datetime(2000,1,1,0,0,0)).total_seconds()/86400
            
        partitions, self._distances = calculate_distance_matrix(self.quick_list, angle_weight, time_weight, self._max_distance)
        
        self._sortable_distances = np.concatenate((self.distances,np.array([np.arange(self._num_pointings)]).T),axis=1) ################# point?
        
        self._region_indices = find_regions(self._distances, self._max_distance, partitions)
        
        
        self._pointings = np.array([Pointing(pointing[0],
                                            SkyCoord(pointing[1],pointing[2],frame="icrs",unit="deg"),
                                            pointing[3], index)
                                    for index, pointing in enumerate(scw_ids)])
                
        self._regions = []
        
        for i in self._region_indices:
            self._regions.append(Region(i, self))
            
    
    
    @property
    def pointings(self):
        """
        :returns: Base List of Pointings
        """
        return self._pointings
    
    @property
    def distances(self):
        """
        :returns: Base Distance Matrix of Pointings
        """
        return self._distances
    

    
    
class Region:
    def __init__(self,
                 region_indices,
                 query
                 ):
        self.indices = region_indices
        
        self.query = query
        
        self.distances = query._distances[self.indices,:][:,self.indices] ############ point?
        self.sortable_distances = np.concatenate((self.distances,np.array([self.indices]).T),axis=1) # Adds indices as right column, sort using a[a[:,x].argsort()]
        
        self.clusters = {}
        self.potential_clusters1 = {}
        self.potential_clusters2 = {}
        for i in range(query._cluster_size_range[1]):
            self.clusters[i+1]=[]
            self.potential_clusters1[i+1]=[]
            self.potential_clusters2[i+1]=[]
            

        
        self.initial_clustering()
        
        for key, value in self.clusters.items():
            for c in value:
                print(f"{c.indices}, {c.avg_distance}")
        print(self.has_suboptimal_clusters())
        
        self.attempt_improvement()
        
        # failed_improvements = 0
        # while failed_improvements < self.query._failed_improvements_max and self.has_suboptimal_clusters():
        #     if not self.attempt_improvement():
        #         failed_improvements += 1
        #     else:
        #         failed_improvements = 0
        
    
    def initial_clustering(self): ################## also check following pointings
        cluster = Cluster(self.query._pointings[self.indices[0]], self)
        for index in self.indices[1:]:
            if cluster.should_add_pointing(self.query._pointings[index]):
                cluster.add_pointing(self.query._pointings[index])
            else:
                cluster.finalize_cluster()
                self.clusters[cluster.num_pointings].append(cluster)
                cluster = Cluster(self.query._pointings[index], self)
        cluster.finalize_cluster()
        self.clusters[cluster.num_pointings].append(cluster)
        
    def attempt_improvement(self):
        c1 = self.find_suboptimal_cluster()
        c2 = self.find_close_suboptimal_cluster(c1)
        found_path, recluster_indices = self.find_cluster_path(c1,c2)
        if not found_path:
            return False
        else:
            self.recluster_pointings(recluster_indices, c1)
        print()
        print("Done!")
        print()
        for key, value in self.potential_clusters1.items():
            for c in value:
                print(f"{c.indices}, {c.avg_distance}")
        print()
        for key, value in self.potential_clusters2.items():
            for c in value:
                print(f"{c.indices}, {c.avg_distance}")
            
        
    
    def find_suboptimal_cluster(self):
        size_weights = np.array([len(self.clusters[i]) / i**2
                                 for i in range(self.query._suboptimal_cluster_size_range[0],
                                                self.query._suboptimal_cluster_size_range[1] + 1)])
        size = choose_random_weighted_interval(size_weights) + 1
        index = np.random.randint(len(self.clusters[size]))
        cluster = self.clusters[size].pop(index)
        self.potential_clusters1[size].append(cluster)
        return cluster
    
    def find_close_suboptimal_cluster(self, cluster): ################################# maximum distance?
        clusters = []
        cluster_size_indices = [0]
        for i in range(self.query._close_suboptimal_cluster_size_range[0], 
                       self.query._close_suboptimal_cluster_size_range[1] + 1):
            clusters.extend(self.clusters[i])
            cluster_size_indices.append( len(self.clusters[i]) + cluster_size_indices[i-1] )
        cluster_weights = np.zeros(len(clusters))
        for i, c in enumerate(clusters):
            d = self.query._distances[cluster.indices,:][:,c.indices]
            cluster_weights[i] = np.exp( -3.*np.amin(d) / self.query._max_distance ) ##### weights
        for i in range(1, self.query._close_suboptimal_cluster_size_range[1] + 1):
            cluster_weights[cluster_size_indices[i-1]:cluster_size_indices[i]] /= i
        index = choose_random_weighted_interval( cluster_weights )
        for size, csi in enumerate(cluster_size_indices):
            if not index >= csi:
                break
        true_index = index - cluster_size_indices[size-1]
        
        cluster2 = self.clusters[size].pop(true_index)
        self.potential_clusters1[size].append(cluster2)
        return cluster2
        
        
        
        
    
    def find_cluster_path(self, cluster1, cluster2):
        indices_in = set(cluster1.indices)
        indices_to = set(cluster2.indices)
        arrived = False
        while not arrived:
        
            indices_out = np.array([i for i in self.indices if i not in indices_in])
            pointing1, pointing2 = cluster1.find_closest_pointings(cluster2)
            print(pointing1.index, pointing2.index)

            close_indices = self.find_closest_points(pointing1.index, indices_out, 4) #### use 3?
            print(close_indices)
            
            angle = np.vectorize(lambda p: pointing1.angle_between_three_pointings(pointing2, p, self.query._angle_weight,
                                                                                   self.query._time_weight))
            
            angles_filtered = angle(self.query._pointings[close_indices])
            
            distances_filtered = self.query._distances[pointing1.index, close_indices]
            
            close_indices_weights = ((distances_filtered < self.query._max_distance) 
                                     * np.exp(-4. * distances_filtered / self.query._max_distance) ##### weights
                                     * np.cos(angles_filtered/2.)**6)
            
            print(close_indices_weights)
            
            if np.amax(np.cos(angles_filtered)) < 0.:
                print("Cannot go forwards!")
                return False, None
            
            index = close_indices[choose_random_weighted_interval(close_indices_weights)]
            if index is None:
                return False, None
            cluster1 = self.query._pointings[index].cluster
            indices_in = indices_in | set(cluster1.indices)
            if index in indices_to:
                arrived = True
            else:
                self.clusters[cluster1.num_pointings].remove(cluster1)
                self.potential_clusters1[cluster1.num_pointings].append(cluster1)
        
        return True,indices_in
        
        
    def recluster_pointings(self, recluster_indices, start_cluster):
        
        index = start_cluster.indices[choose_random_weighted_interval( 
                                      np.exp(np.average(self.query._distances[start_cluster.indices,:][:,list(recluster_indices)], axis=1) ##### weights
                                             * 4. / self.query._max_distance) )]
        
        cluster = Cluster(self.query._pointings[index], self)
        already_clustered = {index}
        not_clustered = np.array([i for i in recluster_indices if i not in already_clustered])
        
        print()
        print("Recluster Pointings!")
        print(index)
        print(already_clustered)
        print(not_clustered)
        
        
        while not_clustered.size != 0: # check case of only one left
            print()
            print("New Cluster!")

            cluster_done = False
            
            while not cluster_done:
                print()
                print("Continue Cluster!")
                print(index)
                print(already_clustered)
                print(not_clustered)
                
                close_indices = self.find_closest_points(index, not_clustered, 2) ############# use 3?
                
                internal_distances = np.average(self.query._distances[cluster.indices,:][:,close_indices], axis=0)
                external_distances = np.average(self.query._distances[close_indices,:][:,close_indices], axis=0)
                
                close_indices_weights = np.exp( (4.*external_distances - 4.*internal_distances) / self.query._max_distance ) ##### weights
                
                print(close_indices)
                print(internal_distances)
                print(external_distances)
                print(close_indices_weights)
                
                added = False
                while not added:
                    random_index = choose_random_weighted_interval(close_indices_weights)
                    print(random_index)
                    if random_index is None:
                        cluster_done = True
                        break
                    
                    index = close_indices[random_index]
                    print(index)
                    if cluster.should_add_pointing(self.query._pointings[index]):
                        added = True
                        cluster.add_pointing(self.query._pointings[index])
                        already_clustered.add(index)
                        not_clustered = np.array([i for i in recluster_indices if i not in already_clustered])
                        
                        print("Added!")
                        
                        if cluster.num_pointings == self.query._cluster_size_range[1]:
                            cluster_done = True
                        
                            
                            
                    else:
                        print("Found None!")
                        close_indices_weights[random_index] = 0.
                
                if not_clustered.size == 0:
                    self.potential_clusters2[cluster.num_pointings].append(cluster)
                    break
                
                if cluster_done:
                    self.potential_clusters2[cluster.num_pointings].append(cluster)
                    
                    weights = np.exp(3*external_distances / self.query._max_distance)
                    if added:
                        weights[random_index] = 0.
                    index = close_indices[choose_random_weighted_interval(weights)]
                    cluster = Cluster(self.query._pointings[index], self)
                    already_clustered.add(index)
                    not_clustered = np.array([i for i in recluster_indices if i not in already_clustered])
                    break
                    
                else:
                    index = cluster.indices[np.random.randint(cluster.num_pointings)]
                    
            
            
            
            
            
            
            
            
        
    def find_closest_points(self, start_index, search_indices, number_multiplier):
        num_points = min(len(search_indices), self.query._cluster_size_range[1]*number_multiplier)
        sortable_pointings = [i for i in range(num_points)]
        distances = self.query._distances[start_index,search_indices]
        return search_indices[np.argpartition(distances, sortable_pointings)[:num_points]]

    def calc_cluster_cost(self, clusters): 
        pass
    
    def has_suboptimal_clusters(self):
        s = 0
        for i in range(self.query._suboptimal_cluster_size_range[0], 
                       self.query._suboptimal_cluster_size_range[1] + 1):
            s += len(self.clusters[i])
        if s >= 2:
            return True
        elif s >= 1:
            for i in range(self.query._suboptimal_cluster_size_range[1] + 1,
                           self.query._close_suboptimal_cluster_size_range[1] + 1):
                s += len(self.clusters[i])
            if s >= 2:
                return True
        return False
            
        






In [318]:
searchquerry = SearchQuery(object_name="Cyg X-1", resultmax=0)
cat = IntegralQuery(searchquerry)
f = Filter(SCW_TYPE="POINTING")
scw_ids = cat.apply_filter(f,True)



In [319]:
test = ClusteredQuery(scw_ids[:4], 1, 1, 2.8, cluster_size_preference_threshold = (1,1))

[3], 0.0
[0, 1, 2], 0.08185850791460011
True
3 0
[0 2 1]
[0.18432071 0.16085585 0.16064026]

Recluster Pointings!
3
{3}
[0 1 2]

New Cluster!

Continue Cluster!
3
{3}
[0 1 2]
[0 2 1]
[1.18375462 1.27787504 1.2789768 ]
[0.07321653 0.0460936  0.04440689]
[0.20464408 0.17209859 0.17141439]
0
0
Added!

Continue Cluster!
0
{0, 3}
[1 2]
[1 2]
[0.69313577 0.69511495]
[0.01296296 0.01296296]
[0.37844844 0.37737993]
0
1
Added!

Continue Cluster!
1
{0, 1, 3}
[2]
[2]
[0.47205194]
[0.]
[0.50948243]
0
2
Added!

Done!

[3], 0.0
[0, 1, 2], 0.08185850791460011

[3, 0, 1, 2], 0.6643636643410807


In [361]:
test = ClusteredQuery(scw_ids[:200], 1, 1, 2.8, cluster_size_preference_threshold = (100,100))
#print(test._region_indices)

[24, 25], 1.985149479153767
[57, 58], 0.0407060185186765
[5, 6, 7], 1.831296171661494
[12, 13, 14], 1.8267541396306226
[30, 31, 32], 1.823811338951977
[37, 38, 39], 1.8256945060766447
[44, 45, 46], 1.8223984167765328
[8, 9, 10, 11], 1.2999530762680331
[15, 16, 17, 18], 1.2981536904708968
[26, 27, 28, 29], 1.296709040564737
[33, 34, 35, 36], 1.2955873080209874
[40, 41, 42, 43], 1.2937298578418075
[0, 1, 2, 3, 4], 0.7751899214493297
[19, 20, 21, 22, 23], 1.6728204843475751
[47, 48, 49, 50, 51], 1.0296703154623565
[52, 53, 54, 55, 56], 0.7528444914308206
True
57 46
[56 55 54 48 41 34 27 53 51 52 16 50 49 42 47  9 40 35 33 28]
[2.48517331e-01 2.22375725e-01 2.38179932e-01 1.21367886e-01
 9.72960845e-02 7.69975414e-02 6.01226389e-02 2.73961814e-05
 3.16031905e-05 3.15217658e-05 3.87254692e-02 3.71343789e-03
 3.74974079e-03 3.93296203e-03 5.54180070e-02 3.01427079e-02
 5.12837401e-02 4.01689835e-03 4.54442530e-02 4.02479408e-03]
52 46
[51 50 49 42 35 28 46 39 32 17 25 21 10 48 41 14  4 34  3

TypeError: '>=' not supported between instances of 'NoneType' and 'int'