In [3]:
import random, numpy as np, torch
from typing import List, Union, Tuple, Optional, Dict, Callable
from pynndescent import NNDescent
from spectraclass.data.base import DataManager
from spectraclass.graph.manager import ActivationFlow, ActivationFlowManager, afm
import hvplot.xarray
import holoviews as hv
from torch_geometric.transforms import KNNGraph
from spectraclass.learn.gcn import GCN
from spectraclass.learn.mlp import MLP
import panel as pn
from torch_geometric.data import Data
import xarray as xa

In [4]:
view_band = 10
use_edge_weights = False
use_nndescent = False

knn_transform = KNNGraph(4)
dm: DataManager = DataManager.initialize( "indianPines", 'aviris' )
project_data: xa.Dataset = dm.loadCurrentProject( "main" )
reduced_spectral_data: xa.DataArray = project_data['reduction']
raw_spectral_data: xa.DataArray = project_data['raw']
class_map: xa.DataArray = dm.getClassMap()
node_data: torch.tensor = torch.from_numpy( reduced_spectral_data.values )

if use_nndescent:
    flow: ActivationFlow = afm().getActivationFlow()
    graph: NNDescent = flow.getGraph()
    D: np.ndarray = graph.neighbor_graph[1].flatten()
    edge_weights = GCN.calc_edge_weights( D ) if use_edge_weights else None
    edge_attr: torch.tensor  = torch.from_numpy( D.reshape( D.size, 1 ) )
    spectral_graph_indices = flow.getEdgeIndex()
else:
    spectral_grid: Data = knn_transform( Data( pos=node_data ) )
    spectral_graph_indices = spectral_grid.edge_index
    edge_weights = None

[ny,nx] = raw_spectral_data.shape[1:]
I: np.array = np.array( range( ny*nx ) )
X, Y = I % nx, I // nx
pos: torch.tensor = torch.from_numpy( np.vstack( [Y,X] ).transpose() )
spatial_grid: Data = knn_transform( Data( pos=pos ) )

class_data: torch.tensor = torch.from_numpy( class_map.values.flatten().astype( np.long ) ) - 1
edge_index: torch.tensor = torch.cat( [ spectral_graph_indices, spatial_grid.edge_index], dim=1 )
graph_data = Data( x=node_data, y=class_data, pos=pos, edge_index=edge_index, edge_weights=edge_weights )
nfeatures = graph_data.num_node_features

print( f"raw_spectral_data shape = {raw_spectral_data.shape}")
print( f"reduced_spectral_data shape = {reduced_spectral_data.shape}")
print( f"num_nodes = {graph_data.num_nodes}")
print( f"num_edges = {graph_data.num_edges}")
print( f"num_node_features = {nfeatures}")
print( f"num_edge_features = {graph_data.num_edge_features}")
print( f"contains_isolated_nodes = {graph_data.contains_isolated_nodes()}")
print( f"contains_self_loops = {graph_data.contains_self_loops()}")
print( f"is_directed = {graph_data.is_directed()}")

Loading config files: ['indianPines.py'] from dir /Users/tpmaxwel/.spectraclass/config/aviris
Opening log file:  '/Users/tpmaxwel/.spectraclass/logging/aviris/indianPines.19923.log'

Reading class file: /Users/tpmaxwel/GDrive/Tom/Data/Aviris/IndianPines/documentation/Site3_Project_and_Ground_Reference_Files/19920612_AVIRIS_IndianPine_Site3_gr.tif

raw_spectral_data shape = (220, 145, 145)
reduced_spectral_data shape = (21025, 32)
num_nodes = 21025
num_edges = 168200
num_node_features = 32
num_edge_features = 0
contains_isolated_nodes = False
contains_self_loops = False
is_directed = True


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  class_data: torch.tensor = torch.from_numpy( class_map.values.flatten().astype( np.long ) ) - 1


