To run this notebook, load it in a local Jupyter instance (`pip install jupyter`). You'll also need these dependencies:

```
pip install tf-nightly
pip install google-cloud-storage
pip install requests
pip install google-api-python-client
```

You may also need to run this if you're not inside a google cloud VM:

```
gcloud auth application-default login
```

You need to configure [OAuth](https://support.google.com/cloud/answer/6158849?hl=en). It's a complicated process, best described [here](https://github.com/googleapis/google-api-python-client/blob/master/docs/client-secrets.md). In the end you donwload the `client_secrets.json` file and use it below.

In [None]:
import pandas as pd

pd.set_option('max_colwidth', 100)

resp = client.projects().locations().studies().list(
    parent=tune_vetting.study_parent()).execute()
studies = pd.DataFrame(resp['studies'])
studies = studies.sort_values('createTime', ascending=False)
studies.head(5)

In [None]:
study = studies['name'][0]
study_id = '{}/studies/{}'.format(tune_vetting.study_parent(), study.split('/')[-1])
print(study_id)
resp = client.projects().locations().studies().trials().list(parent=study_id).execute()

metrics_loss = []
params = []
trial_ids = []
for trial in resp['trials']:
  if 'finalMeasurement' not in trial:
    continue
    
  if 'value' not in trial['finalMeasurement']['metrics'][0]:
    continue

  loss, = (m['value'] for m in trial['finalMeasurement']['metrics'] if m['metric'] == 'loss')  
  
  params.append(trial['parameters'])
  metrics_loss.append(loss)
  trial_ids.append(int(trial['name'].split('/')[-1]))
  
print(max(trial_ids), 'total studies')
print(len(trial_ids), 'valid studies')

In [None]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt

matplotlib.rcParams.update({'font.size': 16})

fig, ax = plt.subplots()
ax.plot(trial_ids, np.minimum(metrics_loss, 0.5))
plt.xlabel("validation loss")
sorted_metrics = sorted(metrics_loss, reverse=True)

best = 0
for i, trial_id in enumerate(trial_ids):
  if (metrics_loss[i] <= sorted_metrics[-5]):
    print(trial_ids[i], metrics_loss[i])
  if (metrics_loss[i] <= metrics_loss[best]):
    best = i

print('Best trial:', trial_ids[best])

Best so far: `projects/mdan-playground/locations/us-central1/studies/6_vrevised_1b_vrevised` (306 0.11842478811740875)

In [None]:
import pprint
from astronet import models

config = models.get_model_config('AstroCNNModelVetting', config_name)

for param in params[best]:
  tune_vetting.map_param(config['hparams'], config['vetting_hparams'], param, config['inputs'])

print(train.FLAGS.train_steps)
pprint.pprint(config['vetting_hparams'])
pprint.pprint(config['hparams'])

In [None]:
import difflib
import pprint
from astronet import models

config1 = models.get_model_config('AstroCNNModelVetting', config_name)

config2 = models.get_model_config('AstroCNNModelVetting', config_name)
for param in params[best]:
  tune_vetting.map_param(config2['hparams'], config2['vetting_hparams'], param, config2.inputs)
  
pp = pprint.PrettyPrinter()
print('\n'.join(difflib.unified_diff(
  pp.pformat(config1).split('\n'), pp.pformat(config2).split('\n'),
  n=0
)))

```
python astronet/tune_vetting.py --model=AstroCNNModelVetting --config_name=vrevised --train_files=../mnt/tess/astronet/tfrecords-vetting-5-toi-train/* --eval_files=../mnt/tess/astronet/tfrecords-vetting-6-toi-val/* --pretrain_model_dir=../mnt/tess/astronet/checkpoints/revised_tuned_30_run_1 --train_steps=0 --tune_trials=10000 --client_secrets=../client_secrets.json --study_id=6_vrevised_1c
```

```
python astronet/tune_vetting.py --model=AstroCNNModelVetting --config_name=vrevised --train_files=../mnt/tess/astronet/tfrecords-vetting-5-toi-train/* --eval_files=../mnt/tess/astronet/tfrecords-vetting-6-toi-val/* --pretrain_model_dir=../mnt/tess/astronet/checkpoints/revised_tuned_30_run_1 --train_steps=0 --tune_trials=10000 --client_secrets=../client_secrets.json --study_id=6_vrevised_1c --client_id=lab2
```

```
python astronet/tune_vetting.py --model=AstroCNNModelVetting --config_name=vrevised --train_files=../mnt/tess/astronet/tfrecords-vetting-5-toi-train/* --eval_files=../mnt/tess/astronet/tfrecords-vetting-6-toi-val/* --pretrain_model_dir=../mnt/tess/astronet/checkpoints/revised_tuned_30_run_1 --train_steps=0 --tune_trials=10000 --client_secrets=../client_secrets.json --study_id=6_vrevised_1c --client_id=lab3
```