In [None]:
# import sys
# sys.path.append('shallow')

# Imports

In [None]:
#export
import numpy as np
from functools import partial

# Code

In [None]:
#export
def annealer(f):
    def _inner(start, end): return partial(f, start, end)
    return _inner

@annealer
def sched_lin(start, end, pos): return start + pos*(end-start)
@annealer
def sched_cos(start, end, pos): return start + (1 + np.cos(np.pi*(1-pos))) * (end-start) / 2
@annealer
def sched_const(start, end, pos):  return start
@annealer
def sched_exp(start, end, pos): return start * (end/start) ** pos

In [None]:
#export
def combine_scheds(scheds):
    pcts, fscheds = [], []
    for s in scheds: pcts.append(s[0]); fscheds.append(s[1])
    
    assert sum(pcts) == 1.
    pcts = np.array([0] + pcts)
    assert (pcts >= 0).all()
    pcts = np.cumsum(pcts, 0)
    def _inner(pos):
        idx = (pos >= pcts).nonzero()[0].max() #[0] for 0-th axis, pcts is 1d
        if idx == len(pcts)-1: idx -= 1
        actual_pos = (pos-pcts[idx]) / (pcts[idx+1]-pcts[idx])
        return fscheds[idx](actual_pos)
    return _inner

# Tests

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
annealings = "NO LINEAR COS EXP".split()

a = np.arange(0, 100)
p = np.linspace(0.01,1,100)

fns = [sched_const, sched_lin, sched_cos, sched_exp]
for fn, t in zip(fns, annealings):
    f = fn(2, 1e-2)
    plt.plot(a, [f(o) for o in p], label=t)
plt.legend();

In [None]:
sched = combine_scheds([
    [.3, sched_cos(0.3, 0.6)],
    [.7, sched_cos(0.6, 0.2)]
])

In [None]:
plt.plot(a, [sched(o) for o in p])

In [None]:
sched = combine_scheds([
    [.2, sched_const(32, _)],
    [.4, sched_const(64, _)],
    [.4, sched_const(128, _)]
])
plt.plot(a, [sched(o) for o in p])