In [255]:
import time, random, numpy as np, torch
from typing import List, Union, Tuple, Optional, Dict, Callable
import networkx as nx
from spectraclass.data.base import DataManager
import torch_geometric
from torchmetrics.functional import cosine_similarity
from sklearn.metrics import mean_squared_error
import hvplot.xarray
import holoviews as hv
import panel as pn
from torch_geometric.data import Data
import xarray as xa

def metric( x: torch.Tensor, y: torch.Tensor, type: str ):
    if type == "mse":
        diff = x - y
        sum_squared_error = torch.sum( diff * diff, dim=1 )
        return  torch.sqrt( sum_squared_error / x.shape[1] )
    elif type == "cosine":
        return 1 - cosine_similarity( x, y, "none" )
    else:
        raise Exception( f"Unknown metric type: {type}")

def npmetric( x: np.ndarray, y: np.ndarray, type: str = "mse" ):
    if type == "mse":
        sqdiff  = np.average( (x - y) ** 2, axis=1 )
        return np.sqrt(sqdiff)
    else:
        raise Exception( f"Unknown metric type: {type}")

In [295]:
tile_size = -1 #10
tile_offset = 200
t0, t1 = tile_offset, tile_offset+tile_size
dm: DataManager = DataManager.initialize( "indianPines", 'aviris' )
project_data: xa.Dataset = dm.loadCurrentProject( "main" )
reduced_spectral_data: xa.DataArray = project_data['reduction']
raw_spectral_data: xa.DataArray = project_data['raw']
xaClassmap: xa.DataArray = dm.getClassMap()
class_map: np.ndarray = xaClassmap.values
nodata_mask = class_map == 0
nfeatures = reduced_spectral_data.shape[1]
node_data_map: np.ndarray = reduced_spectral_data.values.reshape( list(class_map.shape) + [ nfeatures ] )
if tile_size > 0:
    class_map = class_map[t0:t1,t0:t1]
    node_data_map = node_data_map[t0:t1,t0:t1,:]
gshape = list(class_map.shape)
node_data: torch.tensor = torch.from_numpy( node_data_map.reshape( [ gshape[0]*gshape[1], nfeatures ] ) )
all_classes = np.unique( class_map )
classes = np.delete( all_classes, np.where(all_classes <= 0) )
class_sizes = [ np.count_nonzero(class_map == iC) for iC in classes ]
nclasses: int = len(classes)
print( f" Classes({nclasses}) = {classes}, sizes = {class_sizes} ")

Loading config files: ['indianPines.py'] from dir /Users/tpmaxwel/.spectraclass/config/aviris
Opening log file:  '/Users/tpmaxwel/.spectraclass/logging/aviris/indianPines.41047.log'

Reading class file: /Users/tpmaxwel/GDrive/Tom/Data/Aviris/Pavia/Pavia_gt.mat

Reading variable 'paviaU_gt' from Matlab dataset '/Users/tpmaxwel/GDrive/Tom/Data/Aviris/Pavia/Pavia_gt.mat': b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Fri May 20 18:25:52 2011'
 Classes(9) = [1 2 3 4 5 6 7 8 9], sizes = [6631, 18649, 2099, 3064, 1345, 5029, 1330, 3682, 947] 


In [296]:
grid = torch_geometric.utils.grid( *gshape )
grid_edges: torch.Tensor = grid[0]
grid_pos = grid[1]
node_spectra = [ node_data[ grid_edges[i] ] for i in (0,1) ]
# weights: torch.Tensor = torch.abs( 1.0 - cosine_similarity( *node_spectra, 'none' ) )
weights: torch.Tensor = metric( *node_spectra, type="mse" )

In [297]:
num_class_exemplars = 5
class_data: np.ndarray = class_map.flatten()
test_mask: np.ndarray = (class_data > 0)
class_masks: Dict[int,np.ndarray] = { iC: (class_data == iC) for iC in classes }
class_indices: Dict[int,np.ndarray] = { iC: np.argwhere(class_mask).flatten() for iC, class_mask in class_masks.items() }
train_class_indices: Dict[int,np.ndarray] = { iC: np.random.choice(class_indices[iC], size=num_class_exemplars, replace=False) for iC in classes }
for train_indices in train_class_indices.values(): test_mask[train_indices] = 0

