In [None]:
# %pip install --user "jax[cpu]"
# %pip install --user flax
# %pip install --user optax

In [None]:
import optax
import flax
import jax

import matplotlib.pyplot as plt
import scipy.stats as st
# not a fan of the FLAX/JAX nn imports, especially considering most models in this repository use PyTorch nn
import flax.linen as ln
import jax.numpy as jnp
import seaborn as sns
import pandas as pd
import numpy as np

from flax.training import train_state, checkpoints
from torch.utils.data import Dataset, DataLoader
from typing import Callable, Any, Tuple
from functools import partial
from tqdm import tqdm

In [None]:
class Linear(ln.Module):
    initializer : Callable[[Any, Tuple[int], Any], Any]
    
    @ln.compact
    def __call__(self, x):
        x = ln.Dense(features=1, use_bias=False, kernel_init=self.initializer)(x)
        return x
    
# https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.ipynb#scrollTo=3qjS60Zl-I2_

In [None]:
class GaussianReturns(Dataset):
    def __init__(self, size, cov, mean, cov_noise=None, noise_batch=32, reshuffle=True, seed=None):
        super().__init__()
        cov = np.asarray(cov)
        mean = np.asarray(mean)
        self._data = self._generate_returns(size, mean, cov, cov_noise, noise_batch, reshuffle)
        return
        
    def _noisy_cov(self, cov, cov_noise):
        return np.clip(cov + np.random.normal(cov.shape) * cov_noise, 0, 1)
        
    def _generate_returns(self, size, mean, cov, cov_noise, noise_batch, reshuffle):
        n_even = size // noise_batch
        size_uneven = size % noise_batch
        samples = list()
        assert cov.shape[0] == cov.shape[1] == mean.shape[0], f'{cov.shape} | {mean.shape}'
        
        for _ in range(n_even):
            # should provide option to input list of cov_noise offsets
            ncov = cov if cov_noise is None else self._noisy_cov(cov, cov_noise)
            sample = st.multivariate_normal(mean, cov).rvs(size=noise_batch)
            assert sample.shape == (noise_batch, cov.shape[0])
            samples.append(sample)
            
        if size_uneven:
            key1, key2 = jax.random.split(self.rng, num=2)
            ncov = cov if cov_noise is None else self._noisy_cov(cov, cov_noise)
            sample = st.multivariate_normal(mean, cov).rvs(size=size_uneven)
            samples.append(sample)
            
        return_array = np.concatenate(samples, axis=0)
        assert return_array.shape == (size, cov.shape[1])
        if reshuffle:
            return_array = np.random.permutation(return_array)
        return return_array
        
    def __len__(self):
        return self._data.shape[0]
    
    @property
    def shape(self):
        return self._data.shape

    def __getitem__(self, idx):
        data_point = self._data[idx]
        return data_point
    
# This collate function is taken from the JAX tutorial with PyTorch Data Loading
# https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

cov = [[.01, .001], [.001, .03]] # covariance matrix
bias = [.001, .002] # expected return

dataset = GaussianReturns(256, cov, bias)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=numpy_collate)

In [None]:
# this is pretty awful design but should otherwise be split into separate functions
def initialize_params(params, in_place=True, method='even', weight_layer='Dense_0', **reference_weights):
    # methods in even, exact, average, proportional
    w = params['params'][weight_layer]['kernel']
    assert w.ndim == 2 and w.shape[1] == 1
    if method == 'even':
        # no reference required
        n_assets = w.shape[0]
        weights = jnp.full_like(w, 1/n_assets)
    elif method == 'exact':
        # requires 'exact_weights'
        if 'exact_weights' not in reference_weights: 
            raise KeyError('Must pass desired weights as `exact_weights` keyword argument')
        weights = reference_weights['exact']
    elif method == 'average':
        if 'min_weights' not in reference_weights:
            raise KeyError('Must pass desired min weights as `min_weights` keyword argument')
        if 'max_weights' not in reference_weights:
            raise KeyError('Must pass desired max weights as `max_weights` keyword argument')
        min_w = jnp.asarray(reference_weights['min_weights'])
        max_w = jnp.asarray(reference_weights['max_weights'])
        weights = min_w + max_w / 2
    elif method == 'proportional':
        if 'totals' not in reference_weights:
            raise KeyError('Must pass desired array for proportionality as `totals` keyword argument')
        total_w = jnp.asarray(reference_weights['total'])
        total_sum = total_w.sum()
        weights = total_w / total_sum
    
    weights = jnp.expand_dims(weights, axis=1) if weights.ndim == 1 else weights
    assert w.shape == weights.shape, f'{w.shape} != {weights.shape}'
    assert weights.sum() == 1
    if not in_place:
        params = deepcopy(params)
        
    params['params'][weight_layer]['kernel'] = weights
    return params

