# Kelly Criterion

### **Problem Statement - How much fraction of our Wealth should we bet?**

Winning probability = p

Losing probability = q

If win, fraction of bet gained = g

If lose, fraction of bet lost = l


### Idea: 

If win: 
$$S = S + S*f*g = (1+fg)S$$

If lose:
$$S = S - S*f*l = (1+fl)S$$

Upon playing n games, we have growth rate of 
$$r = (1+fg)^{np}(1+fl)^{nq}$$


We want to maximize r, thus we can find the local maxima, ie. 
$$\frac{dlnr}{df} = \frac{npg}{1+fg} + \frac{nql}{1+fl} = 0$$
$$ (1-fl)(pg) - (1+fg)(ql) = 0 $$
$$ pg - flpg - ql - fgql = 0$$
$$ flg(p+q) = pg - ql$$
$$ f = \frac{p}{l} + \frac{q}{g} $$

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sim = 1000      # Number of simulations
n = 50          # Bets per simulation
p = 0.3         # Win probability
q = 1-p         # Loss probability
g = 5           # Win multiplier (e.g., bet $1 → win $10)
l = 1           # Loss multiplier (bet $1 → lose $1)
b = g / l       # Odds (win/loss ratio)
kelly_fraction = p/l - q/g

initial_wealth = 1e7  
experiments = np.zeros((sim, n+1)) 
experiments[:, 0] = initial_wealth

np.random.seed(42)
for i in range(sim):
    wealth = initial_wealth
    for j in range(1, n+1): 
        if np.random.random() < p:
            # Win: wealth increases by kelly_fraction * g
            wealth += wealth * kelly_fraction * g
        else:
            # Loss: wealth decreases by kelly_fraction * l  
            wealth -= wealth * kelly_fraction * l
        
        wealth = max(wealth, 0)
        experiments[i, j] = wealth

In [21]:
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np

# Assuming 'experiments' is your simulated data
# If not, recreate it with the corrected Kelly simulation:
sim = 50
n = 100
experiments = np.zeros((sim, n+1))
initial_wealth = 1e7

# Generate sample data (replace with your actual experiments array)
np.random.seed(42)
for i in range(sim):
    wealth = initial_wealth
    experiments[i, 0] = wealth
    for j in range(1, n+1):
        # Simplified random walk for demonstration
        wealth = wealth * (1 + np.random.normal(0.001, 0.02))
        experiments[i, j] = wealth

# Create interactive Plotly figure
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=('Wealth Trajectories', 'Distribution of Final Wealth', 
                    'Growth Rate Distribution', 'Cumulative Statistics'),
    specs=[[{"secondary_y": False}, {"secondary_y": False}],
           [{"secondary_y": True}, {"secondary_y": False}]],
    vertical_spacing=0.12,
    horizontal_spacing=0.1
)

# ========== Plot 1: Wealth Trajectories (Main Chart) ==========
for i in range(min(50, len(experiments))):
    fig.add_trace(
        go.Scatter(
            x=list(range(len(experiments[i]))),
            y=experiments[i],
            mode='lines',
            line=dict(width=1, color='rgba(100, 149, 237, 0.4)'),  # Cornflower blue
            name=f'Simulation {i+1}',
            showlegend=False,
            hovertemplate='Bet: %{x}<br>Wealth: $%{y:,.0f}<br>Return: %{customdata:.1%}<extra></extra>',
            customdata=[experiments[i][j]/initial_wealth - 1 for j in range(len(experiments[i]))]
        ),
        row=1, col=1
    )

# Add mean trajectory
mean_trajectory = experiments.mean(axis=0)
fig.add_trace(
    go.Scatter(
        x=list(range(len(mean_trajectory))),
        y=mean_trajectory,
        mode='lines',
        line=dict(width=3, color='#FF6B6B'),  # Coral red
        name='Average',
        hovertemplate='Bet: %{x}<br>Avg Wealth: $%{y:,.0f}<extra></extra>'
    ),
    row=1, col=1
)

