In [1]:
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_array, check_is_fitted

import numpy as np
import matplotlib.pyplot as plt
import os
import copy
import inspect
import random
import shutil

from datetime import datetime
from dateutil.relativedelta import relativedelta

from scipy import optimize, ndimage
from sklearn import decomposition, cluster, model_selection, metrics
import sklearn

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import utils.dataset_utils as dataset
import utils.train_utils as train

In [2]:
def unit_vector_norm(X):
    X -= X.min() #remove noise offset
    return (X.T / np.sqrt((X**2).sum(axis=1))).T

def blur_norm(X, s):
    return ndimage.gaussian_filter(X, (s,s,0))

def mu_norm(X):
    return X - X.mean(0)

class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)   

In [3]:
# use 368 to 1750
N_WAVE = 700
s = 1
start_index, end_index = 115, 815

X = blur_norm(np.load("../data/HSI/Liver_map_150z25_60s_1TCPOBOP.npy", 'r')[:,:,start_index: end_index], s)
Y = blur_norm(np.load("../data/HSI/Liver_map_150z25_60s_2TCPOBOP.npy", 'r')[:,:,start_index: end_index], s)
Z = blur_norm(np.load("../data/HSI/Liver_map_150z25_60s_3OBOB.npy", 'r')[:,:,start_index: end_index], s)
wavelength = np.load("../data/HSI/wavelength.npy", 'r')[start_index: end_index]

shape_X = X.shape 
shape_Y = Y.shape 
shape_Z = Z.shape 

X = copy.copy(X.reshape(-1, X.shape[-1]))
Y = copy.copy(Y.reshape(-1, Y.shape[-1]))
Z = copy.copy(Z.reshape(-1, Z.shape[-1]))

In [4]:
class AutoEncoderConv(nn.Module):
    def __init__(self, n_components=10, depth=2, neurons=100, bias=True, **kwargs):
        super().__init__()
        self.encode = nn.Sequential( 
            nn.Dropout3d(0.25),
            nn.Flatten(),
            nn.Linear(N_WAVE, neurons, bias=bias),
            nn.ReLU(True),
            *((nn.Linear(neurons, neurons, bias=bias),
            nn.ReLU(True)) * (depth-1)),
            nn.Linear(neurons, n_components, bias=bias),
        )
        
        self.decode = nn.Sequential(
            nn.Linear(n_components, neurons, bias=bias),
            nn.ReLU(True),
            *((nn.Linear(neurons, neurons, bias=bias),
            nn.ReLU(True)) * (depth-1)),            
            nn.Linear(neurons, N_WAVE, bias=bias),
            View((-1,1,1,1,N_WAVE))
        )
        
    def forward(self, x):
        return self.decode(self.encode(x))  
        

In [6]:
def MSE_loss(x, model):
    W = model.encode(x)
    x_ = model.decode(W)

#     print(torch.abs(W @ W.T))
#     print((torch.abs(W @ W.T).sum() - W.size(0)) / (W.size(0)**2 - W.size(0)))
    
    # maximize the difference in reference vectors
    reference_spectra_ = (W.T @ x.squeeze())
    reference_spectra_ = (reference_spectra_.T / torch.sqrt((reference_spectra_**2).sum(axis=1))).T
#     print(reference_spectra_)
    max_ref_diff = ((torch.abs(reference_spectra_ @ reference_spectra_.T).sum() - reference_spectra_.size(0)) / (reference_spectra_.size(0)**2 - reference_spectra_.size(0)))
#     print(reference_spectra_ @ reference_spectra_.T)
    for i in range(W.size(1)):
        for j in range(i + 1, W.size(1)):
            max_ref_diff += 1 / (torch.abs(reference_spectra_[i] - reference_spectra_[j]).sum() + 0.5)
    max_ref_diff /= (W.size(1) * (W.size(1) - 1)) // 2 #number of combinations
  
    #smoothness loss on x1 and x2
    smooth_x1 = torch.abs((x_[:, :, :, :, :-3] - x_[:, :, :, :, 3:])).mean()
    smooth_x2 = torch.abs((x_[:, :, :, :, :-5] - x_[:, :, :, :, 5:])).mean()
    smooth_x3 = torch.abs((x_[:, :, :, :, :-2] - x_[:, :, :, :, 2:])).mean()
    smooth_x4 = torch.abs((x_[:, :, :, :, :-8] - x_[:, :, :, :, 8:])).mean()

    e = ((F.relu(x_ - x))).sum(4).mean() + ((F.relu(-x_))).sum(4).mean()
    
    #MSE loss
    MSE = ((x_ - x)**2).sum(4).mean()  
    
    return MSE + max_ref_diff

