# %% [markdown]
# # Trading Agent Analysis
# 
# Analyzing the performance of our RL trading agent

# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# loading results
returns = np.load('results/returns.npy')
prices = np.load('results/prices.npy')

# %% [markdown]
# ## Return Analysis

# %%
# Plot and analyze returns
plt.figure(figsize=(12, 6))
plt.plot(returns)
plt.title('Raw Returns Over Time')
plt.xlabel('Episode')
plt.ylabel('Return')
plt.show()

# Calculate statistics
print(f"Mean return: {np.mean(returns):,.2f}")
print(f"Std deviation: {np.std(returns):,.2f}")
print(f"Max return: {np.max(returns):,.2f}")
print(f"Min return: {np.min(returns):,.2f}")

# %% [markdown]
# ## Trading Strategy Analysis

# %%
# Analyze trading patterns
window = 50
returns_series = pd.Series(returns)
rolling_mean = returns_series.rolling(window=window).mean()
rolling_std = returns_series.rolling(window=window).std()

plt.figure(figsize=(12, 6))
plt.plot(rolling_mean, label='Mean Return')
plt.fill_between(range(len(returns)),
                 rolling_mean - rolling_std,
                 rolling_mean + rolling_std,
                 alpha=0.2)
plt.title(f'Learning Progress ({window}-Episode Window)')
plt.legend()
plt.show()

# %% [markdown]
# ## Price Impact Analysis

# %%
# Analyze price changes
price_changes = np.diff(prices, axis=0)
plt.figure(figsize=(12, 6))
for i in range(price_changes.shape[1]):
    sns.histplot(price_changes[:, i], label=f'Resource {i}', alpha=0.6)
plt.title('Distribution of Price Changes by Resource')
plt.xlabel('Price Change')
plt.legend()
plt.show()

# Calculate correlation matrix
corr = np.corrcoef(price_changes.T)
plt.figure(figsize=(8, 6))
sns.heatmap(corr, annot=True, cmap='coolwarm')
plt.title('Price Change Correlations')
plt.show()