In [298]:
# from torch_geometric.transforms import KNNGraph
# nneighbors = 5
# t0 = time.time()
# node_data: torch.tensor = torch.from_numpy( reduced_spectral_data.values )
# knn_transform = KNNGraph(nneighbors)
# node_indices = np.array( range(node_data.shape[0]) )
# spectral_grid: Data = knn_transform( Data( x=node_indices, pos=node_data ) )
# spectral_graph_indices = spectral_grid.edge_index
# spectral_nodes = [ node_data[ spectral_graph_indices[i] ] for i in (0,1) ]
# spectral_weights: torch.Tensor = metric( *spectral_nodes, type="mse" )
# print( f"KNNGraph: computed spectral KNN graph with weights in {time.time()-t0} sec" )

In [299]:
from pynndescent import NNDescent
nneighbors = 5
t0 = time.time()
node_data: np.ndarray = reduced_spectral_data.values
n_trees = 5 + int(round((node_data.shape[0]) ** 0.5 / 20.0))
n_iters = max(5, 2 * int(round(np.log2(node_data.shape[0]))))
kwargs = dict(n_trees=n_trees, n_iters=n_iters, n_neighbors=nneighbors, max_candidates=60, verbose=True, metric="euclidean" )
nnd =  NNDescent( node_data, **kwargs )
dI: np.ndarray = nnd.neighbor_graph[0]
print( f"NNDescent: computed spectral KNN graph in {time.time()-t0} sec, dI shape = {dI.shape}" )

Mon Aug 30 14:06:51 2021 Building RP forest with 28 trees
Mon Aug 30 14:06:52 2021 NN descent for 36 iterations
	 1  /  36
	 2  /  36
	Stopping threshold met -- exiting after 2 iterations
NNDescent: computed spectral KNN graph in 4.254267930984497 sec, dI shape = (207400, 5)


In [300]:
t0 = time.time()
sp_node_indices = np.array(range(dI.shape[0]))
sp_edges_list = [ np.vstack( [sp_node_indices] * nneighbors ).transpose().flatten(), dI.flatten() ]
sp_nodes = [ node_data[ node_indices ] for node_indices in sp_edges_list ]
sp_weights: np.ndarray = npmetric( *sp_nodes, type="mse" )
print( f"Computed spectral graph weights and edges in {time.time()-t0} sec" )

Computed spectral graph weights and edges in 0.33932018280029297 sec


In [301]:
from torch_geometric.utils import k_hop_subgraph
nnodes = node_data.shape[0]
sp_edges = torch.tensor( np.vstack( sp_edges_list ) )
zero_mask = nodata_mask.flatten()

for iClass in classes:
    train_mask = np.full( [nnodes], False )
    train_nodes = torch.tensor( train_class_indices[iClass] )
    print( f"Class[{iClass}]: " )
    for ihop in range( 0, 4 ):
        if ihop == 0:
            gnodes = train_nodes
        else:
            (gnodes, edges, node_map, node_mask) = k_hop_subgraph( train_nodes, ihop, sp_edges, num_nodes=nnodes, flow="target_to_source" )
        train_mask[ gnodes.numpy() ] = True
        n_nodata = np.count_nonzero( train_mask & zero_mask )
        train_mask[ zero_mask ] = False
        n_gnodes = np.count_nonzero( train_mask )
        gclass_data = class_data[ train_mask ]
        n_correct = np.count_nonzero( gclass_data == iClass )
        print(f"   * Hops: {ihop}, nodes: {n_gnodes}, nodata: {n_nodata}, correct: {n_correct}, accuracy: {n_correct/n_gnodes}")

Class[1]: 
   * Hops: 0, nodes: 5, nodata: 0, correct: 5, accuracy: 1.0
   * Hops: 1, nodes: 14, nodata: 11, correct: 14, accuracy: 1.0
   * Hops: 2, nodes: 41, nodata: 38, correct: 39, accuracy: 0.9512195121951219
   * Hops: 3, nodes: 94, nodata: 119, correct: 88, accuracy: 0.9361702127659575
Class[2]: 
   * Hops: 0, nodes: 5, nodata: 0, correct: 5, accuracy: 1.0
   * Hops: 1, nodes: 19, nodata: 6, correct: 18, accuracy: 0.9473684210526315
   * Hops: 2, nodes: 49, nodata: 31, correct: 48, accuracy: 0.9795918367346939
   * Hops: 3, nodes: 131, nodata: 76, correct: 127, accuracy: 0.9694656488549618