# Add median trajectory
median_trajectory = np.median(experiments, axis=0)
fig.add_trace(
    go.Scatter(
        x=list(range(len(median_trajectory))),
        y=median_trajectory,
        mode='lines',
        line=dict(width=3, color='#4ECDC4', dash='dash'),  # Turquoise
        name='Median',
        hovertemplate='Bet: %{x}<br>Median Wealth: $%{y:,.0f}<extra></extra>'
    ),
    row=1, col=1
)

# ========== Plot 2: Distribution of Final Wealth ==========
final_wealths = experiments[:, -1]
fig.add_trace(
    go.Histogram(
        x=final_wealths,
        nbinsx=30,
        name='Final Wealth',
        marker_color='#95E1D3',  # Light teal
        opacity=0.8,
        hovertemplate='Wealth: $%{x:,.0f}<br>Count: %{y}<extra></extra>'
    ),
    row=1, col=2
)

# Add vertical lines for statistics
stats = {
    'Mean': np.mean(final_wealths),
    'Median': np.median(final_wealths),
    '10th Pctl': np.percentile(final_wealths, 10),
    '90th Pctl': np.percentile(final_wealths, 90)
}

colors = {'Mean': '#FF6B6B', 'Median': '#4ECDC4', '10th Pctl': '#FFD166', '90th Pctl': '#118AB2'}
for stat_name, stat_value in stats.items():
    fig.add_vline(
        x=stat_value,
        line_dash="dash",
        line_color=colors[stat_name],
        opacity=0.7,
        row=1, col=2
    )
    
    # Add annotation
    fig.add_annotation(
        x=stat_value,
        y=0.9,
        yref="paper",
        text=f"{stat_name}: ${stat_value:,.0f}",
        showarrow=False,
        font=dict(size=10, color=colors[stat_name]),
        row=1, col=2
    )

# ========== Plot 3: Growth Rate Distribution ==========
growth_rates = np.diff(np.log(experiments), axis=1).flatten()
fig.add_trace(
    go.Histogram(
        x=growth_rates,
        nbinsx=40,
        name='Growth Rates',
        marker_color='#FFD166',  # Yellow
        opacity=0.8,
        hovertemplate='Growth: %{x:.3%}<br>Density: %{y}<extra></extra>'
    ),
    row=2, col=1
)

# Add KDE curve
from scipy import stats
kde = stats.gaussian_kde(growth_rates)
x_kde = np.linspace(growth_rates.min(), growth_rates.max(), 200)
y_kde = kde(x_kde)

fig.add_trace(
    go.Scatter(
        x=x_kde,
        y=y_kde,
        mode='lines',
        line=dict(width=2, color='#E76F51'),  # Terracotta
        name='KDE',
        yaxis="y2",
        hovertemplate='Growth: %{x:.3%}<br>Density: %{y:.3f}<extra></extra>'
    ),
    row=2, col=1
)

# ========== Plot 4: Cumulative Statistics ==========
# Calculate cumulative statistics
bet_numbers = list(range(experiments.shape[1]))
cumulative_mean = experiments.mean(axis=0)
cumulative_std = experiments.std(axis=0)
cumulative_min = experiments.min(axis=0)
cumulative_max = experiments.max(axis=0)

# Fill between mean ± std
fig.add_trace(
    go.Scatter(
        x=bet_numbers + bet_numbers[::-1],
        y=list(cumulative_mean + cumulative_std) + list(cumulative_mean - cumulative_std)[::-1],
        fill='toself',
        fillcolor='rgba(100, 149, 237, 0.2)',
        line=dict(color='rgba(255,255,255,0)'),
        name='Mean ± 1σ',
        showlegend=True,
        hovertemplate='Bet: %{x}<br>Wealth Range<br>extra'
    ),
    row=2, col=2
)

# Add mean line
fig.add_trace(
    go.Scatter(
        x=bet_numbers,
        y=cumulative_mean,
        mode='lines',
        line=dict(width=2, color='#264653'),  # Dark blue-green
        name='Mean',
        hovertemplate='Bet: %{x}<br>Mean: $%{y:,.0f}<extra></extra>'
    ),
    row=2, col=2
)

