In [1]:
#default_exp hypster_prepare

In [2]:
#export
from hypster.oo_hp import *

In [3]:
#export
from inspect import signature
import functools
from collections import OrderedDict

In [4]:
#export
class HypsterPrepare(HypsterBase):
    def __init__(self, call, base_call, *args, **kwargs):
        self.call            = call
        self.base_call       = base_call
        self.args            = args
        self.kwargs          = kwargs
        self.trials_sampled  = set()
        self.studies_sampled = set()
        self.base_object     = None

    def sample(self, trial):
        if trial.study.study_name not in self.studies_sampled:
            self.trials_sampled = set()
        elif trial.number in self.trials_sampled:
            return self.res            
        
        if self.base_call is not None:
            self.base_object = self.base_call.sample(trial)

        self.sampled_args   = [sample_hp(arg, trial) for arg in self.args]
        sampled_kwargs      = [sample_hp(arg, trial) for arg in self.kwargs.values()]
        self.sampled_kwargs = OrderedDict(zip(self.kwargs.keys(), sampled_kwargs))
        self.trials_sampled.add(trial.number)
        self.studies_sampled.add(trial.study.study_name)

        if self.base_object:
            if len(self.sampled_args) == 0 and len(self.sampled_kwargs) == 0:
                self.res = getattr(self.base_object, self.call)
            else:
                self.res = getattr(self.base_object, self.call)(*self.sampled_args, **self.sampled_kwargs)
        else:
            self.res = self.call(*self.sampled_args, **self.sampled_kwargs)
        return self.res
    
    def __call__(self, *args, **kwargs):
        #print(f"args {args}, kwargs {kwargs}")
        self.args = args
        self.kwargs = kwargs
        return self
    
    def __getattr__(self, name, *args, **kwargs):
        #print(f"name {name}, args {args}, kwargs {kwargs}")
        return HypsterPrepare(name, self, *args, **kwargs)

In [5]:
#export
def prepare(call):
    @functools.wraps(call)
    def wrapper_decorator(*args, **kwargs):
        #print(f"args: {args}")
        #print(f"kwargs: {kwargs}")
        all_args = list(args) + list(kwargs.values())
        if any([contains_hypster(arg, HYPSTER_TYPES) for arg in all_args]):
            return HypsterPrepare(call, None, *args, **kwargs)
        else:
            return call(*args, **kwargs)
    return wrapper_decorator

# Test Prepare

In [6]:
@prepare
def foo(a, *args, b="hi!", c=None, **kwargs):
    if c is not None:
        print(f"{a} and {b} and {c}")
    else:
        print(f"{a} and {b}")
        
    for arg in args:
        print(arg)
    
    return f"returned {a} and {b} and {c}"

In [7]:
@prepare
class Cls:
    def __init__(self, a, *args, b="hi!", c=None, **kwargs):
        if c is not None:
            print(f"{a} and {b} and {c}")
        else:
            print(f"{a} and {b}")

        for arg in args:
            print(arg)

        #return f"returned {a} and {b} and {c}"
    
    def shmul(self, batch):
        print(batch)

In [8]:
hps = foo("hi", "hello", "hola", b="shmuli")

hi and shmuli
hello
hola


In [9]:
hps = Cls("hi", "hello", "hola", b=HpCategorical("b", ["Shmuli", "Buli"]))

In [10]:
z = hps.shmul(batch=32)

In [11]:
#export
import optuna

In [12]:
#export
def run_func_test(x, n_trials=5):
    def objective(trial):
        y = x.sample(trial)
        print(y)
        return 1.0

    optuna.logging.set_verbosity(0)
    pruner = optuna.pruners.NopPruner()
    study = optuna.create_study(direction="maximize", pruner=pruner)
    study.optimize(objective, n_trials=n_trials, timeout=600)

In [13]:
run_func_test(z)

[W 2020-05-09 23:42:14,036] Setting status of trial#0 as TrialState.FAIL because of the following error: NameError("name 'studies_sampled' is not defined")
Traceback (most recent call last):
  File "C:\Users\user\Anaconda3\lib\site-packages\optuna\study.py", line 677, in _run_trial
    result = func(trial)
  File "<ipython-input-12-60b4d4bff52e>", line 4, in objective
    y = x.sample(trial)
  File "<ipython-input-4-36108ffc0ecf>", line 13, in sample
    if trial.study.study_name not in studies_sampled:
NameError: name 'studies_sampled' is not defined


NameError: name 'studies_sampled' is not defined

In [None]:
class Cls():
    def __init__(self, name, last_name="", nickname=""):
        self.name = name
        self.last_name = last_name
        self.nickname = nickname
        print(f"{self.name} {self.last_name} {self.nickname}")  

In [None]:
c = Cls("Gilad", nickname="The King!")

In [None]:
Cls2 = prepare(Cls)

In [None]:
x = Cls2("Gilad", HpCategorical("last", ["The King!", "The Best King!"]))

In [None]:
run_func_test(x)

In [None]:
from nbdev.export import notebook2script

In [None]:
notebook2script()