In [11]:
from pathlib import Path
import pickle
import itertools
import json

import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

from ruben import PersistenceDiagram

In [170]:
outdir = Path('../out/')
ref_file = Path(
    "/Users/otis/Documents/rubens_speelhoekje/google/google_data/public_data/reference_data/task1_v4/model_configs.json")

In [171]:
from gtda.diagrams import PersistenceEntropy, ComplexPolynomial

def features_for_dim(pd, dim):
	# b: birth, d: death
	b, d, _ = pd[pd[:,2] == dim].T
	if dim == 0:
		assert d[-1] == np.inf
		d[-1] = 1
	with np.errstate(invalid='ignore', divide='ignore'):
		return [
			mean_bd := np.c_[b, d].mean(axis=0),
			mean_bd**2,
			np.nan_to_num(1/mean_bd + np.log(mean_bd)), # fix divide by zero
			np.c_[b, d].std(axis=0),
			[np.mean(b - d)],
			[np.mean(b - d)**2],
			[np.mean((b + d) / 2)],
			[np.mean((b + d) / 2)**2],
		]

features = {
	'avg_birth_death': 2,
	'avg_birth_death_squared': 2,
	'avg_birth_death_inverted' : 2,
	'std_birth_death': 2,
	'avg_life': 1,
	'avg_life_squared': 1,
	'avg_half_life': 1,
	'avg_half_life_squared': 1
}

def features_all_dims(pd):
	arrs = [np.concatenate(features_for_dim(pd, dim)) for dim in (0, 1)]
	return np.vstack(arrs)

In [172]:
gen_gaps = {
	'model_' + k: v['metrics']['train_acc'] - v['metrics']['test_acc']
	for k, v in json.loads(ref_file.read_text()).items()
}

def get_labeled_data(dir):
	candidates = [dir / (model + '.pd.npy') for model in gen_gaps]
	files = [file for file in candidates if file.exists()]
	models = [file.stem.split('.')[0] for file in files]
	if len(models) < len(gen_gaps):
		print('Missing ', [m for m in gen_gaps if m not in models])
	pds = [*map(np.load, files)]
	X = np.stack([*map(features_all_dims, pds)])
	y = [gen_gaps[m] for m in models]
	return X, y

In [183]:
def feature_mask(included):
	return np.concatenate([
		np.full(length, 1 if feature in included else 0)
		for feature, length in features.items()
	])

def run_experiment(X, y, model_fn):
	global features
	results = {}
	combinations = itertools.chain(*(
		itertools.combinations(features.keys(), r)
		for r in range(1, len(features) + 1)
	))
	rand_int = np.random.default_rng().integers(2**32 - 1)
	for combo in tqdm([*combinations]):
		# Should I do cross-validation here?
		# and vary the splits over the different combos
		# or keep them the same?
		X_masked = X[:,:,np.where(feature_mask(combo))[0]]
		X_flattened = X_masked.reshape((len(X), -1))
		X_train, X_test, y_train, y_test = train_test_split(
			X_flattened, y, train_size=0.7, random_state=rand_int)
		reg = model_fn().fit(X_train, y_train)
		results[combo] = reg.score(X_test, y_test)
	return results

In [185]:
for dir in (outdir / 'task1').iterdir():
	if not dir.is_dir() or dir.name == 'stratifiedkmeans':
		continue
	print('Sampler: ', dir.name)
	X, y = get_labeled_data(dir)
	results = run_experiment(X, y, LinearRegression)
	print('All features:')
	print(max(results.items(), key=lambda x: len(x[0])))
	print('Best selection:')
	print(max(results.items(), key=lambda x: x[1]))
	print('\n\n')

Sampler:  importance
Missing  ['model_156', 'model_157', 'model_158', 'model_220', 'model_221']


100%|██████████| 255/255 [00:00<00:00, 358.38it/s]


All features:
(('avg_birth_death', 'avg_birth_death_squared', 'avg_birth_death_inverted', 'std_birth_death', 'avg_life', 'avg_life_squared', 'avg_half_life', 'avg_half_life_squared'), 0.5510391039486231)
Best selection:
(('avg_birth_death', 'avg_birth_death_inverted', 'std_birth_death', 'avg_life'), 0.7285709264822879)



Sampler:  random
Missing  ['model_28', 'model_29', 'model_156', 'model_157', 'model_220', 'model_221']


100%|██████████| 255/255 [00:00<00:00, 389.12it/s]

All features:
(('avg_birth_death', 'avg_birth_death_squared', 'avg_birth_death_inverted', 'std_birth_death', 'avg_life', 'avg_life_squared', 'avg_half_life', 'avg_half_life_squared'), 0.7885525356416322)
Best selection:
(('avg_birth_death_inverted', 'std_birth_death', 'avg_life_squared'), 0.8858811366890595)






