In [110]:
import time, random, numpy as np, torch
from typing import List, Union, Tuple, Optional, Dict, Callable
import networkx as nx
from spectraclass.data.base import DataManager
import torch_geometric
from torchmetrics.functional import *
import hvplot.xarray
import holoviews as hv
import panel as pn
from torch_geometric.data import Data
import xarray as xa

In [125]:
tile_size = -1 #10
tile_offset = 200
t0, t1 = tile_offset, tile_offset+tile_size
dm: DataManager = DataManager.initialize( "pavia", 'aviris' )
project_data: xa.Dataset = dm.loadCurrentProject( "main" )
reduced_spectral_data: xa.DataArray = project_data['reduction']
raw_spectral_data: xa.DataArray = project_data['raw']
xaClassmap: xa.DataArray = dm.getClassMap()
class_map: np.ndarray = xaClassmap.values
nfeatures = reduced_spectral_data.shape[1]
node_data_map: np.ndarray = reduced_spectral_data.values.reshape( list(class_map.shape) + [ nfeatures ] )
if tile_size > 0:
    class_map = class_map[t0:t1,t0:t1]
    node_data_map = node_data_map[t0:t1,t0:t1,:]
gshape = list(class_map.shape)
node_data: torch.tensor = torch.from_numpy( node_data_map.reshape( [ gshape[0]*gshape[1], nfeatures ] ) )
classes = np.unique( class_map )[1:]
class_sizes = [ np.count_nonzero(class_map == iC) for iC in classes ]
nclasses: int = len(classes)
print( f" Classes({nclasses}) = {classes}, sizes = {class_sizes} ")

Loading config files: ['pavia.py'] from dir /Users/tpmaxwel/.spectraclass/config/aviris
Opening log file:  '/Users/tpmaxwel/.spectraclass/logging/aviris/pavia.41047.log'

Reading class file: /Users/tpmaxwel/GDrive/Tom/Data/Aviris/Pavia/Pavia_gt.mat

Reading variable 'paviaU_gt' from Matlab dataset '/Users/tpmaxwel/GDrive/Tom/Data/Aviris/Pavia/Pavia_gt.mat': b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Fri May 20 18:25:52 2011'
 Classes(9) = [1 2 3 4 5 6 7 8 9], sizes = [6631, 18649, 2099, 3064, 1345, 5029, 1330, 3682, 947] 


In [112]:
grid = torch_geometric.utils.grid( *gshape )
grid_edges: torch.Tensor = grid[0]
grid_pos = grid[1]
node_spectra = [ node_data[ grid_edges[i] ] for i in (0,1) ]
weights: torch.Tensor = torch.abs( 1.0 - cosine_similarity( *node_spectra, 'none' ) )

In [113]:
num_class_exemplars = 3
class_data: np.ndarray = class_map.flatten()
class_masks: List[np.ndarray] = [(class_data == iC) for iC in classes ]
class_indices = [np.argwhere(class_mask).flatten() for class_mask in class_masks ]
train_class_indices = [np.random.choice(class_indices[iC], size=num_class_exemplars, replace=False) for iC in range(nclasses)]


In [114]:
base_index = grid_pos.shape[0]
enodes = []
tnode_indices = []
tnode_pos = []
for (iC, iNodes) in enumerate( train_class_indices ):
    tnode_index = base_index+iC
    tNodes = torch.full( [num_class_exemplars], tnode_index )
    enodes.append( torch.vstack( [ tNodes, torch.from_numpy(iNodes) ] ) )
    tnode_indices.append( tnode_index )
    tnode_pos.append( ( -1.0, float(iC) ) )

rnode_index = tnode_indices[-1]+1
rNodes = torch.full( [len(tnode_indices)], rnode_index )
enodes.append( torch.vstack( [ rNodes, torch.tensor( tnode_indices ) ] ) )
augmented_edges = torch.hstack( [grid_edges] + enodes )
additional_weights = torch.zeros( [ augmented_edges.shape[1]-grid_edges.shape[1] ] )
augmented_weights = torch.cat( [ weights, additional_weights ] )
grid_pos = torch.vstack( [ grid_pos, torch.tensor( tnode_pos ) ] )

In [115]:
from networkx.algorithms.tree.mst import minimum_spanning_tree
t0 = time.time()
data = Data( x=torch.range(0,rnode_index), edge_index=augmented_edges, edge_attr=dict( distance=augmented_weights ) )
nxgraph = torch_geometric.utils.to_networkx( data, to_undirected=True )
spanning_forest: nx.Graph = minimum_spanning_tree( nxgraph, weight='distance', algorithm='prim' )
spanning_forest.remove_node( rnode_index )
print( f"Computed spanning forest in {time.time()-t0} sec" )

  data = Data( x=torch.range(0,rnode_index), edge_index=augmented_edges, edge_attr=dict( distance=augmented_weights ) )


Computed spanning forest in 5.325803995132446 sec


In [123]:
from networkx.algorithms.components import connected_components
t0 = time.time()
tgen = connected_components( spanning_forest )
class_trees = [ np.array(list(c)) for c in nx.connected_components(spanning_forest) ]
print( f"Computed connected_components in {time.time()-t0} sec" )
#nx.draw_networkx( spanning_forest, node_size=20, node_shape="o", labels=None )

Computed connected_components in 0.2564969062805176 sec


In [128]:
from spectraclass.gui.spatial.image import toXA

seg_data = np.full( gshape, 0 ).flatten()
index_thresh = gshape[0]*gshape[1]
for iC, class_nodes in enumerate( class_trees ):
    usable_class_nodes = class_nodes[ class_nodes < index_thresh ]
    seg_data[ usable_class_nodes ] = iC + 1

seg_map: xa.DataArray = toXA( "Segmentation", seg_data.reshape(gshape) )
seg_plot = seg_map.hvplot.image( cmap='Category20' )
class_plot = xaClassmap.hvplot.image( cmap='Category20' )
# pn.Column( [ class_plot, seg_plot ] )
# class_plot
seg_plot