# Step 2: Fit 7-9 Class GMMs and save class properties

This notebook will fit 7, 8 and 9 class GMMs to UK-ESM historical Southern Ocean data 2001-2018, following Jones et al. 2019 (https://doi.org/10.1029/2018JC014629). These models were trained in Step1_trainmodels.ipynb.

These files are required to reproduce figures 2, 3, 6, and 7 from *A Novel Heuristic Method for Detecting Overfit in Unsupervised Classification of Climate Models*, E. Boland et al. 2023 (doi to follow). This requires cluster_utils.py and input datafiles via the googleapi CMIP6 store (see cluster_utils.py for more info)

Outputs stored in \[model\]/\[ensemble\]/\[nclasses\]

Please attribute any plots or code from this notebook using the DOI from Zenodo: to come

Updated Mar 2023
E Atkinson & E Boland [emmomp@bas.ac.uk](email:emmomp@bas.ac.uk)

In [1]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:36101")
client

0,1
Connection method: Direct,
Dashboard: http://127.0.0.1:8787/status,

0,1
Comm: tcp://127.0.0.1:36101,Workers: 6
Dashboard: http://127.0.0.1:8787/status,Total threads: 6
Started: 2 hours ago,Total memory: 48.00 GiB

0,1
Comm: tcp://127.0.0.1:37892,Total threads: 1
Dashboard: http://127.0.0.1:44547/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:44056,
Local directory: /tmp/dask-worker-space/worker-d2gwzh_a,Local directory: /tmp/dask-worker-space/worker-d2gwzh_a
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 0.0%,Last seen: Just now
Memory usage: 277.58 MiB,Spilled bytes: 0 B
Read bytes: 27.77 kiB,Write bytes: 35.72 kiB

0,1
Comm: tcp://127.0.0.1:41776,Total threads: 1
Dashboard: http://127.0.0.1:33632/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:43317,
Local directory: /tmp/dask-worker-space/worker-hayhun2x,Local directory: /tmp/dask-worker-space/worker-hayhun2x
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 1.8%,Last seen: Just now
Memory usage: 279.57 MiB,Spilled bytes: 0 B
Read bytes: 16.35 kiB,Write bytes: 23.36 kiB

0,1
Comm: tcp://127.0.0.1:42827,Total threads: 1
Dashboard: http://127.0.0.1:32895/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:44990,
Local directory: /tmp/dask-worker-space/worker-jj01ya8d,Local directory: /tmp/dask-worker-space/worker-jj01ya8d
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 2.0%,Last seen: Just now
Memory usage: 260.21 MiB,Spilled bytes: 0 B
Read bytes: 26.70 kiB,Write bytes: 28.08 kiB

0,1
Comm: tcp://127.0.0.1:45345,Total threads: 1
Dashboard: http://127.0.0.1:40743/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:43291,
Local directory: /tmp/dask-worker-space/worker-m53fjluj,Local directory: /tmp/dask-worker-space/worker-m53fjluj
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 1.7%,Last seen: Just now
Memory usage: 273.64 MiB,Spilled bytes: 0 B
Read bytes: 17.78 kiB,Write bytes: 24.70 kiB

0,1
Comm: tcp://127.0.0.1:33335,Total threads: 1
Dashboard: http://127.0.0.1:36427/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:42060,
Local directory: /tmp/dask-worker-space/worker-gsjk7xoc,Local directory: /tmp/dask-worker-space/worker-gsjk7xoc
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 2.0%,Last seen: Just now
Memory usage: 276.94 MiB,Spilled bytes: 0 B
Read bytes: 17.01 kiB,Write bytes: 17.00 kiB

0,1
Comm: tcp://127.0.0.1:40192,Total threads: 1
Dashboard: http://127.0.0.1:35925/status,Memory: 8.00 GiB
Nanny: tcp://127.0.0.1:42522,
Local directory: /tmp/dask-worker-space/worker-5jvnbdg4,Local directory: /tmp/dask-worker-space/worker-5jvnbdg4
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 0.0%,Last seen: Just now
Memory usage: 283.70 MiB,Spilled bytes: 0 B
Read bytes: 26.89 kiB,Write bytes: 28.27 kiB


In [2]:
import numpy as np
import xarray as xr

import os
import pickle

import cluster_utils as flt

### User options
Leave as is to recreate the paper

In [3]:
# Number of classes 
model='model'
classes = [7,8,9]
#Time range
tslice=slice('2001-01', '2017-12') 
#Depth range
levSel=slice(5, 2000)
ids = ['r1i1p1f2', 'r2i1p1f2', 'r3i1p1f2', 'r4i1p1f2', 'r5i1p1f3', 'r6i1p1f3', 'r7i1p1f3', 'r8i1p1f2', 'r9i1p1f2', 'r10i1p1f2']
mask = np.load('data/mask.npy', allow_pickle=True)

### Fit already trained models to full dataset and generate average profiles for chosen ensemble members and classes 

In [None]:
avg_profiles = {}
for m_id in ids:
    
    for nn,n_classes in enumerate(classes):   
        path_n = '{}/{}/{}'.format(model,m_id, n_classes)

        print('Starting {}'.format(m_id))
        options = {'memberId' : m_id}
        path_id = '{}/{}'.format(model,m_id)
    
        # Load PCA
        with open('{}/pca.obj'.format(path_id),'rb') as file:
            pca=pickle.load(file)         
        
        # Retrieve ALL Southern Ocean data
        options = {'memberId' : m_id}
        data = flt.retrieve_profiles(timeRange=tslice,mask=mask,options=options,levSel=levSel)
        data = data.chunk({'time': data.sizes['time'], 'n': 1024})
        # Normalise the samples
        data_norm = flt.normalise_data(data, ('n', 'time')) 
        # Transform to PCA space
        data_trans = flt.pca_transform(data_norm, pca)
        print('Finished setup for {}'.format(m_id))      

        for nn,n_classes in enumerate(classes):  
            path_n = '{}/{}/{}'.format(model,m_id, n_classes)
            with open('{}/gmm.obj'.format(path_n),'rb') as file:
                gmm=pickle.load(file)                

            # Classify full dataset            
            print('Classifying full dataset into {} classes'.format(n_classes))
            data_classes = flt.gmm_classify(data_trans, gmm)
            data_probs = flt.gmm_prob(data_trans, gmm)
            print('Classification complete, writing to file'.format(n_classes))
            flt.write_tonc(data_classes.reset_index('n'),n_classes,m_id,'class',path_n)
            flt.write_tonc(data_probs.reset_index('n').mean('time'),n_classes,m_id,'probs',path_n)
            # Calculate average profiles for each class
            avg_prof = flt.avg_profiles(data, data_classes, n_classes)
            print('Average profiles calculated, writing to file'.format(n_classes))
            with open('{}/avg_prof.obj'.format(path_n), 'wb') as file:
                pickle.dump(avg_prof, file)      
            print('Done with {} classes'.format(n_classes))      
    
print('Done!')

Starting r1i1p1f2
Finished setup for r1i1p1f2
Classifying full dataset into 7 classes
Classification complete, writing to file
class written to model/r1i1p1f2/7/class.nc
probs written to model/r1i1p1f2/7/probs.nc
class written to model/r1i1p1f2/8/class.nc
probs written to model/r1i1p1f2/8/probs.nc
Average profiles calculated, writing to file
Done with 8 classes
Classifying full dataset into 9 classes
Classification complete, writing to file
class written to model/r1i1p1f2/9/class.nc
probs written to model/r1i1p1f2/9/probs.nc
Average profiles calculated, writing to file
Done with 9 classes
Starting r1i1p1f2
Finished setup for r1i1p1f2
Classifying full dataset into 7 classes
Classification complete, writing to file
class written to model/r1i1p1f2/7/class.nc
probs written to model/r1i1p1f2/7/probs.nc
Average profiles calculated, writing to file
Done with 7 classes
Classifying full dataset into 8 classes
Classification complete, writing to file
class written to model/r1i1p1f2/8/class.nc
pr