def even_init(key, size, dtype, **kwargs):
    assert len(size) == 1 or size[1] == 1, f'Invalid shape {size}'
    even_w = 1 / size[0]
#     w = jnp.full(size, even_w, dtype=dtype)
    init = ln.initializers.constant(even_w)
    return init

In [None]:
rng = jax.random.PRNGKey(42)
rng, inpt_rng, init_rng = jax.random.split(rng, 3)

inpt_array = jax.random.normal(inpt_rng, (8, 2))

model = Linear(ln.initializers.glorot_normal())
params = jax.jit(model.init)(init_rng, inpt_array)

In [None]:
optimizer = optax.adamw(learning_rate=.1)

In [None]:
model_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
)

In [None]:
def sharpe_loss(
    state, 
    params, 
    data, 
    constraints=[(.6, .8),(.2, .4)], 
    gamma=[.5, .2, .2, 1.2], 
    min_var=.00001, 
    min_cvar=-.05
):
    k = int(data.shape[0] * .05)
    min_weights, max_weights = zip(*constraints)
    pfl_returns = state.apply_fn(params, data).squeeze()
    exp_mean = pfl_returns.mean(axis=0)
    exp_var = pfl_returns.std(axis=0)
    exp_cvar = jax.lax.slice_in_dim(jnp.argpartition(pfl_returns, k), 0, k, axis=0).mean()

    w = params['params']['Dense_0']['kernel']
    
    # sharpe, minimized instead of maximized
    loss = -exp_mean / (jnp.minimum(exp_var, min_var))
    loss += gamma[0] * (jnp.sum(w) - 1) ** 2
    loss += gamma[1] * (jnp.minimum(w - jnp.array(min_weights), 0).sum() ** 2)
    loss += gamma[2] * (jnp.minimum(jnp.array(max_weights) - w, 0).sum() ** 2)
    loss += gamma[3] * (jnp.minimum(exp_cvar - min_cvar, 0).sum() ** 2)
    return loss

def cvar_loss(state, params, data):
    k = int(data.shape[0] * .05)
    pfl_returns = state.apply_fn(params, data).squeeze()
    exp_mean = jnp.mean(pfl_returns)
    exp_cvar = jax.lax.slice(jnp.argpartition(pfl_returns, k), (0,), (k,)).mean()
    
    loss = exp_mean / (exp_mean - exp_cvar) 
    return loss

In [None]:
@jax.jit
def train_step(state, batch):
    # Gradient function
    grad_fn = jax.value_and_grad(
        sharpe_loss,  # Function to calculate the loss
        argnums=1,  # Parameters are second argument of the function
        has_aux=False,
    )
    loss, grads = grad_fn(state, state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [None]:
def train_model(state, dataloader, num_epochs=100):
    losss = list()
    for epoch in tqdm(range(num_epochs)):
        epoch_losss = list()
        for batch in dataloader:
            state, loss = train_step(state, batch)
            epoch_losss.append(loss)
        
        losss.append(jnp.mean(jnp.array(epoch_losss)))
                   
    plt.figure(figsize=(15,5))
    plt.plot(losss)
    plt.show()
    return state

def single_batch_train_model(state, data, num_epochs=100):
    weights = list()
    losss = list()
#     data = jnp.expand_dims(data, axis=0)
    print(data.shape)
    for epoch in tqdm(range(num_epochs)):
        state, loss = train_step(state, data)
#         display(state)
#         assert False
        weights.append(state.params['params']['Dense_0']['kernel'].ravel())
        losss.append(loss)  
    
    print(weights[-1])
    df = pd.DataFrame({f'A{i}' : weight_arr[10:] for i, weight_arr in enumerate(zip(*weights))})
    display(df)
    plt.figure(figsize=(15,5))
    sns.lineplot(data=df.astype(np.float32))
    plt.show()
            
    plt.figure(figsize=(15,5))
    plt.plot(losss)
    plt.show()
    return state

In [None]:
trained_model_state = single_batch_train_model(
    model_state, 
    dataset._data, 
    num_epochs=1000
)

In [None]:
checkpoints.save_checkpoint(
    ckpt_dir='my_checkpoints/',
    target=trained_model_state,
    step=100,
    prefix='my_model',
    overwrite=True,
)

In [None]:
loaded_model_state = checkpoints.restore_checkpoint(
    ckpt_dir='my_checkpoints/',
    target=model_state,
    prefix='my_model',
)

In [None]:
trained_model = model.bind(trained_model_state.params)
output = trained_model(dataset._data)
display(dir(trained_model))