# Add min/max bounds
fig.add_trace(
    go.Scatter(
        x=bet_numbers,
        y=cumulative_max,
        mode='lines',
        line=dict(width=1, color='rgba(230, 57, 70, 0.5)', dash='dot'),
        name='Maximum',
        hovertemplate='Bet: %{x}<br>Max: $%{y:,.0f}<extra></extra>'
    ),
    row=2, col=2
)

fig.add_trace(
    go.Scatter(
        x=bet_numbers,
        y=cumulative_min,
        mode='lines',
        line=dict(width=1, color='rgba(57, 230, 100, 0.5)', dash='dot'),
        name='Minimum',
        hovertemplate='Bet: %{x}<br>Min: $%{y:,.0f}<extra></extra>'
    ),
    row=2, col=2
)

# ========== Update Layout and Axes ==========
fig.update_layout(
    # title=dict(
    #     text='<b>Kelly Criterion Simulation Analysis</b><br><span style="font-size:14px">Wealth Evolution Across 50 Simulations</span>',
    #     x=0.5,
    #     xanchor='center',
    #     font=dict(size=20, family='Arial, sans-serif')
    # ),
    template='plotly_white',
    height=900,
    showlegend=True,
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1.02,
        xanchor="right",
        x=1
    ),
    hovermode='x unified',
    plot_bgcolor='rgba(245, 247, 250, 0.8)',
    paper_bgcolor='white'
)

# Update axes
fig.update_xaxes(title_text="Bet Number", row=1, col=1)
fig.update_yaxes(title_text="Wealth ($)", tickprefix="$", row=1, col=1)
fig.update_xaxes(title_text="Final Wealth ($)", tickprefix="$", row=1, col=2)
fig.update_yaxes(title_text="Count", row=1, col=2)
fig.update_xaxes(title_text="Growth Rate", tickformat=".1%", row=2, col=1)
fig.update_yaxes(title_text="Count", row=2, col=1)
fig.update_yaxes(title_text="Density", row=2, col=1, secondary_y=True)
fig.update_xaxes(title_text="Bet Number", row=2, col=2)
fig.update_yaxes(title_text="Wealth ($)", tickprefix="$", row=2, col=2)

# Add annotations with summary statistics
summary_text = f"""
<b>Summary Statistics</b><br>
Initial Wealth: ${initial_wealth:,.0f}<br>
Simulations: {sim}<br>
Bets per Simulation: {n}<br>
Mean Final Wealth: ${np.mean(final_wealths):,.0f}<br>
Median Final Wealth: ${np.median(final_wealths):,.0f}<br>
Volatility (Std): ${np.std(final_wealths):,.0f}<br>
Max Drawdown: {((experiments.min(axis=1) / experiments.max(axis=1) - 1).mean()):.1%}
"""

fig.add_annotation(
    text=summary_text,
    xref="paper", yref="paper",
    x=0.02, y=0.98,
    showarrow=False,
    align="left",
    bgcolor="rgba(255, 255, 255, 0.8)",
    bordercolor="black",
    borderwidth=1,
    font=dict(size=11, family="Courier New, monospace")
)

# Add interactive buttons
fig.update_layout(
    updatemenus=[
        dict(
            type="buttons",
            direction="right",
            x=0.5,
            y=1.15,
            xanchor="center",
            showactive=True,
            buttons=list([
                dict(
                    label="Linear Scale",
                    method="update",
                    args=[{"visible": [True]*len(fig.data)}, 
                          {"yaxis.type": "linear", "yaxis2.type": "linear",
                           "yaxis3.type": "linear", "yaxis4.type": "linear"}]
                ),
                dict(
                    label="Log Scale",
                    method="update",
                    args=[{"visible": [True]*len(fig.data)}, 
                          {"yaxis.type": "log", "yaxis2.type": "log",
                           "yaxis3.type": "linear", "yaxis4.type": "log"}]
                )
            ]),
        )
    ]
)

# Show the figure
fig.show()