Class[3]: 
   * Hops: 0, nodes: 5, nodata: 0, correct: 5, accuracy: 1.0
   * Hops: 1, nodes: 16, nodata: 9, correct: 10, accuracy: 0.625
   * Hops: 2, nodes: 43, nodata: 42, correct: 22, accuracy: 0.5116279069767442
   * Hops: 3, nodes: 123, nodata: 125, correct: 51, accuracy: 0.4146341463414634
Class[4]: 
   * Hops: 0, nodes: 5, nodata: 0, correct: 5, accuracy: 1.0
   * Hops: 1, nodes: 12, n

In [None]:
# t0 = time.time()
# G = nx.Graph()
# sp_edge_list = [ tup for tup in zip(  sp_edges[0].tolist(), sp_edges[1].tolist(), sp_weights.tolist()  ) ]
# G.add_nodes_from( sp_node_indices )
# G.add_weighted_edges_from( sp_edge_list, weight="distance" )
# print( f"Construted networkx spectral graph in {time.time()-t0} sec" )

In [288]:
base_index = grid_pos.shape[0]
enodes = []
tnode_indices = []
tnode_classes = {}
tnode_pos = []
for (iC, iNodes) in train_class_indices.items():
    tnode_index = base_index+iC
    tNodes = torch.full( [num_class_exemplars], tnode_index )
    enodes.append( torch.vstack( [ tNodes, torch.from_numpy(iNodes) ] ) )
    tnode_indices.append( tnode_index )
    tnode_pos.append( ( -1.0, float(iC) ) )
    tnode_classes[ tnode_index ] = iC

rnode_index = tnode_indices[-1]+1
rNodes = torch.full( [len(tnode_indices)], rnode_index )
enodes.append( torch.vstack( [ rNodes, torch.tensor( tnode_indices ) ] ) )
augmented_edges = torch.hstack( [grid_edges] + enodes )
additional_weights = torch.zeros( [ augmented_edges.shape[1]-grid_edges.shape[1] ] )
augmented_weights = torch.cat( [ weights, additional_weights ] )
grid_pos = torch.vstack( [ grid_pos, torch.tensor( tnode_pos ) ] )

In [224]:
from networkx.algorithms.tree.mst import minimum_spanning_tree
t0 = time.time()
data = Data( x=torch.range(0,rnode_index), edge_index=augmented_edges, edge_attr=dict( distance=augmented_weights ) )
nxgraph = torch_geometric.utils.to_networkx( data, to_undirected=True )
spanning_forest: nx.Graph = minimum_spanning_tree( nxgraph, weight='distance', algorithm='prim' )
spanning_forest.remove_node( rnode_index )
print( f"Computed spanning forest in {time.time()-t0} sec" )

  data = Data( x=torch.range(0,rnode_index), edge_index=augmented_edges, edge_attr=dict( distance=augmented_weights ) )


Computed spanning forest in 9.436674118041992 sec


In [225]:
from networkx.algorithms.components import connected_components
t0 = time.time()
tgen = connected_components( spanning_forest )
class_segments = {}
for c in nx.connected_components(spanning_forest):
    seg_nodes = np.array(list(c))
    tnode_index = seg_nodes.max()
    if tnode_index in tnode_classes:
        iC = tnode_classes[ tnode_index ]
        class_segments[ iC ] = seg_nodes[ seg_nodes != tnode_index ]
print( f"Computed connected_components in {time.time()-t0} sec" )
#nx.draw_networkx( spanning_forest, node_size=20, node_shape="o", labels=None )

Computed connected_components in 0.31702589988708496 sec


In [226]:
from spectraclass.gui.spatial.image import toXA

seg_data = np.full( gshape, 0 ).flatten()
index_thresh = gshape[0]*gshape[1]
for iC, class_nodes in class_segments.items():
    seg_data[ class_nodes ] = iC
seg_data[ nodata_mask.flatten() ] = 0
seg_map: xa.DataArray = toXA( "Segmentation", seg_data.reshape(gshape) )

correct = np.count_nonzero( seg_data[test_mask] == class_data[test_mask] )
acc = correct / np.count_nonzero(test_mask)
print( f"Classification accuracy (#exemplars/class={num_class_exemplars}): {acc}")

Classification accuracy (#exemplars/class=5): 0.6569469471812034


In [228]:
seg_plot = seg_map.hvplot.image( cmap='Category20' )
class_plot = xaClassmap.hvplot.image( cmap='Category20' )
# pn.Column( [ class_plot, seg_plot ] )
# class_plot
seg_plot