In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal
import torch as torch
import torch.nn.functional as F
import rasterio as rio
import rasterio.windows

import copy

from tqdm.notebook import tqdm

from hyperspectral.math import zca_whitening_matrix
from hyperspectral.band_selection import *

# Explore #

This sequence was obtained from [The USGS Spectral Library Version 7](https://crustal.usgs.gov/speclab/QueryAll07a.php).  It can be found by going to that site and searching for `Algea` or similar in the quick search box.

## Clip Spectrum ##

Only the first 993 bands of the 2150 bands in the spectrum are valid.  That corresponds to the first 103 bands in AVIRIS.

In [None]:
filename = 'data2/spectra/0/splib07a_Red_Coated_Algea_Water_RCAW1_ASDFRa_AREF.txt'
spectrum = []
with open(filename, 'r') as f:
    f.readline()
    for line in f.readlines():
        spectrum.append(float(line))
spectrum = np.array(spectrum)
spectrum = spectrum / np.linalg.norm(spectrum[0:993], ord=2) # only the first 993 bands are good
spectrum[993:] = 0
spectrum_normalized = scipy.signal.resample(spectrum, 224) - spectrum.mean()
spectrum_normalized -= np.min(spectrum_normalized[:103])
spectrum_normalized[103:] = 0

In [None]:
plt.plot(spectrum_normalized)

In [None]:
spectrum.shape

## Perform Inference ##

In [None]:
tile_size = 64
def infer1(in_filename, out_filename, spec):
    with rio.open(in_filename, 'r') as in_ds:
        profile = copy.deepcopy(in_ds.profile)
        profile.update(count=1, driver='GTiff', bigtiff='yes', compress='deflate', predictor='2', tiled='yes', dtype=np.float32, sparse_ok='yes')
        with rio.open(out_filename, 'w', **profile) as out_ds:
            for col in tqdm(range(0, in_ds.width, tile_size), position=0):
                width = min(col+tile_size, in_ds.width) - col
                for row in tqdm(range(0, in_ds.height, tile_size), position=1, leave=False):
                    height = min(row+tile_size, in_ds.height) - row
                    window = rasterio.windows.Window(col, row, width, height)
                    data = in_ds.read(1, window=window)
                    if np.abs(data).sum() == 0:
                        continue
                    data = np.transpose(in_ds.read(window=window).astype(np.float32), (1,2,0))
                    norm = np.linalg.norm(data[:,:,103:], ord=2, axis=2)[..., None].astype(np.float32)
                    data /= norm
                    data -= np.mean(data, axis=2)[...,None]
                    data = np.dot(data, spec)
                    data[np.isnan(data)] = 0
                    data = data.reshape(1, width, height).astype(np.float32)
                    out_ds.write(data, window=window)

In [None]:
in_out = [
    ('data2/LakeErie/f080723t01p00r06/f080723t01p00r06rdn_c_sc01.tif', 'data2/f080723t01p00r06rdn_c_sc01_result.tif'),
    ('data2/LakeMichigan/f080709t01p00r13/f080709t01p00r13rdn_c_sc01.tif', 'data2/f080709t01p00r13rdn_c_sc01_result.tif')
]

In [None]:
for (in_filename, out_filename) in in_out:
    infer1(in_filename, out_filename, spectrum_normalized)

# Save Samples #

In [None]:
in_filename1 = 'data2/LakeMichigan/f080709t01p00r13/f080709t01p00r13rdn_c_sc01.tif'
in_filename2 = 'data2/f080709t01p00r13rdn_c_sc01_result.tif'

In [None]:
tile_size = 64
algea_yes = []
algea_no = []
with rio.open(in_filename1, 'r') as ds1, rio.open(in_filename2, 'r') as ds2:
    width, height = ds1.width, ds1.height
    for col in tqdm(range(0, width, tile_size), position=0):
        for row in tqdm(range(0, height, tile_size), position=1, leave=False):
            window = rasterio.windows.Window(col, row, tile_size, tile_size)
            data0 = ds1.read(1, window=window).reshape(-1,1)
            if np.abs(data0).sum() == 0:
                continue
            data1 = np.transpose(ds1.read(window=window).astype(np.float32), (1,2,0)).reshape(-1,224)
            data2 = np.transpose(ds2.read(window=window).astype(np.float32), (1,2,0)).reshape(-1,1)
            algea_yes.append(data1[np.squeeze(data2 > +2.0)])
            algea_no.append(data1[np.squeeze((data2 < +1.0) * (data2 != 0) * (data0 != 0))])

In [None]:
algea_yes = np.concatenate(algea_yes)
algea_no = np.concatenate(algea_no)

In [None]:
np.savez('data2/algea.npz', algea_yes=algea_yes, algea_no=algea_no, spectrum=spectrum)

# Band Selection and Whitening #

## Whitening (Sphering) ##

![image.png](attachment:image.png)

In [None]:
dictionary = np.load('data2/algea.npz')
pos = dictionary.get('algea_yes')
neg = dictionary.get('algea_no')
spectrum = dictionary.get('spectrum')

spectrum = scipy.signal.resample(spectrum, 224) - spectrum.mean()

# Only the first 103 bands are usable
spectrum = spectrum[:103]
pos = pos[:,:103]
neg = neg[:,:103]

pos /= np.linalg.norm(pos, ord=2, axis=1).reshape(-1,1)
neg /= np.linalg.norm(neg, ord=2, axis=1).reshape(-1,1)

pos -= np.mean(pos, axis=1).reshape(-1,1)
neg -= np.mean(neg, axis=1).reshape(-1,1)

In [None]:
W, mean = zca_whitening_matrix(neg)

In [None]:
def whiten(m, W, mean):
    old_shape = m.shape
    m = m.reshape(-1, old_shape[-1])
    m = m - mean
    m = np.matmul(m, W)
    m = m.reshape(*old_shape)
    return m

In [None]:
whitened_pos = whiten(pos, W, 0)
whitened_neg = whiten(neg, W, 0)
whitened_spectrum = whiten(spectrum, W, 0)

In [None]:
np.dot(whitened_pos, whitened_spectrum).mean()

In [None]:
np.dot(whitened_neg, whitened_spectrum).mean()

In [None]:
channels = tuple(range(0+1,103+1))
tile_size = 64
def infer2(in_filename, out_filename, spec):
    with rio.open(in_filename, 'r') as in_ds:
        profile = copy.deepcopy(in_ds.profile)
        profile.update(count=1, driver='GTiff', bigtiff='yes', compress='deflate', predictor='2', tiled='yes', dtype=np.float32, sparse_ok='yes')
        with rio.open(out_filename, 'w', **profile) as out_ds:
            for col in tqdm(range(0, in_ds.width, tile_size), position=0):
                width = min(col+tile_size, in_ds.width) - col
                for row in tqdm(range(0, in_ds.height, tile_size), position=1, leave=False):
                    height = min(row+tile_size, in_ds.height) - row
                    window = rasterio.windows.Window(col, row, width, height)
                    data = in_ds.read(1, window=window)
                    if np.abs(data).sum() == 0:
                        continue
                    data = np.transpose(in_ds.read(channels, window=window).astype(np.float32), (1,2,0))
                    norm = np.linalg.norm(data, ord=2, axis=2)[..., None].astype(np.float32)
                    data /= norm
                    data -= np.mean(data, axis=2)[...,None]
                    data = whiten(data, W, 0)
                    data = np.dot(data, spec)
                    data[np.isnan(data)] = 0
                    data = data.reshape(1, width, height).astype(np.float32)
                    out_ds.write(data, window=window)

In [None]:
out_filename = 'data2/f080709t01p00r13rdn_c_sc01_result_whitened.tif'
infer2(in_filename1, out_filename, whitened_spectrum)

## Optimization ##

![image.png](attachment:image.png)

In [None]:
device = torch.device("cuda")
# device = torch.device("cpu")

In [None]:
indices = list(range(0, neg.shape[0], 32))

In [None]:
neg_subset = neg[indices]

In [None]:
ratio = int(neg_subset.shape[0] / pos.shape[0])
print(ratio)

In [None]:
pos_repeated = np.repeat(pos, ratio, axis=0)

In [None]:
samples = np.concatenate([neg_subset, pos_repeated], axis=0)
samples = torch.from_numpy(samples.astype(np.float)).unsqueeze(2).to(device)

labels = np.concatenate([np.zeros((neg_subset.shape[0], 1)), np.ones((pos_repeated.shape[0], 1))])
labels = torch.from_numpy(labels.astype(np.float)).unsqueeze(2).to(device)

target = torch.from_numpy(spectrum.astype(np.float)).unsqueeze(0).unsqueeze(2).to(device)

In [None]:
model = MatchedFilter(W/35, 0).to(device)

In [None]:
model = vanilla_train(model, samples, labels, target, device, 1000)

In [None]:
opt_W = model.W.cpu().detach().numpy()
opt_bias = model.bias.cpu().detach().numpy()

In [None]:
opt_W.shape

In [None]:
W = opt_W.squeeze()

In [None]:
opt_whitened_spectrum = whiten(spectrum, W, 0)

In [None]:
out_filename = 'data2/f080709t01p00r13rdn_c_sc01_result_whitened_opt.tif'
infer2(in_filename1, out_filename, opt_whitened_spectrum)

## Band Selection ##

In [None]:
start = neg_subset.shape[0]
length = pos.shape[0]
subset_of_samples = samples[list(range(start, start+length)),...]
mean_of_samples = subset_of_samples.mean(axis=0).unsqueeze(axis=0)

In [None]:
according_to_salience = argsort(model, mean_of_samples, target, [0])

In [None]:
according_to_salience

In [None]:
np.savez('data2/W.npz', opt_W=opt_W, opt_bias=opt_bias, according_to_salience=according_to_salience)

# Test Selected Bands #

In [None]:
dictionary = np.load('data2/W.npz')
according_to_salience = dictionary.get('according_to_salience')

## Best 48 ##

Find the best 48 bands (according to salience).

In [None]:
best_48 = according_to_salience[-48:].squeeze()
pos48 = pos[:, best_48].squeeze()
neg48 = neg[:, best_48].squeeze()
spectrum48 = spectrum[best_48].reshape(1,-1)

In [None]:
W, mean = zca_whitening_matrix(neg48)

In [None]:
whitened_spectrum48 = whiten(spectrum48, W, 0).reshape(-1,1)

In [None]:
tile_size = 64
def infer3(in_filename, out_filename, spec, W, bands):
    with rio.open(in_filename, 'r') as in_ds:
        profile = copy.deepcopy(in_ds.profile)
        profile.update(count=1, driver='GTiff', bigtiff='yes', compress='deflate', predictor='2', tiled='yes', dtype=np.float32, sparse_ok='yes')
        with rio.open(out_filename, 'w', **profile) as out_ds:
            for col in tqdm(range(0, in_ds.width, tile_size), position=0):
                width = min(col+tile_size, in_ds.width) - col
                for row in tqdm(range(0, in_ds.height, tile_size), position=1, leave=False):
                    height = min(row+tile_size, in_ds.height) - row
                    window = rasterio.windows.Window(col, row, width, height)
                    data = in_ds.read(1, window=window)
                    if np.abs(data).sum() == 0:
                        continue
                    data = np.transpose(in_ds.read(bands, window=window).astype(np.float32), (1,2,0))
                    norm = np.linalg.norm(data, ord=2, axis=2)[..., None].astype(np.float32)
                    data /= norm
                    data -= np.mean(data, axis=2)[...,None]
                    data = whiten(data, W, 0)
                    data = np.dot(data, spec)
                    data[np.isnan(data)] = 0
                    data = data.reshape(1, width, height).astype(np.float32)
                    out_ds.write(data, window=window)

In [None]:
out_filename = 'data2/f080709t01p00r13rdn_c_sc01_result_whitened_48.tif'
infer3(in_filename1, out_filename, whitened_spectrum48, W, tuple(best_48+1))

## Worst 48 ##

Extract the 48 worst bands (for testing purposes).

In [None]:
worst_48 = according_to_salience[48:].squeeze()
pos48 = pos[:, worst_48].squeeze()
neg48 = neg[:, worst_48].squeeze()
spectrum48 = spectrum[worst_48].reshape(1,-1)
W, mean = zca_whitening_matrix(neg48)
whitened_spectrum48 = whiten(spectrum48, W, 0).reshape(-1,1)

In [None]:
out_filename = 'data2/f080709t01p00r13rdn_c_sc01_result_whitened_48_worst.tif'
infer3(in_filename1, out_filename, whitened_spectrum48, W, tuple(worst_48+1))

Results: Whitening fails and so does band selection.

# Optimization from Identity #

In [None]:
W = np.eye(103)
whitened_pos = whiten(pos, W, 0)
whitened_neg = whiten(neg, W, 0)
whitened_spectrum = whiten(spectrum, W, 0)

In [None]:
model = MatchedFilter(W, 0).to(device)

In [None]:
model = vanilla_train(model, samples, labels, target, device, 1000)

In [None]:
opt_W = model.W.cpu().detach().numpy()
opt_bias = model.bias.cpu().detach().numpy()
W = opt_W.squeeze()
opt_whitened_spectrum = whiten(spectrum, W, 0)

In [None]:
out_filename = 'data2/f080709t01p00r13rdn_c_sc01_result_whitened_eye_opt.tif'
infer2(in_filename1, out_filename, opt_whitened_spectrum)

In [None]:
np.dot(whitened_pos, opt_whitened_spectrum).mean()

In [None]:
np.dot(whitened_neg, opt_whitened_spectrum).mean()