In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np

from model.inference import *
from model.hapke_model import get_USGS_r_mixed_hapke_estimate
from utils.access_data import *
from utils.constants import *

np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)

In [None]:
IMG_DIR = DATA_DIR + 'GALE_CRATER/cartOrder/cartorder/'
image_file = IMG_DIR + 'layered_img_sec_100_150.pickle'
wavelengths_file = IMG_DIR + 'layered_wavelengths.pickle'

# Normalize spectra across RELAB, USGS, and CRISM per each CRISM image
# (since different CRISM images have different wavelengths)
record_reduced_spectra(wavelengths_file)

image = get_CRISM_data(image_file, wavelengths_file, CRISM_match=True)
print("CRISM image size " + str(image.shape))

## Testing
plot spectra of random pixels from image. frt0002037a_07_if165l_trr3_CAT

In [None]:
# load "l" image, reset all NULLs to 0 
CRISM_DATA_PATH = DATA_DIR + 'GALE_CRATER_TEST_2/'
CRISM_IMG = CRISM_DATA_PATH + 'frt0000c0ef_07_if165l_trr3_CAT.img'
spy_image = envi.open(file=CRISM_IMG + '.hdr')


image_arr = spy_image[:,:,:]
img= np.where(image_arr[:,:,:] == 65535, 0, image_arr) 
# S_IMG_WAVELENGTHS = CRISM_DATA_PATH + 'l_pixel_x_201_y_200.csv'
wavelengths = get_CRISM_wavelengths(CRISM_DATA_PATH + 'pixel_x_262_y_136.csv')



print(len(wavelengths))
print(img.shape)

bands = (300, 200, 50)
from spectral import imshow
imshow(data=img, bands=bands)

In [None]:
#  height = 450
#  width = 640

In [None]:
from spectral import imshow

def plot_cutout_spectra(img, wavelengths, sec_width, sec_height, xstart, ystart, bands):
    """
    Visualize subsection of image and corresponding spectra to see variance
    :param sec_width: number of columns to include
    :param sec_height: number of rows to include
    :param xstart: column to start at (where 0 is left most column)
    :param ystart: row to start at (where 0 is top row)
    """
    fig, ax = plt.subplots(2, 1, constrained_layout=True,  dpi=300 ) # figsize=(4, 2), dpi=DPI
    
    height, width, num_wavelengths = img.shape
    
    avg_spectra = np.zeros(num_wavelengths)
    num_pixels = sec_width*sec_height
     
    for i in range(sec_height):
        for j in range(sec_width): 
                            # img [ height, width ]
            pixel_spectra = img[i+ystart,j+xstart]
            ax[0].plot(wavelengths, pixel_spectra, linewidth=0.5)
            avg_spectra += pixel_spectra 
            
             
    avg_spectra = avg_spectra/num_pixels
    ax[0].plot(wavelengths, avg_spectra, linewidth=1.0, color='red')
    
    ax[0].set_xlabel("Wavelength")
    ax[0].set_ylabel("Reflectance")
    ax[0].set_title("Spectra")
    ax[0].set_ylim((0, 1))

    for i in range(sec_height):
        for j in range(sec_width): 
            pixel_spectra = img[i+ystart,j+xstart] 
            ax[1].plot(wavelengths, pixel_spectra-avg_spectra)
    ax[1].set_title("Normalied Spectra (avg subtracted)")
    ax[1].set_xlabel("Wavelength")
    ax[1].set_ylabel("reflectance - avg")
    ax[1].set_ylim((0,.5))
    
    plt.show()
    
    
     
    view = imshow(data=img[xstart:(xstart+sec_height),ystart:(ystart+sec_width),:], bands=bands)
    return avg_spectra/num_pixels

    

bands = (300, 200, 50)
avg_spectra=plot_cutout_spectra(img=img,
                    wavelengths=wavelengths,
                    sec_width = 600,
                    sec_height = 400,
                    xstart = 20,
                    ystart = 20,
                    bands=bands)


## PT sampler testing

In [None]:
from emcee import PTSampler

# mu1 = [1, 1], mu2 = [-1, -1]
mu1 = np.ones(2)
mu2 = -np.ones(2)

# Width of 0.1 in each dimension
sigma1inv = np.diag([100.0, 100.0])
sigma2inv = np.diag([100.0, 100.0])

def logl(x):
    dx1 = x - mu1
    dx2 = x - mu2

    return np.logaddexp(-np.dot(dx1, np.dot(sigma1inv, dx1))/2.0,
                        -np.dot(dx2, np.dot(sigma2inv, dx2))/2.0)

# Use a flat prior
def logp(x):
    return 0.0

In [None]:
ntemps = 4
nwalkers = 10
ndim = 2

num_burnin_iterations = 100

sampler=PTSampler(ntemps, nwalkers, ndim, logl, logp)
p0 = np.random.uniform(low=-1.0, high=1.0, size=(ntemps, nwalkers, ndim))
for p, lnprob, lnlike in sampler.sample(p0, iterations=num_burnin_iterations):
    pass
