# Fit 7 & 8 class GMMs and calculate average profiles for profile matching

This notebook will create the average profiles for 7 and 8 class GMMs fit to two ensembles from the UK-ESM historical simulations. These files are required to reproduce Figures 3 and 4 from *Heuristic Methods for Determining the Number of Classes in Unsupervised Classification of Climate Models*, E. Boland et al. 2022 (doi to follow). This requires cluster_utils.py and input datafiles via the googleapi CMIP6 store (see cluster_utils.py for more info)
Outputs stored in model/\[ensemble\]/\[nclasses\]/avg.obj

Please attribute any plots or code from this notebook using the DOI from Zenodo: to come

Updated Nov 2022
E Atkinson & E Boland [emmomp@bas.ac.uk](email:emmomp@bas.ac.uk)

In [None]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:39925")
client

In [1]:
from dask_gateway import Gateway
gateway = Gateway()
from dask.distributed import Client

'''properly shutdown any previous clusters'''
clusters=gateway.list_clusters()
if clusters != []:
    print(f'found {len(clusters)} clusters')
    for cluster in clusters:
        cluster = gateway.connect(cluster.name)
        client=Client(cluster)
        client.close()
        cluster.shutdown()

In [1]:
client.close()

NameError: name 'client' is not defined

In [None]:
import numpy as np
import xarray as xr

import os
import pickle

import cluster_utils as flt

### User options
Leave as is to recreate the paper

In [None]:
# Number of classes 
classes = [7,8]
model='model_20012017'
#Time range
tslice=slice('2001-01', '2017-12') 
ids = ['r1i1p1f2', 'r2i1p1f2'] 
levSel=slice(5, 2000)
npca=3 #number of PCA components
ntrain=3000 #number of profiles per month to use in training dataset

Uncomment the following two lines if you need to generate mask.npy:

In [None]:
#data = flt.retrieve_profiles(timeRange = slice('1995-01', '1995-02'))
#np.save('data/mask', data['n'])
mask = np.load('data/mask_3000.npy', allow_pickle=True)

### Train models and generate average profiles for chosen ensemble members and classes 

In [8]:
for m_id in ids:
    
    #Check if data already exists
    tests=[]
    for nn,n_classes in enumerate(classes):   
        path_data = '{}/{}/{}'.format(model,m_id, n_classes)
        tests.append(os.path.isfile('{}/avg_prof.obj'.format(path_data)))
    if np.all(tests):
        print('Found avg files for {}, skipping'.format(m_id))
        continue
    else:
        
        print('Starting {}'.format(m_id))
        options = {'memberId' : m_id}
        path_id = '{}/{}'.format(model,m_id)
        # Check if models already created
        tests=[]
        for nn,n_classes in enumerate(classes):   
            path_data = '{}/{}/{}'.format(model,m_id, n_classes)
            tests.append(os.path.isfile('{}/gmm.obj'.format(path_data)))
        if np.all(tests): # All models trained, no need to load training set       
            with open('{}/pca.obj'.format(path_id),'rb') as file:
                pca=pickle.load(file)         
        else:     # Load training set, generate PCA model
            print('No models found, generating training set')
            [data_train,pca] = flt.generate_trainingset(timeRange = tslice, mask=mask, options=options,n_components=npca,N=ntrain,levSel=levSel)
            if not os.path.exists(path_id):
                os.makedirs(path_id)
            with open('{}/pca.obj'.format(path_id), 'wb') as file:
                pickle.dump(pca, file)               
                        
        
        # Retrieve all Southern Ocean data
        options = {'memberId' : m_id}
        data = flt.retrieve_profiles(timeRange=tslice,mask=mask,options=options,levSel=levSel)
        data = data.chunk({'time': data.sizes['time'], 'n': 1024})
        # Normalise the samples
        data_norm = flt.normalise_data(data, ('n', 'time')) 
        # Transform to PCA space
        data_trans = flt.pca_transform(data_norm, pca)
        print('Finished setup for {}'.format(m_id))      

        for nn,n_classes in enumerate(classes):  
            path_n = '{}/{}/{}'.format(model,m_id, n_classes)
            #Check if model already created
            if os.path.isfile('{}/gmm.obj'.format(path_n)):
                with open('{}/gmm.obj'.format(path_n),'rb') as file:
                    gmm=pickle.load(file)                
            else:            
                print('Training {} class model'.format(n_classes))
                if not os.path.exists(path_n):
                    os.makedirs(path_data)            
                gmm = flt.train_gmm(data_train, n_classes)
                with open('{}/gmm.obj'.format(path_n), 'wb') as file:
                    pickle.dump(gmm, file)
 
            print('Classifying full dataset into {} classes'.format(n_classes))
            # Classify full dataset
            data_classes = flt.gmm_classify(data_trans, gmm)
            print('Classification complete, writing to file'.format(n_classes))
            flt.write_tonc(data_classes.reset_index('n'),n_classes,m_id,'class',path_n)     
            # Calculate average profiles for each clasee
            avg_prof = flt.avg_profiles(data, data_classes, n_classes)
            print('Average profiles calculated, writing to file'.format(n_classes))
            with open('{}/avg_prof.obj'.format(path_n), 'wb') as file:
                pickle.dump(avg_prof, file)
            print('Done with {} classes'.format(n_classes))
    
print('Done!')

Starting r1i1p1f2
Finished setup for r1i1p1f2
Classifying full dataset into 7 classes
Classification complete, writing to file
class written to model_20012017/r1i1p1f2/7/class.nc
Average profiles calculated, writing to file
Done with 7 classes
Classifying full dataset into 8 classes
Classification complete, writing to file
class written to model_20012017/r1i1p1f2/8/class.nc
Average profiles calculated, writing to file
Done with 8 classes
Starting r2i1p1f2
Finished setup for r2i1p1f2
Classifying full dataset into 7 classes
Classification complete, writing to file
class written to model_20012017/r2i1p1f2/7/class.nc
Average profiles calculated, writing to file
Done with 7 classes
Classifying full dataset into 8 classes
Classification complete, writing to file
class written to model_20012017/r2i1p1f2/8/class.nc


Task exception was never retrieved
future: <Task finished name='Task-798' coro=<Client._gather.<locals>.wait() done, defined at /srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/client.py:2002> exception=AllExit()>
Traceback (most recent call last):
  File "/srv/conda/envs/notebook/lib/python3.9/site-packages/distributed/client.py", line 2011, in wait
    raise AllExit()
distributed.client.AllExit


Average profiles calculated, writing to file
Done with 8 classes
Done!
