In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal
import torch as torch
import torch.nn.functional as F
import rasterio as rio
import rasterio.windows

import copy

from tqdm.notebook import tqdm

from hyperspectral.math import zca_whitening_matrix
from hyperspectral.band_selection import *

# Load Spectra and Perform Inference #

In [None]:
dictionary = np.load('data2/AVIRIS_oil_ems.npz')
ems = dictionary.get('ems')
[sea_water, oil1, oil2] = ems

sea_water /= np.linalg.norm(sea_water, ord=2)
sea_water -= sea_water.mean()

oil1 /= np.linalg.norm(oil1, ord=2)
oil1 -= oil1.mean()

oil2 /= np.linalg.norm(oil2, ord=2)
oil2 -= oil2.mean()

In [None]:
plt.plot(sea_water)

In [None]:
plt.plot(oil1)

In [None]:
plt.plot(oil2)

In [None]:
def infer1(in_filename, out_filename, spec):
    with rio.open(in_filename, 'r') as in_ds:
        profile = copy.deepcopy(in_ds.profile)
        profile.update(count=1, driver='GTiff', bigtiff='yes', compress='deflate', predictor='2', tiled='yes', dtype=np.float32, sparse_ok='yes')
        with rio.open(out_filename, 'w', **profile) as out_ds:
            for col in tqdm(range(0, in_ds.width, 512), position=0):
                width = min(col+512, in_ds.width) - col
                for row in tqdm(range(0, in_ds.height, 512), position=1, leave=False):
                    height = min(row+512, in_ds.height) - row
                    window = rasterio.windows.Window(col, row, width, height)
                    data = in_ds.read(1, window=window)
                    if np.abs(data).sum() == 0:
                        continue
                    data = np.transpose(in_ds.read(window=window).astype(np.float32), (1,2,0))
                    norm = np.linalg.norm(data, ord=2, axis=2)[..., None].astype(np.float32)
                    data /= norm
                    data -= np.mean(data, axis=2)[...,None]
                    data = np.dot(data, spec)
                    data[np.isnan(data)] = 0
                    data = data.reshape(1, width, height).astype(np.float32)
                    out_ds.write(data, window=window)

In [None]:
in_filename = 'data2/f100517t01p00r14rdn_b/f100517t01p00r14rdn_b_sc01_ort_img.tif'
out_filename = 'data2/results/f100517t01p00r14rdn_b_sc01_ort_img_xxx'

In [None]:
for (i,j) in [(sea_water, '_sea_water.tif'), (oil1, '_oil1.tif'), (oil2, '_oil2.tif')]:
    infer1(in_filename, out_filename + j, i)

# Save Samples #

In [None]:
sea_water_yes = []
sea_water_no = []

oil1_yes = []
oil1_no = []

oil2_yes = []
oil2_no = []

with rio.open(in_filename, 'r') as in_ds, \
    rio.open(out_filename + '_sea_water.tif', 'r') as sea_ds, \
    rio.open(out_filename + '_oil1.tif', 'r') as oil1_ds, \
    rio.open(out_filename + '_oil2.tif', 'r') as oil2_ds:

    width, height = in_ds.width, in_ds.height

    for col in tqdm(range(2048, width, 512), position=0):
        for row in tqdm(range(2048, height, 512), position=1, leave=False):
            window = rasterio.windows.Window(col, row, 512, 512)
            data0 = in_ds.read(1, window=window).reshape(-1,1)
            if np.abs(data0).sum() == 0:
                continue

            data1 = np.transpose(in_ds.read(window=window).astype(np.float32), (1,2,0)).reshape(-1,224)
            data_sea = np.transpose(sea_ds.read(window=window).astype(np.float32), (1,2,0)).reshape(-1,1)
            data_oil1 = np.transpose(oil1_ds.read(window=window).astype(np.float32), (1,2,0)).reshape(-1,1)
            data_oil2 = np.transpose(oil2_ds.read(window=window).astype(np.float32), (1,2,0)).reshape(-1,1)

            sea_water_yes.append(data1[np.squeeze(data_sea > 0.80)])
            sea_water_no.append(data1[np.squeeze((data_sea < 0.75) * (data0 != 0))])

            oil1_yes.append(data1[np.squeeze(data_oil1 > 0.55)])
            oil1_no.append(data1[np.squeeze((data_oil1 < 0.50) * (data0 != 0))])

            oil2_yes.append(data1[np.squeeze(data_oil2 > 0.55)])
            oil2_no.append(data1[np.squeeze((data_oil2 < 0.48) * (data0 != 0))])

In [None]:
sea_water_yes = np.concatenate(sea_water_yes)
sea_water_no = np.concatenate(sea_water_no)

oil1_yes = np.concatenate(oil1_yes)
oil1_no = np.concatenate(oil1_no)

oil2_yes = np.concatenate(oil2_yes)
oil2_no = np.concatenate(oil2_no)

In [None]:
np.savez('data2/oil2.npz',
         sea_water_yes=sea_water_yes, sea_water_no=sea_water_no,
         oil1_yes=oil1_yes, oil1_no=oil1_no,
         oil2_yes=oil2_yes, oil2_no=oil2_no,
         sea_water_spectrum=sea_water, oil1_spectrum=oil1, oil2_spectrum=oil2)

# Band Selection #

In [None]:
def whiten(m, W, mean):
    old_shape = m.shape
    m = m.reshape(-1, old_shape[-1])
    m = m - mean
    m = np.matmul(m, W)
    m = m.reshape(*old_shape)
    return m

