In [1]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib
import cartopy.crs as ccrs
import scipy.stats as sts

import os
import pickle

import SAM
import float_data as flt

import importlib
importlib.reload(SAM)
importlib.reload(flt)

from dask import delayed

In [3]:
model_folder = 'model'
n_classes = 8
ids = ['r1i1p1f2', 'r2i1p1f2', 'r3i1p1f2', 'r4i1p1f2', 'r5i1p1f3', 'r6i1p1f3', 'r7i1p1f3', 'r8i1p1f2', 'r9i1p1f2', 'r10i1p1f2']
mask = np.load('data/mask.npy', allow_pickle=True)
data_classes = {}
avg_profiles_dict = {}
path_ref = '{}/{}/{}'.format(model_folder, ids[0], n_classes)
with open('{}/avg.obj'.format(path_ref), 'rb') as file:
    ref_profiles = pickle.load(file)
    file.close()
  
    
for (m_i, m_id) in enumerate(ids):

    print('Starting {}'.format(m_id))
    path_id = '{}/{}'.format(model_folder, m_id)
    path_n = '{}/{}/{}'.format(model_folder, m_id, n_classes)
    path_data = 'data/{}/{}'.format(m_id, n_classes)

    with open('{}/pca.obj'.format(path_id), 'rb') as file:
        pca = pickle.load(file)
        file.close()

    with open('{}/gmm.obj'.format(path_n), 'rb') as file:
        gmm = pickle.load(file)
        file.close()

    with open('{}/avg.obj'.format(path_data), 'rb') as file:
        avg_profiles = pickle.load(file)
        file.close()

    options = {'memberId' : m_id, 'raw' : True}
    data = flt.retrieve_profiles(timeRange = slice('1970-01', '1970-12'), mask=mask, options=options)
    data = data.chunk({'time': data.sizes['time'], 'i' : 64, 'j':64})
    data_sampled = flt.normalise_data(data, ('i', 'j', 'time'))
    data_trans = flt.pca_transform(data_sampled, pca)
    
    data_c = flt.gmm_classify(data_trans, gmm).compute()
    data_classes[m_id] = data_c
    continue
    lats = data_c['lat'].values
    lons = data_c['lon'].values
    lev = data['lev'].values
    times = data['time'].values
    alpha = np.logical_not(data.isel(time=0, lev=-1).isnull().values)
    

    plt_data = data_c.values

    #plt_data[plt_data == -1] = 0
    for yr in range(np.size(plt_data, 0)):
        fig = plt.figure()
        
        ax = fig.add_subplot(projection=ccrs.SouthPolarStereo())
        ax.pcolormesh(lons, lats, plt_data[yr, :, :], transform=ccrs.PlateCarree(), alpha=alpha, cmap='YlOrRd') #, cmap='YlOrRd'
        ax.coastlines()
        ax.set_title('{}, {}'.format(m_id, times[yr].astype('datetime64[M]')))
        ax.set_facecolor((0.7, 0.7, 0.7, 1))
        
    
        fig.set_size_inches(7, 7)
        #plt.savefig('figures/anim_mm/{:03d}'.format(yr), dpi=300, bbox_inches='tight')
        
        print(yr)
        plt.show()
        plt.close(fig)
    
    
print('Done!')

Starting r1i1p1f2
Starting r2i1p1f2
Starting r3i1p1f2
Starting r4i1p1f2
Starting r5i1p1f3
Starting r6i1p1f3
Starting r7i1p1f3
Starting r8i1p1f2
Starting r9i1p1f2
Starting r10i1p1f2
Done!


In [6]:
data_c1 = data_c

In [4]:
def f(datac1, datac2):
  a = [np.unique(datac2.where(datac1==k).values, return_counts=True) for k in range(n_classes)]
  a = [(x[0][0:-1].astype('int'), x[1][0:-1]) for x in a]
  return a
counts = []
for m_id in ids:
  counts.append(f(data_classes[ids[0]], data_classes[m_id]))

In [5]:
indices = np.zeros((len(ids), n_classes))
for i in range(len(ids)):
  for j in range(n_classes):
    indices[i, j] = counts[i][j][0][np.argmax(counts[i][j][1])]
indices = indices.astype('int')

