# Step 1: Train Models

This notebook will train 2 to 20 class GMM models to three UK-ESM historical ensemble members, caculate the BIC, SIC and SIL score for each. This is required to reproduce Figure 2 from *Heuristic Methods for Determining the Number of Classes in Unsupervised Classification of Climate Models*, E. Boland et al. 2022 (doi to follow). This requires cluster_utils.py and input datafiles via the googleapi CMIP6 store (see cluster_utils.py for more info)

Please attribute any plots or code from this notebook using the DOI from Zenodo: to come

Updated Feb 2023
E Atkinson & E Boland [emmomp@bas.ac.uk](email:emmomp@bas.ac.uk)

In [1]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:32937")
client

0,1
Connection method: Direct,
Dashboard: http://127.0.0.1:8787/status,

0,1
Comm: tcp://127.0.0.1:32937,Workers: 5
Dashboard: http://127.0.0.1:8787/status,Total threads: 5
Started: 11 minutes ago,Total memory: 40.00 GiB

0,1
Comm: tcp://127.0.0.1:36144,Total threads: 1
Dashboard: http://127.0.0.1:40414/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:37548,
Local directory: /tmp/dask-worker-space/worker-ed6goqha,Local directory: /tmp/dask-worker-space/worker-ed6goqha
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 2.0%,Last seen: Just now
Memory usage: 249.66 MiB,Spilled bytes: 0 B
Read bytes: 30.57 kiB,Write bytes: 22.07 kiB

0,1
Comm: tcp://127.0.0.1:40871,Total threads: 1
Dashboard: http://127.0.0.1:42590/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:38967,
Local directory: /tmp/dask-worker-space/worker-zyklutlg,Local directory: /tmp/dask-worker-space/worker-zyklutlg
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 0.0%,Last seen: Just now
Memory usage: 246.19 MiB,Spilled bytes: 0 B
Read bytes: 27.34 kiB,Write bytes: 23.20 kiB

0,1
Comm: tcp://127.0.0.1:33069,Total threads: 1
Dashboard: http://127.0.0.1:40815/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:42081,
Local directory: /tmp/dask-worker-space/worker-5ppprifm,Local directory: /tmp/dask-worker-space/worker-5ppprifm
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 0.0%,Last seen: Just now
Memory usage: 244.47 MiB,Spilled bytes: 0 B
Read bytes: 16.36 kiB,Write bytes: 7.72 kiB

0,1
Comm: tcp://127.0.0.1:40485,Total threads: 1
Dashboard: http://127.0.0.1:39646/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:44741,
Local directory: /tmp/dask-worker-space/worker-0izax7z1,Local directory: /tmp/dask-worker-space/worker-0izax7z1
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 0.0%,Last seen: Just now
Memory usage: 240.78 MiB,Spilled bytes: 0 B
Read bytes: 15.64 kiB,Write bytes: 7.01 kiB

0,1
Comm: tcp://127.0.0.1:34073,Total threads: 1
Dashboard: http://127.0.0.1:37802/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:41127,
Local directory: /tmp/dask-worker-space/worker-mzc8bwb8,Local directory: /tmp/dask-worker-space/worker-mzc8bwb8
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 2.0%,Last seen: Just now
Memory usage: 238.30 MiB,Spilled bytes: 0 B
Read bytes: 16.00 kiB,Write bytes: 7.37 kiB


In [2]:
import numpy as np
import os
import pickle
import cluster_utils as flt
from sklearn import metrics

### User options
Leave as is to recreate the paper

In [3]:
# Number of classes 
model_folder='model'
max_classes = 20 #max classes
#Time range
tslice=slice('2001-01', '2017-12') 
#Depth range
levSel=slice(5, 2000)
ids = ['r1i1p1f2', 'r2i1p1f2', 'r3i1p1f2']
ntrain=3000 #number of profiles per month to use in training dataset
npca=3

Uncomment the following three lines if you need to generate mask.npy:

In [4]:
#data = flt.retrieve_profiles(timeRange = slice('1995-01', '1995-02'),levSel=levSel)
#np.save('data/mask', data['n'])
#mask=data['n']
mask = np.load('data/mask.npy', allow_pickle=True)

In [5]:
options = {'memberId' : ids[0]}
[data,pca] = flt.generate_trainingset(timeRange = tslice, mask=mask, options=options,N=ntrain,n_components=npca,levSel=levSel)
data

2023-06-09 11:53:30,120 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client


### Fit 2-30 class models for each ensemble member
Saves each individual PCA model, GMM model and BIC/AIC/SIL score to \[model_folder\]

Saves all BICs/AICs/SILs to \[model_folder\]/\[BICs/AICs/SILs\]2-30.obj

In [None]:
BICs = {}
AICs = {}
SILs = {}
for m_id in ids:
    path_id = '{}/{}'.format(model_folder, m_id)
    if not os.path.isdir(path_id):
        os.makedirs(path_id)
    print('Starting {}'.format(m_id))
    options = {'memberId' : m_id}
    
    # Load training set
    [data,pca] = flt.generate_trainingset(timeRange = tslice, mask=mask, options=options,N=ntrain,n_components=npca,levSel=levSel)
    
    bic = np.zeros(max_classes-1)
    aic = bic.copy()
    sil = bic.copy()
    
    with open('{}/pca.obj'.format(path_id), 'wb') as file:
        pickle.dump(pca, file)
        
    print('Finished setup for {}'.format(m_id))
    
    for iin,n_classes in enumerate(range(2, max_classes+1)):
        
        path_n = '{}/{}/{}'.format(model_folder, m_id, n_classes)
        
        if not os.path.isdir(path_n):
            os.makedirs(path_n)
            
        gmm = flt.train_gmm(data, n_classes)
        with open('{}/gmm.obj'.format(path_n), 'wb') as file:
            pickle.dump(gmm, file)
        
        bic[iin] = gmm.bic(data)
        with open('{}/bic.obj'.format(path_n), 'wb') as file:
            pickle.dump(bic[iin],file)       

        aic[iin] = gmm.aic(data)
        with open('{}/aic.obj'.format(path_n), 'wb') as file:
            pickle.dump(aic[iin],file)     
            
        # Calculate silhouette score for 10000 point sample        
        inds=np.random.randint(0,data.shape[0],10000)
        labels=flt.gmm_classify(data[inds,:],gmm)
        sil[iin]=metrics.silhouette_score(data[inds,:],labels,n_jobs=-1)
        sample_silhouette_values = metrics.silhouette_samples(data[inds,:],labels,n_jobs=-1)
        with open('{}/sil.obj'.format(path_n), 'wb') as file:
            pickle.dump(sil[iin],file)
        with open('{}/sil_vals.obj'.format(path_n), 'wb') as file:
            pickle.dump(sample_silhouette_values,file)
        with open('{}/sil_labels.obj'.format(path_n), 'wb') as file:
            pickle.dump(labels,file)                
        
        print('Finished {} with {} classes'.format(m_id, n_classes))
        
    BICs[m_id] = bic
    AICs[m_id] = aic
    SILs[m_id] = sil
    
with open('{}/BICs2-20.obj'.format(model_folder), 'wb') as file:
    pickle.dump(BICs, file)
with open('{}/AICs2-20.obj'.format(model_folder), 'wb') as file:
    pickle.dump(AICs, file)
with open('{}/SILs2-20.obj'.format(model_folder), 'wb') as file:
    pickle.dump(SILs, file)

print('Done!')

Starting r1i1p1f2
Updating cached catalogue...
catalogue memory usage (MB): 26.848599
