In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import time, os
import cmcrameri.cm as cm

from tqdm import tqdm
from utils import ELFData, ELFModel

## Load and filter data

In [None]:
# Perform center crop
center = False

In [None]:
# Load data
unlabeled = ELFData()
unlabeled.load_processed(dirname='data/unlabeled', structure=False)
print('Number of examples:', len(unlabeled.data))
unlabeled.data.head()

In [None]:
# Center crop
if center:
    dx = len(unlabeled.data.iloc[0]['elf'])//4
    unlabeled.data['elf'] = unlabeled.data['elf'].apply(lambda x: x[dx:-dx])

In [None]:
# Sort and filter data
unlabeled.data = unlabeled.data[~unlabeled.data['mixed']]
unlabeled.get_pdf_cdf()
unlabeled.sort_by_cdf()
print('Number of examples:', len(unlabeled.data))
unlabeled.data.head()

### Calculate columns

In [None]:
unlabeled.data['pdf-l_srt'] = unlabeled.data[['pdf_srt', 'l']].apply(
    lambda x: x.pdf_srt/(x.l/len(x.pdf_srt)), axis=1)

unlabeled.data['cdf-A_srt'] = unlabeled.data[['cdf_srt', 'A']].apply(lambda x: x.cdf_srt*x.A, axis=1)

## Predict data classes

In [None]:
import importlib, sys
importlib.reload(sys.modules['utils'])
from utils import ELFModel

In [None]:
n_classes = 4
n_components = 10
columns = [['elf_srt'], ['pdf_srt'], ['pdf-l_srt'], ['cdf_srt'], ['cdf-A_srt']]
features = ['l', 'A']
n_estimators = 150
max_depth = 12

### Load models

In [None]:
savedir = 'models/'
if len(columns[0][0].split('_')) > 1:
    tag = '_' + columns[0][0].split('_')[-1]
else:
    tag = ''
models = []
for column in columns:
    column_name = '_'.join([''.join(k.split('_')[:-1]) for k in column]) + tag
    path = '_'.join([j + str(k) for (j,k) in zip(['c', 'z', 'n', 'd'],
                     [n_classes, n_components, n_estimators, max_depth])] + [column_name] + features)
    if center:
        path += '_cen'
    
    models.append(ELFModel(n_classes))
    models[-1].load_model(savedir + path)

### Predict

In [None]:
for i in tqdm(range(len(columns)), bar_format=unlabeled.bar_format):
    models[i].prepare_inputs(unlabeled.data)
    unlabeled.data = models[i].clf_predict(unlabeled.data)

In [None]:
unlabeled.data.head()

## Analyze predictions

### Class distribution per length and area bin

In [None]:
n_bins = 10
_, bins = np.histogram(unlabeled.data['l'], bins=n_bins)
unlabeled.data['l_bin'] = np.digitize(unlabeled.data['l'], bins[:-1], right=False).tolist()

_, bins = np.histogram(unlabeled.data['A'], bins=n_bins)
unlabeled.data['A_bin'] = np.digitize(unlabeled.data['A'], bins[:-1], right=False).tolist()

In [None]:
i = 0
index = 0
column = '_'.join([''.join(k.split('_')[:-1]) for k in columns[i]]) + tag
x = [np.stack(unlabeled.data.loc[unlabeled.data['l_bin']==k,
                                 'z_' + columns[i][index]].values) for k in range(1,n_bins+1)]
y = [unlabeled.data.loc[unlabeled.data['l_bin']==k, column + '_pred'].values for k in range(1,n_bins+1)]
fig = models[i].plot_projection_slices(x=x, y=y, axes=[0,1], cmap=models[i].dmap, order=True, index=index)

In [None]:
x = [np.stack(unlabeled.data.loc[unlabeled.data['A_bin']==k,
                                 'z_' + columns[i][index]].values) for k in range(1,n_bins+1)]
y = [unlabeled.data.loc[unlabeled.data['A_bin']==k, column + '_pred'].values for k in range(1,n_bins+1)]
fig = models[i].plot_projection_slices(x=x, y=y, axes=[0,1], cmap=models[i].dmap, order=True, index=index)

### Inspect candidate MDHs

In [None]:
formulas = [473, 479, 736, 816, 1125, 1413, 1511, 1587, 1684, 2306, 2344, 2430, 2975]
threshold = 0.6
column = columns[0]
column_name = '_'.join([''.join(k.split('_')[:-1]) for k in column]) + tag
mdh = unlabeled.get_mdhs(column=column_name, n_classes=n_classes, threshold=threshold, formulas=None)

In [None]:
i = 29
entry = mdh.iloc[i]
print(entry.formula)
struct = entry.structure
unlabeled.plot_structure(struct, rotation=('0x,0y,0z'));

In [None]:
thresholds = np.arange(0.,0.9,0.05)
column = columns[0]
n_mdh = np.zeros_like(thresholds)
for i, threshold in enumerate(thresholds):
    column_name = '_'.join([''.join(k.split('_')[:-1]) for k in column]) + tag
    try: n_mdh[i] = len(unlabeled.get_mdhs(column_name, n_classes, threshold))
    except:
        break

In [None]:
fig, ax = plt.subplots(figsize=(5,2))
_ax = ax.twinx()
ax.plot(thresholds, n_mdh, color='white')
_ax.plot(thresholds, 100*n_mdh/len(unlabeled.data), color=unlabeled.palette[0])
ax.set_xlabel('Threshold')
ax.set_ylabel('Number of candidates');
_ax.set_ylabel('Percentage (%)');