In [None]:
def infer2(in_filename, out_filename, spec):
    with rio.open(in_filename, 'r') as in_ds:
        profile = copy.deepcopy(in_ds.profile)
        profile.update(count=1, driver='GTiff', bigtiff='yes', compress='deflate', predictor='2', tiled='yes', dtype=np.float32, sparse_ok='yes')
        with rio.open(out_filename, 'w', **profile) as out_ds:
            for col in tqdm(range(0, in_ds.width, 512), position=0):
                width = min(col+512, in_ds.width) - col
                for row in tqdm(range(0, in_ds.height, 512), position=1, leave=False):
                    height = min(row+512, in_ds.height) - row
                    window = rasterio.windows.Window(col, row, width, height)
                    data = in_ds.read(1, window=window)
                    if np.abs(data).sum() == 0:
                        continue
                    data = np.transpose(in_ds.read(window=window).astype(np.float32), (1,2,0))
                    norm = np.linalg.norm(data, ord=2, axis=2)[..., None].astype(np.float32)
                    data /= norm
                    data -= np.mean(data, axis=2)[...,None]
                    data = whiten(data, W, 0)
                    data = np.dot(data, spec)
                    data[np.isnan(data)] = 0
                    data = data.reshape(1, width, height).astype(np.float32)
                    out_ds.write(data, window=window)

In [None]:
dictionary = np.load('data2/oil2.npz')
pos = dictionary.get('oil1_yes')
neg = dictionary.get('oil1_no')
spectrum = dictionary.get('oil1_spectrum')

spectrum = scipy.signal.resample(spectrum, 224) - spectrum.mean()
pos /= np.linalg.norm(pos, ord=2, axis=1).reshape(-1,1)
neg /= np.linalg.norm(neg, ord=2, axis=1).reshape(-1,1)

pos -= np.mean(pos, axis=1).reshape(-1,1)
neg -= np.mean(neg, axis=1).reshape(-1,1)

In [None]:
W, mean = zca_whitening_matrix(neg)

In [None]:
device = torch.device("cuda")
# device = torch.device("cpu")

In [None]:
indices = list(range(0, neg.shape[0], 128))

In [None]:
neg_subset = neg[indices]

In [None]:
ratio = int(neg_subset.shape[0] / pos.shape[0])
print(ratio)

In [None]:
pos_repeated = np.repeat(pos, ratio, axis=0)

In [None]:
samples = np.concatenate([neg_subset, pos_repeated], axis=0)
samples = torch.from_numpy(samples.astype(np.float)).unsqueeze(2).to(device)

labels = np.concatenate([np.zeros((neg_subset.shape[0], 1)), np.ones((pos_repeated.shape[0], 1))])
labels = torch.from_numpy(labels.astype(np.float)).unsqueeze(2).to(device)

target = torch.from_numpy(spectrum.astype(np.float)).unsqueeze(0).unsqueeze(2).to(device)

In [None]:
model = MatchedFilter(W/5000, 0).to(device)

In [None]:
model = vanilla_train(model, samples, labels, target, device, 1000)

In [None]:
opt_W = model.W.cpu().detach().numpy()
opt_bias = model.bias.cpu().detach().numpy()

In [None]:
W = opt_W.squeeze()

In [None]:
opt_whitened_spectrum = whiten(spectrum, W, 0)

## Inference is Optional ##

In [None]:
infer2(in_filename, out_filename + '_opt_oil1.tif', opt_whitened_spectrum)

# Band Selection #

In [None]:
start = neg_subset.shape[0]
length = pos.shape[0]
subset_of_samples = samples[list(range(start, start+length)),...]
mean_of_samples = subset_of_samples.mean(axis=0).unsqueeze(axis=0)

In [None]:
according_to_salience = argsort(model, mean_of_samples, target, [0])

In [None]:
according_to_salience

In [None]:
np.savez('data2/W2.npz', opt_W=opt_W, opt_bias=opt_bias, according_to_salience=according_to_salience)

# Test Selected Bands #

In [None]:
def infer3(in_filename, out_filename, spec, W, bands):
    with rio.open(in_filename, 'r') as in_ds:
        profile = copy.deepcopy(in_ds.profile)
        profile.update(count=1, driver='GTiff', bigtiff='yes', compress='deflate', predictor='2', tiled='yes', dtype=np.float32, sparse_ok='yes')
        with rio.open(out_filename, 'w', **profile) as out_ds:
            for col in tqdm(range(0, in_ds.width, 512), position=0):
                width = min(col+512, in_ds.width) - col
                for row in tqdm(range(0, in_ds.height, 512), position=1, leave=False):
                    height = min(row+512, in_ds.height) - row
                    window = rasterio.windows.Window(col, row, width, height)
                    data = in_ds.read(1, window=window)
                    if np.abs(data).sum() == 0:
                        continue
                    data = np.transpose(in_ds.read(bands, window=window).astype(np.float32), (1,2,0))
                    norm = np.linalg.norm(data, ord=2, axis=2)[..., None].astype(np.float32)
                    data /= norm
                    data -= np.mean(data, axis=2)[...,None]
                    data = whiten(data, W, 0)
                    data = np.dot(data, spec)
                    data[np.isnan(data)] = 0
                    data = data.reshape(1, width, height).astype(np.float32)
                    out_ds.write(data, window=window)

In [None]:
dictionary = np.load('data2/W2.npz')
according_to_salience = dictionary.get('according_to_salience')

In [None]:
best_48 = according_to_salience[-48:].squeeze()
pos48 = pos[:, best_48].squeeze()
neg48 = neg[:, best_48].squeeze()
spectrum48 = spectrum[best_48].reshape(1,-1)

In [None]:
W, mean = zca_whitening_matrix(neg48)

In [None]:
whitened_spectrum48 = whiten(spectrum48, W, 0).reshape(-1,1)

In [None]:
infer3(in_filename, out_filename + '_opt48_oil1.tif', whitened_spectrum48, W, tuple(best_48+1))