In [12]:
import random, numpy as np, torch
from typing import List, Union, Tuple, Optional, Dict, Callable
from pynndescent import NNDescent
from spectraclass.data.base import DataManager
from spectraclass.graph.manager import ActivationFlow, ActivationFlowManager, afm
import hvplot.xarray
import holoviews as hv
from torch
from spectraclass.learn.gcn import GCN
from spectraclass.learn.mlp import MLP
import panel as pn
from torch_geometric.data import Data
import xarray as xa

In [15]:
view_band = 10
use_edge_weights = False

dm: DataManager = DataManager.initialize( "indianPines", 'aviris' )
project_data: xa.Dataset = dm.loadCurrentProject( "main" )
flow: ActivationFlow = afm().getActivationFlow()
graph: NNDescent = flow.getGraph()
D: np.ndarray = graph.neighbor_graph[1].flatten()
reduced_spectral_data: xa.DataArray = project_data['reduction']
raw_spectral_data: xa.DataArray = project_data['raw']
class_map: xa.DataArray = dm.getClassMap()
edge_weights = GCN.calc_edge_weights( D ) if use_edge_weights else None
[ny,nx] = raw_spectral_data.shape[1:]
I: np.array = np.array( range( ny*nx ) )
X, Y = I % nx, I // nx
from torch_geometric.transforms import KNNGraph
edge_attr: torch.tensor  = torch.from_numpy( D.reshape( D.size, 1 ) )
edge_index: torch.tensor = flow.getEdgeIndex()
node_data: torch.tensor = torch.from_numpy( reduced_spectral_data.values )
pos: torch.tensor = torch.from_numpy( np.vstack( [Y,X] ).transpose() )
class_data: torch.tensor = torch.from_numpy( class_map.values.flatten().astype( np.long ) ) - 1
graph_data = Data( x=node_data, y=class_data, pos=pos, edge_index=edge_index, edge_weights=edge_weights )
nfeatures = graph_data.num_node_features

print( f"raw_spectral_data shape = {raw_spectral_data.shape}")
print( f"reduced_spectral_data shape = {reduced_spectral_data.shape}")
print( f"num_nodes = {graph_data.num_nodes}")
print( f"num_edges = {graph_data.num_edges}")
print( f"num_node_features = {nfeatures}")
print( f"num_edge_features = {graph_data.num_edge_features}")
print( f"contains_isolated_nodes = {graph_data.contains_isolated_nodes()}")
print( f"contains_self_loops = {graph_data.contains_self_loops()}")
print( f"is_directed = {graph_data.is_directed()}")

Loading config files: ['indianPines.py'] from dir /Users/tpmaxwel/.spectraclass/config/aviris
Opening log file:  '/Users/tpmaxwel/.spectraclass/logging/aviris/indianPines.5186.log'
Get Activation flow for dsid aviris_hyperspectral_data/19920612_AVIRIS_IndianPine_Site3.1000-1000_0-0_b-1000-1000-0-0-Autoencoder-32

Reading class file: /Users/tpmaxwel/GDrive/Tom/Data/Aviris/IndianPines/documentation/Site3_Project_and_Ground_Reference_Files/19920612_AVIRIS_IndianPine_Site3_gr.tif

raw_spectral_data shape = (220, 145, 145)
reduced_spectral_data shape = (21025, 32)
num_nodes = 21025
num_edges = 105125
num_node_features = 32
num_edge_features = 0
contains_isolated_nodes = False
contains_self_loops = True
is_directed = True


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  class_data: torch.tensor = torch.from_numpy( class_map.values.flatten().astype( np.long ) ) - 1


In [3]:
# band_image_data: np.ndarray = reduced_spectral_data.values[:,view_band].reshape( [1] + list(raw_spectral_data.shape[1:]) )
# band_image: xa.DataArray = class_map.copy( True, band_image_data )
# class_plot = class_map.hvplot.image( cmap='Category20', clim=(0, 20) )
# band_plot = band_image.hvplot.image( cmap='jet', clim=( -1.5, 1.5 ) )
# pn.Row( class_plot, band_plot  ).show("Indian Pines")

In [4]:
nn_transform = KNNGraph(4)
num_class_exemplars = 5
class_data: np.ndarray = class_map.values.flatten()
nclasses: int = class_map.values.max()
class_masks: List[np.ndarray] = [ (class_data == (iC+1) ) for iC in range(nclasses) ]
test_mask: np.ndarray = (class_data > 0)
nodata_mask = np.logical_not( test_mask )
class_indices = [ np.argwhere(class_masks[iC]).flatten() for iC in range(nclasses) ]
train_class_indices = [  np.random.choice(class_indices[iC], size=num_class_exemplars, replace=False )  for iC in range(nclasses)  ]
train_indices = np.hstack( train_class_indices )
train_mask = np.full( [ node_data.shape[0] ], False, dtype=bool )
train_mask[ train_indices ] = True
test_mask[ train_indices ] = False

graph_data['train_mask'] = torch.from_numpy( train_mask )
graph_data['test_mask'] = torch.from_numpy( test_mask )
graph_data['nodata_mask'] = torch.from_numpy( nodata_mask )
graph_data

Data(edge_index=[2, 105125], test_mask=[21025], train_mask=[21025], x=[21025, 32], y=[21025])

In [5]:
nhidden = 32
sgd_parms = dict( nepochs = 1000, lr = 0.02, weight_decay = 0.0005, dropout = True )
MODEL = GCN
ntrails = 5
accuracy = []

for iT in range(ntrails):
    model = MODEL( nfeatures, nhidden, nclasses )
    MODEL.train_model( model, graph_data, **sgd_parms )
    ( pred, acc ) = MODEL.evaluate_model( model, graph_data )
    accuracy.append( acc )

acc_data = np.array(accuracy)
print( f"Average accuracy over {ntrails} trials = {acc_data.mean()}, std = {acc_data.std()}")

Training model with lr=0.02, weight_decay=0.0005, nepochs=1000, dropout=True
epoch: 0, loss = 3.27402925491333
epoch: 25, loss = 1.1978318691253662
epoch: 50, loss = 0.8777629137039185
epoch: 75, loss = 0.801292896270752
epoch: 100, loss = 0.7167412042617798
epoch: 125, loss = 0.6534552574157715
epoch: 150, loss = 0.6655129790306091
epoch: 175, loss = 0.5843487977981567
epoch: 200, loss = 0.5623810887336731
epoch: 225, loss = 0.5817524194717407
epoch: 250, loss = 0.5197635889053345
epoch: 275, loss = 0.4757311940193176
epoch: 300, loss = 0.5555009245872498
epoch: 325, loss = 0.531602680683136
epoch: 350, loss = 0.55058753490448
epoch: 375, loss = 0.41903916001319885
epoch: 400, loss = 0.4557510018348694
epoch: 425, loss = 0.45506247878074646
epoch: 450, loss = 0.42208462953567505
epoch: 475, loss = 0.4218785762786865
epoch: 500, loss = 0.4339276850223541
epoch: 525, loss = 0.3204771876335144
epoch: 550, loss = 0.38203054666519165
epoch: 575, loss = 0.4494887888431549
epoch: 600, loss =

In [6]:
# pred_image_data: np.ndarray = pred_data.reshape( [1] + list(raw_spectral_data.shape[1:]) )
# pred_image: xa.DataArray = class_map.copy( True, pred_image_data )
# class_plot = class_map.hvplot.image( cmap='Category20' )
# pred_plot = pred_image.hvplot.image( cmap='Category20' )
# pn.Row( class_plot, pred_plot  ).show("Indian Pines")
#
