In [274]:
from functools import wraps, partial
from dataclasses import dataclass, field
import inspect
import random
from copy import copy, deepcopy
from typing import Sequence, TypeVar, cast, Iterable, Dict
import torch
from numpy import cos, sin, tan, tanh, log, exp
from pdb import set_trace

def uniform(low, high, size=None):
    return random.uniform(low,high) if size is None else torch.FloatTensor(size).uniform_(low,high)

def log_uniform(low, high, size=None):
    res = uniform(log(low), log(high), size)
    return exp(res) if size is None else res.exp_()

def rand_bool(p, size=None): return uniform(0,1,size)<p


def get_default_args(func):
    return {k: v.default
            for k, v in inspect.signature(func).parameters.items()
            if v.default is not inspect.Parameter.empty}

def get_arg_names(func):
    return list(inspect.signature(func).parameters)

def get_dynamic_var_args(func):
    return {k: v.default
            for k, v in inspect.signature(func).parameters.items()}

def listify(p=None, q=None):
    if p is None: p=[]
    elif not isinstance(p, Iterable): p=[p]
    n = q if type(q)==int else 1 if q is None else len(q)
    if len(p)==1: p = p * n
    return p

def bind_args(func, v):
    arg_names = get_arg_names(func)
    bound_args = { arg_names[i]: vi for i, vi in enumerate(v)}
    return bound_args
        

def resolve_fun_annotations(func, kwargs):
    params = copy(func.__annotations__)
    resolved = {}
    for k, v in kwargs.items():
        if k in params and not isinstance(v, DynamicVar):
            rand_func = params[k]
            if isinstance(v, Dict): resolved[k] = rand_func(**v)
            elif isinstance(v, Iterable):
                arg_names = get_arg_names(rand_func)
                bound_args = bind_args(rand_func, v)
                resolved[k] = rand_func(**bound_args)
            else: resolved[k] = v
        else: resolved[k] = v
    return resolved


def dynamic_resolve(a):
    if isinstance(a, DynamicVar): return a()
    elif isinstance(a, Iterable): return [dynamic_resolve(ai) for ai in a]
    elif isinstance(a, Dict): return {k:dynamic_resolve(vi) for k,vi in a.items()}
    else: return a

def dynamic_release(a):
    if isinstance(a, DynamicVar): return a.release()
    elif isinstance(a, list):
        for ai in a:
            dynamic_release(ai)
    elif isinstance(a, dict):
        for vi in a.values():
            dynamic_release(vi)
    
def clone_var(v):
    if isinstance(v, DynamicVar): return v.clone()
    else: return copy(v)
    
    
class DynamicVar(object):
    def __init__(self, calc_func, args=None, kwargs=None):
        print('init')
        self.value_ = None
        self.func = calc_func       
        self.def_args = get_default_args(self.func)   
        self.bound_args = bind_args(self.func, args) if args else {}
        self.kwargs = resolve_fun_annotations(self.func, kwargs) if kwargs else {}
            
    def clone(self, **kwargs):
        copy_kwargs = {k:clone_var(v) for k,v in self.kwargs.items()}
        copy_bound = {k:clone_var(v) for k,v in self.bound_args.items()}
        kwargs = {**copy_kwargs, **copy_bound, **kwargs}
        return DynamicVar(self.func, args=None, kwargs=kwargs)
    
    def override(self, **kwargs):
        #kwargs = resolve_fun_annotations(self.func, kwargs)
        kwargs = {**self.kwargs, **self.bound_args, **kwargs}
        return DynamicVar(self.func, args=None, kwargs=kwargs)
    
    def resolve(self, *args, **kwargs):
        args = [dynamic_resolve(a) for a in args]
        kwargs = resolve_fun_annotations(self.func, kwargs) 
        kwargs = {**self.def_args, **self.kwargs, **self.bound_args, **kwargs}
        kwargs = {k:dynamic_resolve(v) for k,v in kwargs.items()}
        self.value_ = self.func(*args, **kwargs)

    def release(self):
        dynamic_release(self.kwargs)
        dynamic_release(self.func)
        self.value_ = None
    
    def __repr__(self):
        return f'{self.func.__name__}:{self.kwargs}:{self.value_}'
    
    def __call__(self, *args, **kwargs):
        if self.value_ is None:
            self.resolve(*args, **kwargs)
        return self.value_


def dynamic_var(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        return DynamicVar(func, args=args, kwargs=kwargs) 
    return wrapper

def dynamic_func(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        f = DynamicVar(func, kwargs=kwargs)
        if args: return f(*args)
        else: return f
    return wrapper

@dynamic_var
def Uniform(low, high, size=None):
    if size is None: value = random.uniform(low,high)
    else: value = torch.FloatTensor(size).uniform_(low, high)
    return value

@dynamic_var
def LogUniform(low, high, size=None):
    res = uniform(log(low), log(high), size)
    value = exp(res) if size is None else res.exp_()
    return value

@dynamic_var
def Bool(p, size=None):
    return uniform(0,1,size)<p

@dynamic_var
def Constant(c):
    return c

@dynamic_var
def Choice(choices):
    return random.choice(choices)



@dynamic_func
def mult(x, by1:UniformVar=1, by2:UniformVar=1, do_mult:BoolVar=True):
    print('calc mult:', x, by1, by2, do_mult)
    if do_mult: return x * by1 * by2
    else: return x



In [279]:
m = mult()
m2 = m.clone(by2=Choice([5,6,7]), by1=(3.,3.1))
m(5), m2(5)

init
init
init
init
calc mult: 5 1 1 True
calc mult: 5 3.033929049173557 5 True


(5, 75.84822622933892)

In [204]:
c = Choice(choices=[1,2,3])

init


In [205]:
c.release()
c(), c()

(1, 1)

In [206]:
for i in range(5):
    print(mult(2))

init
calc mult: 2 1 1 True
2
init
calc mult: 2 1 1 True
2
init
calc mult: 2 1 1 True
2
init
calc mult: 2 1 1 True
2
init
calc mult: 2 1 1 True
2


In [207]:
m1 = mult(by1=Choice([2,4,6]), do_mult=0.5)
print(m1)
for i in range(5):
    print(m1(5))

init
init
init
mult:{'by1': Choice:{}:None, 'do_mult': BoolVar:{'p': 0.5}:None}:None
calc mult: 5 6 1 True
30
30
30
30
30


In [208]:
for i in range(5):
    m1.release()
    print(m1(5))

calc mult: 5 6 1 True
30
calc mult: 5 4 1 True
20
calc mult: 5 2 1 False
5
calc mult: 5 6 1 False
5
calc mult: 5 2 1 True
10


In [209]:
# uses all params
m2 = m1.clone(by1=(3.,3.3))
m3 = m1.override(by2=Choice([7,8]))
m1.release()
m2.release()
m3.release()
for i in range(5):
    print(m1(5), m2(5), m3(5))

init
init
init
init
init
init
calc mult: 5 6 1 True
calc mult: 5 3.0081518382378283 1 False
calc mult: 5 6 8 True
30 5 240
30 5 240
30 5 240
30 5 240
30 5 240