In [5]:
# band_image_data: np.ndarray = reduced_spectral_data.values[:,view_band].reshape( [1] + list(raw_spectral_data.shape[1:]) )
# band_image: xa.DataArray = class_map.copy( True, band_image_data )
# class_plot = class_map.hvplot.image( cmap='Category20', clim=(0, 20) )
# band_plot = band_image.hvplot.image( cmap='jet', clim=( -1.5, 1.5 ) )
# pn.Row( class_plot, band_plot  ).show("Indian Pines")

In [6]:
num_class_exemplars = 5
class_data: np.ndarray = class_map.values.flatten()
nclasses: int = class_map.values.max()
class_masks: List[np.ndarray] = [ (class_data == (iC+1) ) for iC in range(nclasses) ]
test_mask: np.ndarray = (class_data > 0)
nodata_mask = np.logical_not( test_mask )
class_indices = [ np.argwhere(class_masks[iC]).flatten() for iC in range(nclasses) ]
train_class_indices = [  np.random.choice(class_indices[iC], size=num_class_exemplars, replace=False )  for iC in range(nclasses)  ]
train_indices = np.hstack( train_class_indices )
train_mask = np.full( [ node_data.shape[0] ], False, dtype=bool )
train_mask[ train_indices ] = True
test_mask[ train_indices ] = False

graph_data['train_mask'] = torch.from_numpy( train_mask )
graph_data['test_mask'] = torch.from_numpy( test_mask )
graph_data['nodata_mask'] = torch.from_numpy( nodata_mask )
graph_data

Data(edge_index=[2, 168200], nodata_mask=[21025], pos=[21025, 2], test_mask=[21025], train_mask=[21025], x=[21025, 32], y=[21025])

In [7]:
nhidden = 32
sgd_parms = dict( nepochs = 500, lr = 0.01, weight_decay = 0.0005, dropout = True )
MODEL = GCN
ntrials = 5
accuracy = []

for iT in range(ntrials):
    model = MODEL( nfeatures, nhidden, nclasses )
    MODEL.train_model( model, graph_data, **sgd_parms )
    ( pred, acc ) = MODEL.evaluate_model( model, graph_data )
    accuracy.append( acc )

acc_data = np.array(accuracy)
print( f"Average accuracy over {ntrials} trials = {acc_data.mean()}, std = {acc_data.std()}")

Init GCN: Base Layer weights = [[ 0.22294292 -0.20553133  0.07582068 ... -0.07034136 -0.05711201
   0.05917844]
 [-0.23161839  0.28137854  0.11917046 ...  0.209245    0.08683249
   0.12960252]
 [-0.04322138 -0.27885464 -0.22181621 ... -0.04073685 -0.29417714
  -0.08672206]
 ...
 [-0.16806774 -0.21350883  0.23745194 ...  0.05271587  0.09251931
  -0.21410692]
 [ 0.05852076  0.06490099 -0.11355208 ...  0.21631655 -0.0317381
  -0.00350234]
 [ 0.1004841  -0.06332095 -0.08251467 ...  0.13808352 -0.26738203
   0.17541328]]
Training model with lr=0.02, weight_decay=0.0005, nepochs=1000, dropout=True
epoch: 0, loss = 2.966397523880005
epoch: 25, loss = 1.221968173980713
epoch: 50, loss = 0.8989861607551575
epoch: 75, loss = 0.822873592376709
epoch: 100, loss = 0.7708004713058472
epoch: 125, loss = 0.6899546384811401
epoch: 150, loss = 0.5988602042198181
epoch: 175, loss = 0.5611194372177124
epoch: 200, loss = 0.5343203544616699
epoch: 225, loss = 0.5337408781051636
epoch: 250, loss = 0.52068299

In [8]:
# pred_image_data: np.ndarray = pred_data.reshape( [1] + list(raw_spectral_data.shape[1:]) )
# pred_image: xa.DataArray = class_map.copy( True, pred_image_data )
# class_plot = class_map.hvplot.image( cmap='Category20' )
# pred_plot = pred_image.hvplot.image( cmap='Category20' )
# pn.Row( class_plot, pred_plot  ).show("Indian Pines")
#
