# Fit 8 Class GMM and Calculate Average Profiles

This notebook will fit 8 class GMMs to UK-ESM historical Southern Ocean data 2001-2018, following Jones et al. 2019 (https://doi.org/10.1029/2018JC014629). These files are required to reproduce Figures YY from *Heuristic Methods for Determining the Number of Classes in Unsupervised Classification of Climate Models*, E. Boland et al. 2022 (doi to follow). This requires cluster_utils.py and input datafiles via the googleapi CMIP6 store (see cluster_utils.py for more info)

Outputs stored in model/\[ensemble\]/\[nclasses\]/gmm.obj and data/\[ensemble\]/\[nclasses\]/avg.obj

Please attribute any plots or code from this notebook using the DOI from Zenodo: to come

Updated Nov 2022
E Atkinson & E Boland [emmomp@bas.ac.uk](email:emmomp@bas.ac.uk)

In [31]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:43361")
client

0,1
Connection method: Direct,
Dashboard: http://127.0.0.1:8787/status,

0,1
Comm: tcp://127.0.0.1:43361,Workers: 1
Dashboard: http://127.0.0.1:8787/status,Total threads: 1
Started: Just now,Total memory: 8.00 GiB

0,1
Comm: tcp://127.0.0.1:36477,Total threads: 1
Dashboard: http://127.0.0.1:36877/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:39361,
Local directory: /tmp/dask-worker-space/worker-49avh9u8,Local directory: /tmp/dask-worker-space/worker-49avh9u8
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 2.0%,Last seen: Just now
Memory usage: 140.50 MiB,Spilled bytes: 0 B
Read bytes: 9.78 kiB,Write bytes: 13.74 kiB


In [30]:
from dask_gateway import Gateway
gateway = Gateway()
from dask.distributed import Client

'''properly shutdown any previous clusters'''
clusters=gateway.list_clusters()
if clusters != []:
    print(f'found {len(clusters)} clusters')
    for cluster in clusters:
        cluster = gateway.connect(cluster.name)
        client=Client(cluster)
        client.close()
        cluster.shutdown()

In [29]:
client.close()

2022-11-16 11:38:33,958 - distributed.client - ERROR - 
ConnectionRefusedError: [Errno 111] Connection refused

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/comm/core.py", line 291, in connect
    comm = await asyncio.wait_for(
  File "/srv/conda/envs/notebook/lib/python3.9/asyncio/tasks.py", line 479, in wait_for
    return fut.result()
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/comm/tcp.py", line 461, in connect
    convert_stream_closed_error(self, e)
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/comm/tcp.py", line 142, in convert_stream_closed_error
    raise CommClosedError(f"in {obj}: {exc.__class__.__name__}: {exc}") from exc
distributed.comm.core.CommClosedError: in <distributed.comm.tcp.TCPConnector object at 0x7f7a61b10310>: ConnectionRefusedError: [Errno 111] Connection refused

During ha

In [32]:
import numpy as np
import xarray as xr

import os
import pickle

import cluster_utils as flt

### User options
Leave as is to recreate the paper

In [33]:
# Number of classes 
model='model_20012017'
n_classes = 8
#Time range
tslice=slice('2001-01', '2017-12') 
ids = ['r1i1p1f2', 'r2i1p1f2', 'r3i1p1f2', 'r4i1p1f2', 'r5i1p1f3', 'r6i1p1f3', 'r7i1p1f3', 'r8i1p1f2', 'r9i1p1f2', 'r10i1p1f2'] 
npca=3 #number of PCA components
ntrain=7000 #number of profiles per month to use in training dataset

Uncomment the following two lines if you need to generate mask.npy:

In [34]:
#data = flt.retrieve_profiles(timeRange = slice('1995-01', '1995-02'))
#np.save('data/mask', data['n'])
mask = np.load('data/mask.npy', allow_pickle=True)

### Train models and generate average profiles for chosen ensemble members and classes 

In [None]:
avg_profiles = {}
for m_id in ids:
    
    #Check if data already exists
 
    path_n = '{}/{}/{}'.format(model,m_id, n_classes)
    if os.path.isfile('{}/avg_class.obj'.format(path_n)):
        print('Found avg files for {}, skipping'.format(m_id))
        continue
    else:
        
        print('Starting {}'.format(m_id))
        options = {'memberId' : m_id}
        path_id = '{}/{}'.format(model,m_id)
        # Check if pca & model already created
        if os.path.isfile('{}/pca.obj'.format(path_id)) and os.path.isfile('{}/gmm.obj'.format(path_n)):
            with open('{}/pca.obj'.format(path_id),'rb') as file:
                pca=pickle.load(file)
            with open('{}/gmm.obj'.format(path_n), 'rb') as file:
                gmm = pickle.load(file)            
        else:        #Generate training set and PCA model
            print('No model found, generating training set')
            [data_train,pca] = flt.generate_trainingset(timeRange = tslice, mask=mask, options=options,n_components=npca,N=ntrain)
            if not os.path.exists(path_id):
                os.makedirs(path_id)
            with open('{}/pca.obj'.format(path_id), 'wb') as file:
                pickle.dump(pca, file)
            if not os.path.exists(path_n):
                os.makedirs(path_n)    
            # Generate GMM model generated from training set
            print('Training GMM')
            gmm = flt.train_gmm(data_train, n_classes)
            with open('{}/gmm.obj'.format(path_n), 'wb') as file:
                pickle.dump(gmm, file)
            
        print('Loading full dataset')
        #Load full Southern Ocean data to fit
        data = flt.retrieve_profiles(timeRange=tslice,mask=mask,options=options)
        data = data.chunk({'time': data.sizes['time'], 'n': 1024})
        # Normalise the samples
        data_norm = flt.normalise_data(data, ('n', 'time')) 
        # Transform to PCA space
        data_trans = flt.pca_transform(data_norm, pca)
        print('Classifying full dataset into {} classes'.format(n_classes))
        # Classify full dataset
        data_classes = flt.gmm_classify(data_trans, gmm)
        # Time average classification
        avg_class = data_classes.mean('time')
        print('Time Average Classification calculated, writing to file'.format(n_classes))
        with open('{}/avg_class.obj'.format(path_n), 'wb') as file:
            pickle.dump(avg_class.data.compute(), file)           
        # Calculate average profiles for each class
        avg_prof = flt.avg_profiles(data, data_classes, n_classes)
        print('Average profiles calculated, writing to file'.format(n_classes))
        with open('{}/avg_prof.obj'.format(path_n), 'wb') as file:
            pickle.dump(avg_prof, file)      
        print('Done with {} classes'.format(n_classes))      
    
print('Done!')

Found avg files for r1i1p1f2, skipping
Starting r2i1p1f2
Loading full dataset
Classifying full dataset into 8 classes
Time Average Classification calculated, writing to file
