In [50]:
from collections.abc import Sequence, Callable
from typing import Any, Literal, Optional
import math
from functools import partial
from glio.imports import *
from torchzero.random.random import uniform
from torchzero.python_tools import auto_compose, identity

SCALE_LITERALS = Literal[None, 'linear', 'log', 'log2', 'log10', 'ln']

SCALERS = {
    None: identity,
    'linear': identity,
    'log': torch.log10,
    'log2': torch.log2,
    'log10': torch.log10,
    'ln': torch.log
}

class XPow:
    def __init__(self, base):
        self.base = base
    def __call__(self, x):
        return self.base ** x

UNSCALERS= {
    None: identity,
    'linear': identity,
    'log': XPow(10),
    'log2': XPow(2),
    'log10': XPow(10),
    'ln': torch.exp
}


class ParamNumeric(torch.nn.Module):
    def __init__(self, name:str, min:float, max:float, scale:SCALE_LITERALS, mul:float, init: Callable | float | torch.Tensor, tfm:Optional[Callable | Sequence[Callable]]):
        super().__init__()
        self.name = name
        self.min = min
        self.max = max
        self.scale = scale
        self.mul = mul
        self.init = init

        self.scaler = SCALERS[scale]
        self.unscaler = UNSCALERS[scale]
        self.true_min = self.unscaler(self.min) * self.mul
        self.true_max = self.unscaler(self.max) * self.mul

        self.tfm = auto_compose(tfm)

        if callable(init):
            self.param = torch.nn.Parameter(init(1, self.true_min, self.true_max, dtype=torch.float32), requires_grad = True)
        else:
            self.param = torch.nn.Parameter(torch.tensor(self.unscaler(init) * self.mul, dtype=torch.float32), requires_grad = True)

    def forward(self):
        return self.tfm(self.scaler(self.param / self.mul))

class ParamInt(ParamNumeric):
    def __init__(self, name:str, min:float, max:float, scale:SCALE_LITERALS, mul:float, init: Callable | float | torch.Tensor | Any, tfm:Optional[Callable]):
        tfm = Compose(auto_compose(tfm), int)
        super().__init__(name = name, min = min, max = max, scale = scale, mul = mul, init = init, tfm=tfm)

class ParamFloat(ParamNumeric):
    def __init__(self, name:str, min:float, max:float, scale:SCALE_LITERALS, mul:float, init: Callable | float | torch.Tensor, tfm:Optional[Callable]):
        tfm = Compose(auto_compose(tfm), float)
        super().__init__(name = name, min = min, max = max, scale = scale, mul = mul, init = init, tfm=tfm)

def _to_bool(x): return x > 0
class ParamBool(ParamNumeric):
    def __init__(self, name:str, mul:float, init: Callable | bool | float | torch.Tensor, tfm:Optional[Callable]):
        tfm = Compose(auto_compose(tfm), _to_bool)
        super().__init__(name = name, min = -1, max = 1, mul = mul, scale=None, init = init, tfm=tfm)

def _choose(x, choices:Sequence): return choices[int(x % len(choices))]
class ParamCategorical(ParamNumeric):
    def __init__(self, name:str, choices: Sequence, mul:float, init: Callable | float | torch.Tensor | Any, tfm:Optional[Callable]):
        tfm = Compose(auto_compose(tfm), partial(_choose, choices=choices))

        if init in choices: init = choices.index(init)
        super().__init__(name = name, min = 0, max = len(choices), mul = mul, scale=None,  init = init, tfm=tfm)


class Trial(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def suggest_float(self, name:str, min, max, scale:SCALE_LITERALS = None, mul=1., init = uniform, tfm=None):
        if name not in dict(self.named_modules()):
            self.add_module(name, ParamFloat(name = name, min = min, max = max, scale = scale, mul = mul, init = init, tfm=tfm))
        return getattr(self, name)()

    def suggest_int(self, name:str, min, max, scale:SCALE_LITERALS = None, mul=0.01, init = uniform, tfm=None):
        if name not in dict(self.named_modules()):
            self.add_module(name, ParamInt(name = name, min = min, max = max, scale = scale, mul = mul, init = init, tfm=tfm))
        return getattr(self, name)()

    def suggest_bool(self, name:str, mul=0.01, init = uniform, tfm=None):
        if name not in dict(self.named_modules()):
            self.add_module(name, ParamBool(name = name, mul = mul, init = init, tfm=tfm))
        return getattr(self, name)()

    def suggest_categorical(self, name:str, choices: Sequence, mul=0.01, init = uniform, tfm=None):
        if name not in dict(self.named_modules()):
            self.add_module(name, ParamCategorical(name = name, choices = choices, mul = mul, init = init, tfm=tfm))
        return getattr(self, name)()

In [61]:
class Study:
    def __init__(self, objective):
        self.objective = objective
        self.trial = Trial()
        self.lowest_loss = self.objective(self.trial)

        self.loss_history = []
        self.lowest_loss = float('inf')
        self.best_params = None

    def parameters(self):
        return self.trial.parameters()

    def step(self):
        loss = self.objective(self.trial)

        if isinstance(loss, torch.Tensor): float_loss = float(loss.detach().cpu())
        else: float_loss = float(loss)
        self.loss_history.append(float_loss)

        if float_loss < self.lowest_loss:
            self.lowest_loss = float_loss
            self.best_params = [i.clone() for i in self.parameters()]

        return loss

    def __call__(self): return self.step()

In [63]:
from torchzero.optim import AcceleratedRandomSearch
def objective(trial:Trial):
    x = trial.suggest_float("x", -10, 10)
    y = trial.suggest_float("y", -10, 10)
    return x ** 2 + y ** 2

study = Study(objective)
optimizer = AcceleratedRandomSearch(study.parameters(), [-10, 10])

for _ in range(1000):
    optimizer.zero_grad()
    loss = optimizer.step(study)
    print(loss, end = '\r')

1.8457534934271646e-09

In [68]:
uniform(10, -float('inf'), float('inf'))

RuntimeError: from is out of bounds for float