In [174]:
from determined.experimental import client
from determined.common import yaml
import pathlib
import time

In [175]:
import threading
import queue

In [176]:
model_dir = pathlib.Path("./cifar10")
exp_conf_path = model_dir / "const.yaml"
exp_conf = yaml.safe_load(exp_conf_path.read_text())

In [177]:
val_metric_name = "validation_error"
hp_search = {
    "learning_rate": [0.0001, 0.001, 0.01, 0.1]
}
terminal_trial_states = [
    "CANCELED", "COMPLETED", "ERROR"
]

In [178]:
def cprint(string, ansi_code):
    reset = "\033[0m"
    color = f"\033[{ansi_code}m"
    print(color + string + reset)

def follow_trial_logs(trial):
    # Readable ANSI colors
    ansi_colors = list(range(32, 37)) + list(range(92, 97))
    
    # Hacky attempt to assign different colors to various trials so the aggregated logs are easier to read.
    color = ansi_colors[trial.id % len(ansi_colors)]
    for log in trial.logs(follow=True):
        cprint(log.rstrip(), color)


def monitor_trial(trial, interval):
    prev_val_steps = None
    prev_best_val = None
    steps_threshold = 5
    
    logs_thread = threading.Thread(target=follow_trial_logs, args=(trial,))
    logs_thread.start()
    while trial.state.name not in terminal_trial_states:
        trial.reload()
        summary_metrics = trial.summary_metrics
        if not summary_metrics or "validation_metrics" not in summary_metrics:
            time.sleep(interval)
            continue
        current_val = summary_metrics["validation_metrics"][val_metric_name]["min"]
        current_steps = summary_metrics["validation_metrics"][val_metric_name]["count"]
        
        if prev_val_steps is not None and prev_best_val is not None:
            early_stop = should_early_stop(prev_best_val, prev_val_steps, current_val, current_steps, steps_threshold)
            if early_stop:
                print(f"Early stopping trial {trial.id} due to no improvement for {val_metric_name} for {steps_threshold} steps.")
                trial.kill()
                
        time.sleep(interval)
    logs_thread.join()

def create_experiment_with_hparams(hp_name, hp_val, val_metric_name, trial_queue):
    print(f"Starting experiment with {hp_name}={hp_val}")
    exp_conf["hyperparameters"][hp_name] = hp_val

    exp = client.create_experiment(config=exp_conf, model_dir=model_dir)

    trial = exp.await_first_trial()
    trial_queue.put(trial.id)

    monitor_trial(trial, 5)
    
def should_early_stop(prev_best_val, prev_val_steps, current_best_val, current_val_steps, stop_threshold):
    """
    Primitive early stopping: returns True if a trial's searcher validation metric has not improved within a specified number of steps, else False.
    """
    if prev_val_steps + stop_threshold <= current_val_steps and current_best_val == prev_best_val:
        return True
    return False


In [None]:
trial_queue = queue.Queue()
exp_threads = []

for hp_name, hp_vals in hp_search.items():
    for hp_val in hp_vals:
        exp_thread = threading.Thread(target=create_experiment_with_hparams, args=(hp_name, hp_val, val_metric_name, trial_queue))
        exp_threads.append(exp_thread)
        exp_thread.start()

for thread in exp_threads:
    thread.join()

print(f"All trials completed. Generating summary report.")
trial_vals = []
for trial_id in trial_queue.queue:
    trial = client.get_trial(trial_id=trial_id)
    
    # Smaller is better
    trial_best_val = trial.summary_metrics["validation_metrics"][val_metric_name]["min"]
    
    for hparam in hp_search.keys():  
        trial_vals.append({
            "trial_id": trial.id,
            "hparam_name": hparam,
            "hparam_val": trial.hparams[hparam],
            "val_metric_name": val_metric_name,
            "best_val_metric": trial_best_val,
        })

trial_vals.sort(key=lambda x: x["best_val_metric"])

print("=" * 100)
print(f"Hyperparameter space: {hp_search}")
print(f"Trials completed: {len(trial_vals)}")
print(f"Best validation: {trial_vals[0]}")

Starting experiment with learning_rate=0.0001
Starting experiment with learning_rate=0.001d 0 files
Starting experiment with learning_rate=0.01nd 0 files
Starting experiment with learning_rate=0.1and 0 files
Preparing files to send to master... 6.5KB and 6 files                                                      

Preparing files to send to master... 6.5KB and 6 files
Preparing files to send to master... 6.5KB and 6 files
[34m[2023-07-20T00:30:38.106535Z]          || INFO: Scheduling Trial 72 (Experiment 72) (id: 8a4aa7d6-aff9-4295-9dab-247134f17857)[0m
[36m[2023-07-20T00:30:38.106652Z]          || INFO: Scheduling Trial 74 (Experiment 73) (id: 8a4aa7d6-aff9-4295-9dab-247134f17857)[0m
[35m[2023-07-20T00:30:38.330514Z]          || INFO: Scheduling Trial 73 (Experiment 74) (id: 8a4aa7d6-aff9-4295-9dab-247134f17857)[0m
[33m[2023-07-20T00:30:38.015778Z]          || INFO: Scheduling Trial 71 (Experiment 71) (id: 8a4aa7d6-aff9-4295-9dab-247134f17857)[0m
[33m[2023-07-20T00:30:38.59

In [None]:
from determined.common.experimental.trial import LogLevel

# Filtering logs -> get all debug logs from specific agent before a specific timestamp.
for log in trial.logs(search_text="a658e1d8", min_level=LogLevel.DEBUG):
    print(log)