In [None]:
import numpy as np
import pandas as pd
import jax
import plotly.graph_objects as go
import jax.numpy as jnp
from jax import random
from functools import partial
import matplotlib.pyplot as plt

print(jax.default_backend())

key = random.PRNGKey(42)
new_key, subkey = random.split(key)

In [None]:
b = jnp.array([0.1, 0.9])

In [None]:
x1 = 100 + 100*np.random.lognormal(size=(1000, ), sigma=0.5)
x2 = 200 + 50*np.random.lognormal(size=(1000, ), sigma=0.5)

In [None]:
x_all = jnp.vstack((x2, x1))

In [None]:
plt.plot(x_all.T);

In [None]:
# Convert x_all to stock multipliers

x_all = x_all[:, 1:] / x_all[:, 0:-1]

In [None]:
plt.plot(x_all.T);

In [None]:
@jax.jit
def S_k(b, x_k):
    """
    Wealth upto time k
    b = some scalar, where we set b2 = 1 - b1, shape = (2, )
    x_k = matrix of stock multiples, shape = (2, K)
    """
    b_full = jnp.array([b, 1. - b])
    return jnp.multiply(b_full[..., None], x_k).sum(axis=0).prod()

In [None]:
# Hindsight optimal portfolio

b_hos = list(np.arange(0, 1, 0.05))
S_k_ho = partial(S_k, x_k = x_all)
S_ho = jax.tree_map(S_k_ho, b_hos)

In [None]:
plt.plot(np.array(b_hos), S_ho, "ro-")

In [None]:
# Check extreme portfolios lineup

assert S_k(1, x_all) == x_all.prod(axis=1)[0]
assert S_k(0, x_all) == x_all.prod(axis=1)[1]

In [None]:
# Calculation of universal portfolio fraction at (k+1)
# At t = 0, we assume the amount is [0.5, 0.5] split
# then k+1 can be calculated, without observing x_(k+1)

In [None]:
@jax.jit
def num_body_fun(i, x):
    return (i/100)*S_k(i/100, x) # + val

@jax.jit
def denom_body_fun(i, x):
    return S_k(i/100, x)

In [None]:
# idx = list(jnp.arange(0, 101, 1))
# tmp1 = jax.tree_map(num_body_fun, idx, list(x_all[:, 0:2]))
# num = jnp.trapz(jnp.array(tmp1))

In [None]:
@jax.jit
def b_k_next(x_k):
    
    num_body_func = jax.tree_util.Partial(num_body_fun, x=x_k)
    denom_body_func = jax.tree_util.Partial(denom_body_fun, x=x_k)

    # Calculate numerator

    idx = list(jnp.arange(0, 101, 1))
    tmp1 = jax.tree_map(num_body_func, idx)
    num = jnp.trapz(jnp.array(tmp1))

    # Calculate denominator

    tmp2 = jax.tree_map(denom_body_func, idx)
    denom = jnp.trapz(jnp.array(tmp2))

    return num/denom

In [None]:
b_k_next(x_all[:, 0:20])

In [None]:
#  Universal portfolio amounts

xs = [x_all[:, 0:i] for i in range(1, x_all.shape[1]+1)]
# b_univ = jax.tree_map(b_k_next, xs)

b_univ = [0.5]

for i, x_k in enumerate(xs):
    print(i)
    b_univ.append(b_k_next(x_k))

In [None]:
b_univ_all = jnp.vstack((jnp.array(b_univ), 1-jnp.array(b_univ)))

In [None]:
b_univ_all.shape

In [None]:
x_all_upd = jnp.hstack((jnp.array([1., 1.])[:, None], x_all))

In [None]:
x_all_upd.shape

In [None]:
def S_n(b_mat, x_mat):
    # Note telescoping happens naturally
    return ((b_mat) * (x_mat)).sum(axis=0).cumprod()

# Wealth multiplier using b_univ_s

S_univ = S_n(b_univ_all, x_all_upd)

In [None]:
plt.plot(S_univ)

In [None]:
# TODO: we need to calculate wealth multiplier at t using b

In [None]:
# TODO, given some x[:, 0:L], we can calculate b_k_next
# But how can I use Jax to loop over x[:, 0:L]?

In [None]:
# def num_body_fun(i, val, x):
#     return (i/100)*S_k(i/100, x) # + val

# def denom_body_fun(i, val, x):
#     return S_k(i/100, x) + val

# def b_k_next(x_k):
#     """
#     b_k_next = b[k+1] optimal
#     x_k = x[k] observed upto now
#     b[k+1] = \int_{0, 1} b*S(k, b)db / \int_{0, 1} S(k, b) db
#     """
#     num_body_func = partial(num_body_fun, x=x_k)
#     denom_body_func = partial(num_body_fun, x=x_k)
    
#     # Solve the integrals numerically
    
#     numerator = jax.lax.fori_loop(0, 101, num_body_func, 0.)
#     denominator = jax.lax.fori_loop(0, 101, denom_body_func, 0.)
    
#     return numerator/denominator