In [7]:
class ReferenceVectorClassifierAE(BaseEstimator):
    def __init__(self, **kwargs):
        self.kwargs = {}
        self.k_means_kwargs = {}
        self.ae_kwargs = {}        
        self.set_params(**kwargs)

        _use_cuda = torch.cuda.is_available() and kwargs['cuda']
        if _use_cuda:
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True
        self.device = torch.device('cuda' if _use_cuda else 'cpu')        
        
    def fit(self, x, **kwargs):
        self.set_params(**kwargs)
        X = unit_vector_norm(x)
        
        ###################### Autoencoder ################################
        self.model = AutoEncoderConv(**self.ae_kwargs).to(self.device)
        
        parameters = filter(lambda x: x.requires_grad, self.model.parameters())
        self.optimizer = optim.Adam(parameters)        
        train_loader, test_loader = dataset.load_liver(X, self.kwargs['batch_size'])
        
        for epoch in range(self.kwargs['epochs']):
            print('-'*50)
            print('Epoch {:3d}/{:3d}'.format(epoch+1, self.kwargs['epochs']))
            start_time = datetime.now()
            train.train(self.model, self.optimizer, train_loader, self.kwargs['loss_func'], self.kwargs['log_step'], self.device)
            end_time = datetime.now()
            time_diff = relativedelta(end_time, start_time)
            print('Elapsed time: {}h {}m {}s'.format(time_diff.hours, time_diff.minutes, time_diff.seconds))
            loss = train.test(self.model, test_loader, self.kwargs['loss_func'], self.device)
            print('Validation| bits: {:2.2f}'.format(loss), flush=True)    
          
        self.model.eval()
        with torch.no_grad():
            W = self.model.encode(dataset.load_liver_all(X).to(self.device))
        self.z = W
        W = W.cpu().detach().numpy()
                   
        ###################### reference spectra ################################
#         W = W * (W > 0) #relu
        self.reference_spectra_ = unit_vector_norm(W.T @ X)    
        self.ref_org = unit_vector_norm(W.T @ x)
                
        # Return the classifier
        return self

    
    def predict(self, X):
        """
        predict transforms the data into the reference space. Min weight should be 0 or higher then 'min_weight'
        The error is the NMSE, where the MSE is normalised by the signal strength. 
        error.shape = X.shape[0], so for each data point the error is calculated.
        """
        # Check is fit had been called
        check_is_fitted(self)

        # Input validation
        X = check_array(X)
        X = unit_vector_norm(X)
        
        ###################### RCA ################################           
        RCA_vector = np.array([optimize.nnls(self.reference_spectra_.T, X[i,:])[0] for i in range(X.shape[0])])

        return RCA_vector
    
    def get_reference_vectors(self):
        return self.reference_spectra_

    def get_org_reference_vectors(self):
        return self.ref_org    
    
    def get_params(self, deep=False):
        return self.kwargs
    
    def set_params(self, **kwargs):
        self.kwargs.update(kwargs)
        self.k_means_kwargs.update({k:v  for k,v in kwargs.items() if k in list(inspect.signature(cluster.KMeans).parameters.keys())})     
        self.ae_kwargs.update({k:v  for k,v in kwargs.items() if k in list(inspect.signature(AutoEncoderConv).parameters.keys())})     
        return self

In [8]:
def error_map(estimator, X, y=None):
    RCA = estimator.predict(X)
    ref_vec = estimator.get_reference_vectors()
    return ((RCA @ ref_vec - X)**2).mean(1)

def score_func(estimator, X, y=None):
    X = unit_vector_norm(X)
    return error_map(estimator, X).mean()

def print_mean_std(X):
    return f"{X.mean():<12.4e}{X.std():<12.4e}"

def cross_val_X_Y_Z(rvc, X, Y, Z):
    rvc.fit(np.concatenate((X, Y), axis=0))
    return score_func(rvc, Z)

In [9]:
n = 3

kwargs = {'n_clusters': n,
          'n_components': n,
          'min_weight': 0,
          'batch_size': 64,
          'cuda': True,
          'log_step': 10,
          'loss_func': MSE_loss,
          'epochs': 5,
          'depth': 3,
          'neurons': 30,
          'bias': True
         }

header = f"{'mean':12}{'std':12}"


In [None]:
rvc = ReferenceVectorClassifierAE(**kwargs)
rvc.fit(X)