sampler.reset()

In [None]:
# At each iteration, this generator for PTSampler yields

# p, the current position of the walkers.
# lnprob the current posterior values for the walkers.
# lnlike the current likelihood values for the walkers.

In [None]:
num_iterations = 100

for p, lnprob, lnlike in sampler.sample(p0=p, lnprob0=lnprob,
                                           lnlike0=lnlike,
                                           iterations=num_iterations, 
                                        thin=10):
    pass

In [None]:
assert sampler.chain.shape == (ntemps, nwalkers, 20, ndim)

In [None]:
assert sampler.chain.shape == (ntemps, nwalkers, 1000, ndim)

# Chain has shape (ntemps, nwalkers, nsteps, ndim)
# Zero temperature mean:
mu0 = np.mean(np.mean(sampler.chain[0,...], axis=0), axis=0)

# Longest autocorrelation length (over any temperature)
max_acl = np.max(sampler.acor)

# etc

In [None]:
print(3/10)

## Infer point

In [2]:
import math
import matplotlib.pyplot as plt
import numpy as np

from model.inference import *
from model.hapke_model import get_USGS_r_mixed_hapke_estimate
from preprocessing.generate_USGS_data import generate_image

from utils.access_data import *
from utils.constants import *

np.set_printoptions(precision=6)
np.set_printoptions(suppress=True)

In [None]:
image = generate_image(num_mixtures=4,
                           grid_res=2,
                           noise_scale=0.0001,
                           res=4)

In [None]:
image.r_image

In [None]:

# ['olivine (Fo51)',
# 'olivine (Fo80)',
# 'augite',
# 'labradorite',
# 'pigeonite',
# 'magnetite',
# 'basaltic glass']
# m_random = np.array([0, 0.4, 0.5,0.1, 0, 0, 0])
# D_random = np.array([80, 80, 60, 60, 60, 60, 60])
def get_rmse(a, b):
    # Print error
    return math.sqrt(np.mean((a - b)**2))


def print_error(m_actual, D_actual, m_est, D_est):
    print("Estimated m " + str(m_est))
    print("Real m    " + str(m_actual))
    print("Estimated D " + str(D_est))
    print("Real D    " + str(D_actual))
    m_rmse = str(round(get_rmse(m_actual, m_est), 2))
    print("RMSE for m: " + m_rmse)
    D_rmse = str(round(get_rmse(D_actual, D_est), 2))
    print("RMSE for D: " + D_rmse)
    return m_rmse, D_rmse

m_errs = []
D_errs = []
r_errs = []
# for i in range(1):
i = 1 


    m_random = np.random.dirichlet(np.ones(USGS_NUM_ENDMEMBERS),
                                       size=1)[0]
    D_random = np.random.randint(low=GRAIN_SIZE_MIN,
                                     high=GRAIN_SIZE_MAX,
                                     size=USGS_NUM_ENDMEMBERS)

    true_m = convert_arr_to_dict(m_random)
    true_D = convert_arr_to_dict(D_random)
    r_actual = get_USGS_r_mixed_hapke_estimate(m=true_m, D=true_D)

    np.set_printoptions(precision=2)
    est_m, est_D = infer_datapoint(iterations=1500, d=r_actual)
 

    wavelengths = N_WAVELENGTHS
    r_est = get_USGS_r_mixed_hapke_estimate(convert_arr_to_dict(est_m),
                                              convert_arr_to_dict(est_D))
    
    m_rmse, D_rmse=print_error(m_random, D_random, est_m, est_D)
    m_errs.append(m_rmse)
    D_errs.append(D_rmse)
    
    
    fig, ax = plt.subplots(1, 1, constrained_layout=True,
                           figsize=(FIG_WIDTH, FIG_HEIGHT), dpi=DPI)
    ax.plot(wavelengths, r_est, label = "Estimated", color = "orange")
    ax.plot(wavelengths, r_actual, label = "Actual", color="blue")

    ax.set_xlabel("Wavelength")
    ax.set_ylabel("Reflectance")
    ax.legend()
    ax.set_title("RMSE M: " + str(m_rmse) + ", D: " + str(D_rmse))
    plt.ylim((0, 1)) 
    plt.savefig("../output/figures/testing/sample_" + str(i))
    plt.show()
    
     
    r_rmse = get_rmse(r_actual, r_est)
    print("Reflectance RMSE: " + str(r_rmse))
    r_errs.append(r_rmse)
    

In [None]:
m_errs_arr = np.array(m_errs)
m_errs_arr = m_errs_arr.astype(np.float)

print("Mean: " + str(np.mean(m_errs_arr)) )
print("Min: " + str(np.amin(m_errs_arr)) )
print("Max: " + str(np.amax(m_errs_arr)) )

In [4]:
len(N_WAVELENGTHS)

277