In [6]:
for (k, (n, c)) in enumerate(a):
  plt.scatter(n[0:-1], c[0:-1], marker='+', label=k)
plt.legend()

NameError: name 'a' is not defined

In [5]:
def gmm_prob(data_c1, data_c2):
  
  """
  Replace the nan with -1
  """

  pca_size = data_trans.sizes['pca_comp']
  gmm_size = gmm.n_components
  def func(arr):
    arr_r = np.reshape(arr, (-1, pca_size))
    
    inds = np.isnan(arr_r) 
    arr_r[inds] = 0
    
    out = gmm.predict_proba(arr_r)
    out_sizes = list(np.shape(arr))
    out_sizes[-1] = gmm_size
    out[inds[:, 0]] = np.nan
    out = np.reshape(out, out_sizes)
    return out

  result = xr.apply_ufunc(
        func,
        data_trans,
        input_core_dims=[['pca_comp']],
        output_core_dims=[['k']],
        dask='parallelized',
        output_dtypes=('float64',),
        vectorize=False,
        dask_gufunc_kwargs={
            'output_sizes' : {'k' : gmm_size}
        }
    )
  
  return result

In [7]:
b = np.unique(data_c1.where(data_c2==k), return_counts=True)
b

NameError: name 'data_c1' is not defined

In [11]:
indices[0:-1, :]

array([[0., 1., 2., 3., 4., 5., 6.],
       [1., 0., 5., 6., 2., 3., 4.],
       [1., 4., 0., 6., 5., 2., 3.],
       [3., 0., 2., 5., 1., 4., 6.],
       [2., 6., 4., 3., 0., 5., 1.],
       [6., 2., 5., 0., 1., 4., 3.],
       [2., 5., 0., 1., 4., 6., 3.],
       [3., 5., 4., 2., 1., 6., 0.],
       [0., 4., 2., 6., 3., 5., 1.]])

In [6]:
avg_profiles = {}
for m_id in ids:
    
    print('Starting {}'.format(m_id))
    path_id = 'model/{}'.format(m_id)
    path_n = 'model/{}/{}'.format(m_id, n_classes)
    path_data = 'data/{}/{}'.format(m_id, n_classes)
    

    if os.path.exists('{}/avg.obj'.format(path_data)):
        with open('{}/avg.obj'.format(path_data), 'rb') as file:
            avg_profiles[m_id] = pickle.load(file)
            file.close()
print('Done!')

Starting r1i1p1f2
Starting r2i1p1f2
Starting r3i1p1f2
Starting r4i1p1f2
Starting r5i1p1f3
Starting r6i1p1f3
Starting r7i1p1f3
Starting r8i1p1f2
Starting r9i1p1f2
Starting r10i1p1f2
Done!


In [7]:
indices_avg = np.array([flt.match_profiles(avg_profiles[ids[0]], avg_profiles[x]) for x in ids])

In [8]:

for (a, b, i) in zip(indices, indices_avg, ids):
  print(i)
  print("  Bijective spatial plots: {}".format(len(a) == len(np.unique(a))))
  print("  Bijective profiles: {}".format(len(b) == len(np.unique(b))))
  print("  Same assignment? {}".format(np.all(a == b)))

r1i1p1f2
  Bijective spatial plots: True
  Bijective profiles: True
  Same assignment? True
r2i1p1f2
  Bijective spatial plots: False
  Bijective profiles: False
  Same assignment? True
r3i1p1f2
  Bijective spatial plots: True
  Bijective profiles: True
  Same assignment? True
r4i1p1f2
  Bijective spatial plots: False
  Bijective profiles: False
  Same assignment? True
r5i1p1f3
  Bijective spatial plots: False
  Bijective profiles: False
  Same assignment? True
r6i1p1f3
  Bijective spatial plots: False
  Bijective profiles: False
  Same assignment? True
r7i1p1f3
  Bijective spatial plots: False
  Bijective profiles: False
  Same assignment? False
r8i1p1f2
  Bijective spatial plots: True
  Bijective profiles: True
  Same assignment? True
r9i1p1f2
  Bijective spatial plots: True
  Bijective profiles: True
  Same assignment? True
r10i1p1f2
  Bijective spatial plots: True
  Bijective profiles: True
  Same assignment? True
