# Train a online portfolio

In [1]:
import jax.numpy as jnp
import jax
import optax


def neg_log_dot(relative_prices: jax.Array, portfolio: jax.Array):
    return -jnp.log(jnp.dot(relative_prices, portfolio))

In [2]:
key = jax.random.PRNGKey(42)
T = 100
N = 3


historical_prices = jnp.exp(jax.random.normal(key, (T, N)))
historical_prices

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Array([[ 9.266193  ,  0.39556924,  1.5939081 ],
       [ 1.5317304 ,  0.96124506,  0.2886936 ],
       [ 1.7270536 ,  5.056033  ,  0.74298453],
       [ 1.1581417 ,  0.46063346,  2.081038  ],
       [ 2.251048  ,  1.2899486 ,  1.0330684 ],
       [ 1.7388194 ,  0.6944476 ,  0.5631286 ],
       [ 0.19809502,  0.72179365,  1.1720095 ],
       [ 0.5430063 ,  1.9968487 ,  0.84817475],
       [ 2.8010573 ,  0.12892678,  0.21251392],
       [ 0.85309094,  2.105761  ,  0.9298663 ],
       [ 0.9364263 ,  0.6926263 , 12.047028  ],
       [ 0.92143077,  2.7572978 ,  2.0657616 ],
       [ 1.8104322 ,  1.0792102 ,  1.4097453 ],
       [ 1.4900318 ,  2.1079764 ,  1.6321301 ],
       [ 0.315403  ,  1.7979252 ,  0.5859138 ],
       [ 1.4391824 ,  3.8254228 ,  0.5350349 ],
       [ 0.99339545,  1.8183542 ,  1.0347971 ],
       [ 1.8982258 ,  0.9242348 ,  0.39538172],
       [ 0.12167022,  1.9283448 ,  0.41918126],
       [ 4.4567738 ,  0.5843716 ,  3.0589564 ],
       [ 0.4355084 ,  1.5790704 ,  1.220

In [3]:
start_learning_rate = 1e-1
params = jnp.ones(N) / N

In [4]:
from pyfoliopt.optimizer import egd

In [5]:
# Create the mirror descent optimizer
optimizer = egd(start_learning_rate)
opt_state = optimizer.init(params)

# Optimization loop (simplified)
for i in range(T):
    grads = jax.grad(neg_log_dot, argnums=1)(historical_prices[i], params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    params = params / jnp.sum(params) # Avoid numerical instability

    print("Optimized params:", params)
    print("Sum of weights: ", jnp.sum(params, axis=-1))

Optimized params: [0.3839505  0.30310607 0.31294343]
Sum of weights:  1.0
Optimized params: [0.40628064 0.30241165 0.2913077 ]
Sum of weights:  1.0
Optimized params: [0.3934591  0.3355461  0.27099478]
Sum of weights:  1.0
Optimized params: [0.39235774 0.31530902 0.2923333 ]
Sum of weights:  1.0000001
Optimized params: [0.40870792 0.3092059  0.2820862 ]
Sum of weights:  1.0
Optimized params: [0.4335879  0.29790625 0.2685058 ]
Sum of weights:  1.0
Optimized params: [0.4042872  0.30243897 0.2932738 ]
Sum of weights:  1.0
Optimized params: [0.38416237 0.32911697 0.28672063]
Sum of weights:  1.0
Optimized params: [0.43815294 0.2992723  0.2625748 ]
Sum of weights:  1.0
Optimized params: [0.42406756 0.3202307  0.2557018 ]
Sum of weights:  1.0000001
Optimized params: [0.38995683 0.29253802 0.31750515]
Sum of weights:  1.0
Optimized params: [0.37082285 0.30767807 0.32149914]
Sum of weights:  1.0000001
Optimized params: [0.37985858 0.2997441  0.3203973 ]
Sum of weights:  1.0
Optimized params: [0