--------------------------------------------------
Epoch   1/  5
  2020-06-05 15:57:27|     0/   52| bits: 12.69
  2020-06-05 15:57:27|    10/   52| bits: 7.76
  2020-06-05 15:57:27|    20/   52| bits: 5.90
  2020-06-05 15:57:27|    30/   52| bits: 4.06
  2020-06-05 15:57:27|    40/   52| bits: 2.55
  2020-06-05 15:57:27|    50/   52| bits: 2.07
Elapsed time: 0h 0m 2s
Validation| bits: 2.01
--------------------------------------------------
Epoch   2/  5
  2020-06-05 15:57:31|     0/   52| bits: 1.99
  2020-06-05 15:57:31|    10/   52| bits: 1.51
  2020-06-05 15:57:31|    20/   52| bits: 1.23
  2020-06-05 15:57:31|    30/   52| bits: 1.00
  2020-06-05 15:57:32|    40/   52| bits: 0.91
  2020-06-05 15:57:32|    50/   52| bits: 0.61
Elapsed time: 0h 0m 2s
Validation| bits: 0.74
--------------------------------------------------
Epoch   3/  5
  2020-06-05 15:57:35|     0/   52| bits: 0.66
  2020-06-05 15:57:36|    10/   52| bits: 0.71
  2020-06-05 15:57:36|    20/   52| bits: 0.59
  2020-

In [None]:
RCA_vector = rvc.predict(X)

print("fit score: ", score_func(rvc, X))

RCA_vector = RCA_vector - RCA_vector.min(0)
RCA_vector /= RCA_vector.max(0)
plt.figure(figsize = (20,4))
plt.imshow(np.swapaxes(RCA_vector.reshape((*shape_X[:2], kwargs['n_clusters'])),0,1)[::-1,:,:3])
plt.show()
# plt.figure(figsize = (20,4))
# plt.imshow(np.swapaxes(RCA_vector.reshape((*shape_X[:2], kwargs['n_clusters'])),0,1)[::-1,:,3:6])
# plt.show()
# plt.figure(figsize = (20,4))
# plt.imshow(np.swapaxes(RCA_vector.reshape((*shape_X[:2], kwargs['n_clusters'])),0,1)[::-1,:,6:9])
# plt.show()
plt.figure(figsize = (20,4))
plt.imshow(error_map(rvc, X).reshape(shape_X[:2]).T, cmap='gray', vmin=0)
plt.show()

In [None]:
plt.figure(figsize = (20,12))
for i, r in enumerate(rvc.get_reference_vectors()):
    plt.plot(wavelength, r, label=i)
plt.legend()
plt.show()

In [None]:
#reconstructing plot
with torch.no_grad():
    x_ = rvc.model(dataset.load_liver_all(unit_vector_norm(X)).to('cuda:0'))

plt.figure(figsize = (20,12))

for i in range(0,1000,10):
    plt.plot(wavelength, x_[i].flatten().cpu().detach().numpy(),label='rec')
#     plt.plot(wavelength, unit_vector_norm(X)[i] , label='org')    
# plt.legend()
plt.show()

In [None]:
mu = rvc.z.mean(0)
sigma = rvc.z.std(0)
mu,sigma

In [None]:
#mu plot
plt.figure(figsize = (20,12))
plt.plot(wavelength, rvc.model.decode(mu).flatten().cpu().detach().numpy(), label='centre_z')
plt.plot(wavelength, unit_vector_norm(X).mean(0), label='centre') 
plt.legend()
plt.show()

In [None]:
reconstruct = torch.squeeze(rvc.model.decode(torch.eye(rvc.kwargs['n_components']).to(rvc.device)*sigma*3)).cpu().detach().numpy()

print(rvc.z[0])

plt.figure(figsize = (20,12))
for i, r in enumerate(reconstruct):
    plt.plot(wavelength, r, label=i)
plt.plot(wavelength, unit_vector_norm(X).mean(0),label='centre')    
plt.plot(wavelength, unit_vector_norm(X)[0],label='org')    
plt.plot(wavelength, a[0].flatten().cpu().detach().numpy(),label='rec')
plt.legend()
plt.show()


In [None]:
plt.plot(rvc.z[:,0].flatten().cpu().detach().numpy(), rvc.z[:,1].flatten().cpu().detach().numpy(), '.')

In [None]:
plt.plot(rvc.z[:,2].flatten().cpu().detach().numpy(), rvc.z[:,3].flatten().cpu().detach().numpy(), '.')

In [None]:
# print(rvc.z.mean(0))
# print(rvc.z.std(0))

# plt.figure(figsize = (20,12))
# for i in range(0,3750,100):
#     plt.plot(rvc.z[i].cpu().detach().numpy())
# plt.show()