In [None]:
import numpy as np
from scipy import sparse
# import cupy as cp

import sys
sys.path.append('../')
from general_purpose import utilities as ut

import logging
logger = logging.getLogger()

In [None]:
def compute_weight_matrix(reshape_mask, lat):
    '''
    Compute the matrix W such that
    $$ H_2(p) = p^\top W p $$

    Parameters
    ----------
    reshape_mask : np.ndarray[bool]
        mask to flatten a snapshot `p` into a one dimensional array, eventually removing zero variance features
    lat : np.ndarray[float]
        latitude vector, used to compute the grid cell area and the proper longitudinal gradients

    Returns
    -------
    np.ndarray[float]
        W
    '''
    shape = reshape_mask.shape
    shape_r = (np.sum(reshape_mask),)
    if len(shape) != 3:
        raise ValueError(f'reshape_mask should be a 3d array! Instead {reshape_mask.shape = }')

    geosep = ut.Reshaper(reshape_mask)

    W = np.zeros(shape_r*2)
    #f -> field
    #i -> lat
    #j -> lon
    for f in range(shape[-1]):
        for i in range(shape[-3]):
            w = np.cos(lat[i]*np.pi/180)
            wi = 1./w
            for j in range(shape[-2]):
                # add latitude gradient
                try:
                    ind1 = geosep.reshape_index((i,j,f))
                    ind2 = geosep.reshape_index((i+1,j,f))
                except IndexError:
                    logger.debug(f'IndexError: {(i,j,f)}-{(i+1,j,f)}')
                else:
                    W[ind1,ind1] += w
                    W[ind2,ind2] += w
                    W[ind1,ind2] += -w
                    W[ind2,ind1] += -w

                # add longitude gradient
                try:
                    ind1 = geosep.reshape_index((i,j,f))
                    ind2 = geosep.reshape_index((i,j+1,f))
                except IndexError:
                    logger.debug(f'IndexError: {(i,j,f)}-{(i,j+1,f)}')
                else:
                    W[ind1,ind1] += wi
                    W[ind2,ind2] += wi
                    W[ind1,ind2] += -wi
                    W[ind2,ind1] += -wi

            # add periodic longitude point
            try:
                ind1 = geosep.reshape_index((i,shape[-2] - 1,f))
                ind2 = geosep.reshape_index((i,0,f))
            except IndexError:
                logger.debug(f'IndexError: {(i,shape[-2] - 1,f)}-{(i,0,f)}')
            else:
                W[ind1,ind1] += wi
                W[ind2,ind2] += wi
                W[ind1,ind2] += -wi
                W[ind2,ind1] += -wi
    return W

def compute_weight_matrix_sparse(reshape_mask, lat):
    '''
    Compute the matrix W such that
    $$ H_2(p) = p^\top W p $$

    Parameters
    ----------
    reshape_mask : np.ndarray[bool]
        mask to flatten a snapshot `p` into a one dimensional array, eventually removing zero variance features
    lat : np.ndarray[float]
        latitude vector, used to compute the grid cell area and the proper longitudinal gradients

    Returns
    -------
    np.ndarray[float]
        W
    '''
    shape = reshape_mask.shape
    shape_r = (np.sum(reshape_mask),)
    if len(shape) != 3:
        raise ValueError(f'reshape_mask should be a 3d array! Instead {reshape_mask.shape = }')

    geosep = ut.Reshaper(reshape_mask)

    W = sparse.lil_matrix(shape_r*2)
    #f -> field
    #i -> lat
    #j -> lon
    for f in range(shape[-1]):
        for i in range(shape[-3]):
            w = np.cos(lat[i]*np.pi/180)
            wi = 1./w
            for j in range(shape[-2]):
                # add latitude gradient
                try:
                    ind1 = geosep.reshape_index((i,j,f))
                    ind2 = geosep.reshape_index((i+1,j,f))
                except IndexError:
                    logger.debug(f'IndexError: {(i,j,f)}-{(i+1,j,f)}')
                else:
                    W[ind1,ind1] += w
                    W[ind2,ind2] += w
                    W[ind1,ind2] += -w
                    W[ind2,ind1] += -w

                # add longitude gradient
                try:
                    ind1 = geosep.reshape_index((i,j,f))
                    ind2 = geosep.reshape_index((i,j+1,f))
                except IndexError:
                    logger.debug(f'IndexError: {(i,j,f)}-{(i,j+1,f)}')
                else:
                    W[ind1,ind1] += wi
                    W[ind2,ind2] += wi
                    W[ind1,ind2] += -wi
                    W[ind2,ind1] += -wi

            # add periodic longitude point
            try:
                ind1 = geosep.reshape_index((i,shape[-2] - 1,f))
                ind2 = geosep.reshape_index((i,0,f))
            except IndexError:
                logger.debug(f'IndexError: {(i,shape[-2] - 1,f)}-{(i,0,f)}')
            else:
                W[ind1,ind1] += wi
                W[ind2,ind2] += wi
                W[ind1,ind2] += -wi
                W[ind2,ind1] += -wi
    return W.todia()

## Try to perform operations with the dense and sparse matrices and see if they are the same

In [None]:
lon = np.arange(0, 360, 2)
lat = np.arange(0, 87, 2)
reshape_mask = np.ones((len(lat), len(lon), 3), dtype=bool)
reshape_mask.shape

### Creation

In [None]:
%%time
W = compute_weight_matrix(reshape_mask, lat)

In [None]:
%%time
W_sparse = compute_weight_matrix_sparse(reshape_mask, lat)

In [None]:
(W == W_sparse).all()

### Convert to array

In [None]:
%%time
W2 = W_sparse.todense()

In [None]:
W.shape

In [None]:
isinstance(W_sparse, sparse.spmatrix)

In [None]:
W3 = np.asarray(W_sparse)

In [None]:
(W3 == W2).all()

In [None]:
W3

In [None]:
w = W
u2 = np.asarray(w.toarray() if isinstance(w, sparse.dia_matrix) else w, dtype=np.float32)

In [None]:
(u == u2).all()