In [None]:
import pandas as pd
import pickle
import numpy as np

log_return_path = r'C:\Users\ahmed\Downloads\ref_log_return (1).pkl'
price_path = r'C:\Users\ahmed\Downloads\ref_price (1).pkl'

with open(log_return_path, 'rb') as f:
    log_returns = pickle.load(f)

with open(price_path, 'rb') as f:
    prices = pickle.load(f)

log_returns = np.array(log_returns)  
prices = np.array(prices)          
log_returns_reshaped = log_returns.reshape(-1, 3)
log_return_df = pd.DataFrame(log_returns_reshaped, columns=['Crypto1_LogRet', 'Crypto2_LogRet', 'Crypto3_LogRet'])

prices_reshaped = prices.reshape(-1, 3)
price_df = pd.DataFrame(prices_reshaped, columns=['Crypto1_Price', 'Crypto2_Price', 'Crypto3_Price'])

log_return_info = {
    "Shape": log_returns.shape,
    "Null Values": np.isnan(log_returns).sum(),
    "Per-Crypto Mean": log_return_df.mean().to_dict(),
    "Per-Crypto Std Dev": log_return_df.std().to_dict(),
    "Min/Max": log_return_df.describe().loc[['min', 'max']].to_dict()
}

price_info = {
    "Shape": prices.shape,
    "Null Values": np.isnan(prices).sum(),
    "Per-Crypto Mean": price_df.mean().to_dict(),
    "Per-Crypto Std Dev": price_df.std().to_dict(),
    "Min/Max": price_df.describe().loc[['min', 'max']].to_dict()
}

(log_return_info, price_info)


In [None]:
import pandas as pd
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt

log_return_path = 'C:/Users/ahmed/Downloads/ref_log_return (1).pkl'
price_path = 'C:/Users/ahmed/Downloads/ref_price (1).pkl'

with open(log_return_path, 'rb') as f:
    log_returns = pickle.load(f)

with open(price_path, 'rb') as f:
    prices = pickle.load(f)

log_returns = np.array(log_returns)  
prices = np.array(prices)           

log_return_shape = log_returns.shape
price_shape = prices.shape

log_return_mean = np.mean(log_returns, axis=(0, 1))
log_return_std = np.std(log_returns, axis=(0, 1))

price_mean = np.mean(prices, axis=(0, 1))
price_std = np.std(prices, axis=(0, 1))

sample_indices = np.random.choice(log_returns.shape[0], size=3, replace=False)

for i in range(3):
    plt.figure()
    for j in range(3):
        plt.plot(log_returns[sample_indices[i], :, j], label=f'Crypto {j+1}')
    plt.title(f'Log-Return Series Sample {i+1}')
    plt.xlabel('Hour')
    plt.ylabel('Log Return')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

for i in range(3):
    print(f"Sample {i+1} initial prices: {prices[sample_indices[i], 0]}")

(log_return_shape, price_shape, log_return_mean, log_return_std, price_mean, price_std)


In [None]:
import pandas as pd
import pickle
import numpy as np

log_return_path = r'C:\Users\ahmed\Downloads\ref_log_return (1).pkl'
price_path = r'C:\Users\ahmed\Downloads\ref_price (1).pkl'

with open(log_return_path, 'rb') as f:
    log_returns = pickle.load(f)

with open(price_path, 'rb') as f:
    prices = pickle.load(f)

log_returns = np.array(log_returns) 
prices = np.array(prices)          

sample_day_index = 0
log_return_df = pd.DataFrame(log_returns[sample_day_index], columns=['Crypto1_LogRet', 'Crypto2_LogRet', 'Crypto3_LogRet'])
print(f"Log Returns for Day {sample_day_index}:\n")
print(log_return_df)

price_row = prices[sample_day_index][0]  # Shape (3,)
price_df = pd.DataFrame([price_row], columns=['Crypto1_Price', 'Crypto2_Price', 'Crypto3_Price'])
print(f"\nInitial Prices for Day {sample_day_index}:\n")
print(price_df)


In [None]:
from sklearn.model_selection import train_test_split

# 1. Clamp extreme outliers to [-0.1, 0.1]
log_returns_clamped = np.clip(log_returns, -0.1, 0.1)

# 2. Standardize (z-score) across all days and cryptos
mean = log_returns_clamped.mean(axis=(0, 1), keepdims=True)
std = log_returns_clamped.std(axis=(0, 1), keepdims=True)
log_returns_norm = (log_returns_clamped - mean) / std

# Save the mean and std for use during generation/evaluation
np.save("log_return_mean.npy", mean)
np.save("log_return_std.npy", std)

# 3. Split into training and validation sets
train_data, val_data = train_test_split(log_returns_norm, test_size=0.2, random_state=42)

print("Train shape:", train_data.shape)  
print("Val shape:", val_data.shape)      


In [None]:
import matplotlib.pyplot as plt

# Plot 5 random samples of log-returns for Crypto1
plt.figure(figsize=(12, 6))
for i in range(5):
    sample_idx = np.random.randint(0, train_data.shape[0])
    plt.plot(train_data[sample_idx, :, 0], label=f'Sample {i+1}')
plt.title('Sample Log-Return Paths for Crypto1 (Normalized)')
plt.xlabel('Hour')
plt.ylabel('Log Return (Standardized)')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
plt.figure(figsize=(10, 5))
plt.hist(train_data.flatten(), bins=100, color='skyblue', edgecolor='k')
plt.title("Histogram of Normalized Log-Returns (All Cryptos)")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.grid(True)
plt.show()


In [None]:
import seaborn as sns

# Reshape for boxplot
reshaped = train_data.reshape(-1, 3)
df_box = pd.DataFrame(reshaped, columns=["Crypto1", "Crypto2", "Crypto3"])

plt.figure(figsize=(8, 6))
sns.boxplot(data=df_box)
plt.title("Boxplot of Normalized Log-Returns per Crypto")
plt.ylabel("Normalized Log-Return")
plt.grid(True)
plt.show()


In [None]:
corr_matrix = df_box.corr()

plt.figure(figsize=(6, 4))
sns.heatmap(corr_matrix, annot=True, cmap="coolwarm", fmt=".2f")
plt.title("Correlation Between Cryptos (Normalized Returns)")
plt.show()


In [None]:
import numpy as np
import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
from scipy.stats import skew, kurtosis

# Use your training data (shape: [days, hours, cryptos])
train_data = train_data  # already normalized log-returns
val_data = val_data      # for comparison

# Flatten across days & hours for each crypto
crypto_series = {
    f"Crypto{i+1}": train_data[:,:,i].reshape(-1) for i in range(train_data.shape[2])
}

def fit_best_arima(series, p_max=3, q_max=3):
    best = None
    for p in range(p_max+1):
        for q in range(q_max+1):
            try:
                model = ARIMA(series, order=(p,0,q))
                res = model.fit()
                if best is None or res.aic < best[0]:
                    best = (res.aic, res)
            except Exception:
                continue
    return best[1]

def simulate_arima(res, T, M=500, burn=50):
    """Simulate M paths length T from fitted ARIMA"""
    sims = []
    for _ in range(M):
        path = res.simulate(T+burn)
        sims.append(np.asarray(path)[burn:])  # discard burn-in
    return np.vstack(sims)  # shape [M, T]

def var_es(x, alpha):
    q = np.quantile(x, alpha)
    es = x[x <= q].mean() if (x <= q).any() else q
    return q, es

results = {}

# Evaluate each crypto
for name, series in crypto_series.items():
    # Fit ARIMA
    res = fit_best_arima(series)

    # Simulate synthetic returns (match validation length)
    T = val_data.shape[1]  # hours per day
    M = val_data.shape[0]  # number of days
    synth = simulate_arima(res, T, M=M)  # shape [M, T]
    flat = synth.ravel()

    # Compute metrics
    metrics = {
        "Skew": skew(flat),
        "Kurtosis": kurtosis(flat, fisher=True)  # excess kurtosis
    }
    for a in [0.01, 0.05, 0.10]:
        q, es = var_es(flat, a)
        metrics[f"VaR@{a}"] = q
        metrics[f"ES@{a}"] = es
    results[name] = metrics

# Make table
arima_df = pd.DataFrame(results).T
print(arima_df)

# Save results
arima_df.to_csv("arima_metrics.csv")


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

class CryptoLogReturnDataset(Dataset):
    def __init__(self, data_array):
        self.data = torch.tensor(data_array, dtype=torch.float32)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create datasets and dataloaders
train_dataset = CryptoLogReturnDataset(train_data)
val_dataset = CryptoLogReturnDataset(val_data)

real_data = train_data  

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Dimensionality parameters
seq_len = train_data.shape[1]   
feature_dim = train_data.shape[2]  
input_dim = seq_len * feature_dim  
noise_dim = 128  


lambda_tail = 5


In [None]:


# === Generator ===
class Generator(nn.Module):
    def __init__(self, noise_dim, output_dim=72):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, output_dim)
        )

    def forward(self, z):
        out = self.model(z)
        noise = 0.05 * torch.randn_like(out) 
        out = out + noise
        return out.view(z.size(0), 24, 3

# === Discriminator ===
class Discriminator(nn.Module):
    def __init__(self, input_dim=72):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 1) 
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  
        return self.model(x)


In [None]:
noise_dim = 128
output_dim = 72 

# === Instantiate models ===
generator = Generator(noise_dim, output_dim).to(device)
discriminator = Discriminator(output_dim).to(device)

# === Loss function ===
criterion = nn.BCEWithLogitsLoss()

# === Optimizers ===
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))



In [None]:
import torch
import torch.nn as nn
import numpy as np
from scipy.stats import skew as sp_skew, kurtosis as sp_kurt


torch.manual_seed(42)
np.random.seed(42)

# --- models ---
generator.train()
discriminator.train()

# --- loss & opt (standard GAN) ---
criterion   = nn.BCEWithLogitsLoss()
g_optimizer = torch.optim.Adam(generator.parameters(),     lr=2e-4, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

#  EMA for smoother sampling/eval ---
use_ema  = True
ema_decay = 0.999
if use_ema:
    G_ema = type(generator) (*generator.__dict__['_modules'].values()) if False else None  # placeholder

    import copy
    G_ema = copy.deepcopy(generator).to(device)
    for p in G_ema.parameters(): p.requires_grad_(False)

@torch.no_grad()
def ema_update(target, source, decay=0.999):
    for p_t, p_s in zip(target.parameters(), source.parameters()):
        p_t.data.mul_(decay).add_(p_s.data, alpha=1.0 - decay)

# --- training ---
EPOCHS = 50
for epoch in range(EPOCHS):
    generator.train(); discriminator.train()
    for real_batch in train_loader:
        real_batch = real_batch.to(device)  # [B, T, A]
        B = real_batch.size(0)


        # 1) Train D

        z = torch.randn(B, noise_dim, device=device)
        with torch.no_grad():
            fake_detach = generator(z)

        real_labels = torch.empty(B, 1, device=device).uniform_(0.9, 1.0)  # label smoothing
        fake_labels = torch.empty(B, 1, device=device).uniform_(0.0, 0.1)

        d_optimizer.zero_grad(set_to_none=True)
        d_real = discriminator(real_batch)
        d_fake = discriminator(fake_detach)
        d_loss = criterion(d_real, real_labels) + criterion(d_fake, fake_labels)
        d_loss.backward()
        d_optimizer.step()


        # 2) Train G (adv only)

        z = torch.randn(B, noise_dim, device=device)
        fake_batch = generator(z)

        g_optimizer.zero_grad(set_to_none=True)
        d_fake = discriminator(fake_batch)
        # try to fool D as "real"
        g_loss = criterion(d_fake, real_labels)
        g_loss.backward()
        nn.utils.clip_grad_norm_(generator.parameters(), max_norm=5.0)  # optional, harmless
        g_optimizer.step()

        if use_ema:
            ema_update(G_ema, generator, ema_decay)

    print(f"Epoch {epoch+1}/{EPOCHS} | D: {d_loss.item():.4f} | G: {g_loss.item():.4f}")

In [None]:
generator.eval()
with torch.no_grad():
    noise = torch.randn(5000, noise_dim).to(device)
    synthetic_batch = generator(noise).cpu().numpy()  # shape: (5000, 24, 3)


In [None]:
synthetic_returns = synthetic_batch.reshape(-1)
real_returns = real_data.reshape(-1)  


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(10, 5))
sns.kdeplot(real_returns, label='Real', bw_adjust=1.5)
sns.kdeplot(synthetic_returns, label='Synthetic (Tail-GAN)', bw_adjust=1.5)
plt.title("Log-Return Distribution: Real vs Tail-GAN")
plt.xlabel("Log Return")
plt.ylabel("Density")
plt.legend()
plt.grid()
plt.show()


In [None]:
def sample_synthetic(G, n_samples=5000, noise_dim=128, device="cpu"):
    G.eval()
    z = torch.randn(n_samples, noise_dim, device=device)
    synth = G(z).detach().cpu().numpy()  
    return synth

def evaluate_gan(real_data, synthetic_data, alpha=0.05, fisher_excess=True):

    def to_flat(x):
        x = np.asarray(x)
        if x.ndim == 1:
            x = x[:, None]
        elif x.ndim == 3:
            x = x.reshape(-1, x.shape[2])
        return x

    real_flat  = to_flat(real_data)
    synth_flat = to_flat(synthetic_data)
    A = real_flat.shape[1]

    print(f"{'Metric':<12} | " + " | ".join([f"Asset {i+1:>2}" for i in range(A)]))
    print("-" * (14 + A*14))

    # VaR
    real_var  = [np.quantile(real_flat[:, i], alpha)  for i in range(A)]
    synth_var = [np.quantile(synth_flat[:, i], alpha) for i in range(A)]
    print("VaR (5%)    | " + " | ".join([f"{v:10.5f}" for v in real_var]))
    print("            | " + " | ".join([f"{v:10.5f}" for v in synth_var]))

    # ES
    real_es  = [real_flat[:, i][real_flat[:, i]  <= real_var[i]].mean()  for i in range(A)]
    synth_es = [synth_flat[:, i][synth_flat[:, i] <= synth_var[i]].mean() for i in range(A)]
    print("ES (5%)     | " + " | ".join([f"{e:10.5f}" for e in real_es]))
    print("            | " + " | ".join([f"{e:10.5f}" for e in synth_es]))

    # Skew
    real_sk = [sp_skew(real_flat[:, i],  bias=False) for i in range(A)]
    syn_sk  = [sp_skew(synth_flat[:, i], bias=False) for i in range(A)]
    print("Skewness    | " + " | ".join([f"{s:10.5f}" for s in real_sk]))
    print("            | " + " | ".join([f"{s:10.5f}" for s in syn_sk]))

    # Kurtosis \
    real_ku = [sp_kurt(real_flat[:, i],  bias=False, fisher=fisher_excess) for i in range(A)]
    syn_ku  = [sp_kurt(synth_flat[:, i], bias=False, fisher=fisher_excess) for i in range(A)]
    print("Kurtosis    | " + " | ".join([f"{k:10.5f}" for k in real_ku]))
    print("            | " + " | ".join([f"{k:10.5f}" for k in syn_ku]))

    return {
        "VaR_real": np.array(real_var),   "VaR_synth": np.array(synth_var),
        "ES_real":  np.array(real_es),    "ES_synth":  np.array(synth_es),
        "Skew_real": np.array(real_sk),   "Skew_synth": np.array(syn_sk),
        "Kurt_real": np.array(real_ku),   "Kurt_synth": np.array(syn_ku),
    }


G_eval = G_ema if use_ema else generator
synthetic = sample_synthetic(G_eval, n_samples=5000, noise_dim=noise_dim, device=device)



_ = evaluate_gan(real_data, synthetic, alpha=0.05, fisher_excess=True)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import spectral_norm
from torch.distributions import StudentT


# Data

class CryptoLogReturnDataset(Dataset):
    def __init__(self, data_array):
        self.data = torch.tensor(data_array, dtype=torch.float32)
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = CryptoLogReturnDataset(train_data)
val_dataset   = CryptoLogReturnDataset(val_data)
train_loader  = DataLoader(train_dataset, batch_size=64, shuffle=True,  drop_last=True)
val_loader    = DataLoader(val_dataset,   batch_size=64, shuffle=False, drop_last=False)

seq_len, feature_dim = train_data.shape[1], train_data.shape[2]   # e.g., 24, 3
noise_dim = 128


In [None]:
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
from torch.distributions import StudentT
import os


# 1) Small, stable TCN blocks

class ResTCNBlock(nn.Module):
    """Residual 1D conv block with dilation; GroupNorm + GELU for small batches."""
    def __init__(self, channels: int, dilation: int = 1, dropout: float = 0.05):
        super().__init__()
        pad = dilation
        self.net = nn.Sequential(
            nn.Conv1d(channels, channels, 3, padding=pad, dilation=dilation),
            nn.GroupNorm(num_groups=min(8, channels), num_channels=channels),
            nn.GELU(),
            nn.Dropout(dropout),

            nn.Conv1d(channels, channels, 3, padding=pad, dilation=dilation),
            nn.GroupNorm(num_groups=min(8, channels), num_channels=channels),
        )
        self.act = nn.GELU()

    def forward(self, x):
        return self.act(x + self.net(x))

class TailGANGenerator(nn.Module):

    def __init__(self, noise_dim=128, seq_len=24, feature_dim=3,
                 hidden_channels=128, num_blocks=4, dropout=0.05):
        super().__init__()
        self.seq_len = seq_len
        self.feature_dim = feature_dim
        self.hidden = hidden_channels

        self.fc = nn.Linear(noise_dim, hidden_channels * seq_len)
        dilations = [1, 2, 4, 8][:num_blocks]
        self.blocks = nn.ModuleList([ResTCNBlock(hidden_channels, d, dropout) for d in dilations])
        self.to_out = nn.Conv1d(hidden_channels, feature_dim, kernel_size=1)
        self.out_scale = nn.Parameter(torch.tensor(1.0))

        self._init()

    def _init(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, z):
        B = z.size(0)
        x = self.fc(z).view(B, self.hidden, self.seq_len)  
        for blk in self.blocks:
            x = blk(x)
        x = self.to_out(x) * self.out_scale                 
        return x.permute(0, 2, 1)                          

class MinibatchStdDev(nn.Module):
    def __init__(self, eps=1e-8): super().__init__(); self.eps = eps
    def forward(self, x):  
        B, C, T = x.shape
        if B < 2: return x
        m = x.mean(dim=0, keepdim=True)
        std = torch.sqrt((x - m).pow(2).mean(dim=0) + self.eps) 
        ch  = std.mean().view(1,1,1).expand(B,1,T)
        return torch.cat([x, ch], dim=1)  

class TailGANDiscriminator(nn.Module):

    def __init__(self, seq_len=24, feature_dim=3, base_channels=64, dropout=0.05):
        super().__init__()
        C = base_channels
        self.conv_in  = spectral_norm(nn.Conv1d(feature_dim, C,   3, padding=1, stride=2))
        self.conv_mid = spectral_norm(nn.Conv1d(C, C,             3, padding=1))
        self.mbstd    = MinibatchStdDev()
        self.conv_out = spectral_norm(nn.Conv1d(C+1, 2*C,         3, padding=1, stride=2))
        self.act  = nn.LeakyReLU(0.2)
        self.drop = nn.Dropout(dropout)
        self.gap  = nn.AdaptiveAvgPool1d(1)
        self.fc   = spectral_norm(nn.Linear(2*C, 1))
        self._init()

    def _init(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, a=0.2)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, x): 
        x = x.permute(0,2,1)                 
        x = self.act(self.conv_in(x))        
        x = self.drop(self.act(self.conv_mid(x)))
        x = self.mbstd(x)                    
        x = self.drop(self.act(self.conv_out(x)))  
        x = self.gap(x).squeeze(-1)          
        return self.fc(x)                    

In [None]:
def expected_shortfall(x, alpha=0.05, dim=0):
    k = max(1, int(alpha * x.size(dim)))
    vals, _ = torch.sort(x, dim=dim)
    return vals.narrow(dim, 0, k).mean(dim=dim)

def es_matching_loss_multi(real, fake, alphas=(0.01,0.05,0.10)):
    # real/fake: [B,T,A] → compare ES per asset over [B*T]
    B,T,A = real.shape
    r = real.reshape(B*T, A); f = fake.reshape(B*T, A)
    losses = []
    for a in alphas:
        es_r = expected_shortfall(r, a, dim=0)
        es_f = expected_shortfall(f, a, dim=0)
        losses.append((es_f - es_r).abs().mean())
    return torch.stack(losses).mean()

def quantile_pinball_loss(y_pred, y_true, q=0.05):
    e = y_true - y_pred
    return torch.mean(torch.maximum(q*e, (q-1)*e))

def pinball_multi(fake, real,
                  qs_lower=(0.01,0.05,0.10), w_lower=(0.6,0.3,0.1),
                  qs_upper=(0.90,),         w_upper=(1.0,)):
    # sample one timestep to align shapes
    B,T,A = real.shape
    idx_t = torch.randint(0, T, (B,), device=real.device)
    r = real[torch.arange(B), idx_t, :]  # [B,A]
    f = fake[torch.arange(B), idx_t, :]  # [B,A]
    low = torch.stack([w*quantile_pinball_loss(f, r, q=q) for q,w in zip(qs_lower, w_lower)]).sum()
    up  = torch.stack([w*quantile_pinball_loss(-f, -r, q=1-q) for q,w in zip(qs_upper, w_upper)]).sum()
    return low, up

def _centered_moments(x):
    m = x.mean(dim=0); xc = x - m
    v = (xc**2).mean(dim=0) + 1e-12
    s = torch.sqrt(v)
    m3 = (xc**3).mean(dim=0)
    m4 = (xc**4).mean(dim=0)
    skew = m3 / (s**3)
    kurt_excess = m4 / (v**2) - 3.0
    return skew, kurt_excess

def _compress_kurt(k):
    # compress very large excess-kurtosis to stabilise optimisation
    return torch.sign(k) * torch.log1p(k.abs())

def moment_matching_losses(real, fake):
    """Return separate skew and kurt losses (L1), robust kurt via log1p."""
    B,T,A = real.shape
    r = real.reshape(B*T, A); f = fake.reshape(B*T, A)
    sr, kr = _centered_moments(r)
    sf, kf = _centered_moments(f)
    loss_skew = (sf - sr).abs().mean()
    loss_kurt = (_compress_kurt(kf) - _compress_kurt(kr)).abs().mean()
    return loss_skew, loss_kurt

In [None]:
G = TailGANGenerator(noise_dim=noise_dim, seq_len=seq_len, feature_dim=feature_dim).to(device)
D = TailGANDiscriminator(seq_len=seq_len, feature_dim=feature_dim).to(device)

criterion = nn.BCEWithLogitsLoss()

# TTUR
opt_D = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))
opt_G = torch.optim.Adam(G.parameters(), lr=1e-4, betas=(0.5,0.999))

# EMA
use_ema = True
ema_decay = 0.999
G_ema = TailGANGenerator(noise_dim=noise_dim, seq_len=seq_len, feature_dim=feature_dim).to(device)
G_ema.load_state_dict(G.state_dict())
for p in G_ema.parameters(): p.requires_grad_(False)

@torch.no_grad()
def ema_update(target, source, decay=ema_decay):
    for p_t, p_s in zip(target.parameters(), source.parameters()):
        p_t.data.mul_(decay).add_(p_s.data, alpha=1-decay)

# Student-t latent (heavier tails)
student_t = StudentT(df=3.0)
def sample_noise(B, dim, device):
    return student_t.sample((B, dim)).to(device)


LAMBDA_ES = 12    
W_PIN_L   = 0.75   
W_PIN_U   = 0.10    
W_SKEW    = 1.0   
W_KURT    = 0.20    

#  early stop
ckpt_dir = "checkpoints"; os.makedirs(ckpt_dir, exist_ok=True)
best_score, best_epoch, no_improve = float("inf"), -1, 0
patience, save_every = 15, 10

EPOCHS = 30 

for epoch in range(EPOCHS):
    G.train(); D.train()
    for real_batch in train_loader:
        real_batch = real_batch.to(device)  # [B,T,A]
        B = real_batch.size(0)

        # ---- D step ----
        z = sample_noise(B, noise_dim, device)
        with torch.no_grad():
            fake_detach = G(z)

        real_labels = torch.empty(B,1,device=device).uniform_(0.9,1.0)
        fake_labels = torch.empty(B,1,device=device).uniform_(0.0,0.1)

        opt_D.zero_grad(set_to_none=True)
        d_real = D(real_batch)
        d_fake = D(fake_detach)
        d_loss = criterion(d_real, real_labels) + criterion(d_fake, fake_labels)
        d_loss.backward(); opt_D.step()

        # ---- G step ----
        z = sample_noise(B, noise_dim, device)
        fake_batch = G(z)

        opt_G.zero_grad(set_to_none=True)
        d_fake = D(fake_batch)
        gan_loss = criterion(d_fake, real_labels)

        es_loss = es_matching_loss_multi(real_batch, fake_batch, alphas=(0.01,0.05,0.10))
        q_low, q_up = pinball_multi(fake_batch, real_batch,
                                    qs_lower=(0.01,0.05,0.10), w_lower=(0.6,0.3,0.1),
                                    qs_upper=(0.90,),         w_upper=(1.0,))
        skew_loss, kurt_loss = moment_matching_losses(real_batch, fake_batch)

        g_loss = (gan_loss
                  + LAMBDA_ES * es_loss
                  + W_PIN_L   * q_low
                  + W_PIN_U   * q_up
                  + W_SKEW    * skew_loss
                  + W_KURT    * kurt_loss)

        g_loss.backward()
        nn.utils.clip_grad_norm_(G.parameters(), max_norm=5.0)
        opt_G.step()

        if use_ema:
            ema_update(G_ema, G)


    current_score = es_loss.item() + 0.5*skew_loss.item() + 0.5*kurt_loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"D: {d_loss.item():.4f} | G: {g_loss.item():.4f} | "
          f"ES: {es_loss.item():.4f} | Qlow: {q_low.item():.4f} | Qup: {q_up.item():.4f} | "
          f"Sk: {skew_loss.item():.4f} | Ku: {kurt_loss.item():.4f}")

    if current_score < best_score - 1e-4:
        best_score, best_epoch, no_improve = current_score, epoch+1, 0
        torch.save((G_ema if use_ema else G).state_dict(), f"{ckpt_dir}/G_best.pt")
        torch.save(D.state_dict(), f"{ckpt_dir}/D_best.pt")
        print(f"✅ New BEST @ epoch {best_epoch} | score={best_score:.4f}")
    else:
        no_improve += 1

    if (epoch + 1) % save_every == 0:
        torch.save((G_ema if use_ema else G).state_dict(), f"{ckpt_dir}/G_epoch{epoch+1}.pt")

    if no_improve >= patience:
        print(f"⏹ Early stopping at epoch {epoch+1}. Best was {best_epoch} (score={best_score:.4f}).")
        break

print(f"Done. Best epoch: {best_epoch} with score={best_score:.4f}")

In [None]:
import numpy as np
import torch
from scipy.stats import skew as sp_skew, kurtosis as sp_kurt

@torch.no_grad()
def _sample_noise(n_samples, noise_dim, device, latent="student_t", df=3.0):
    if latent.lower() == "student_t":
        from torch.distributions import StudentT
        return StudentT(df=df).sample((n_samples, noise_dim)).to(device)
    return torch.randn(n_samples, noise_dim, device=device)

def _to_numpy_nta(x):

    if isinstance(x, torch.Tensor):
        x = x.detach().cpu().numpy()
    x = np.asarray(x)
    assert x.ndim == 3, f"Expected [N,T,A], got shape {x.shape}"
    return x

@torch.no_grad()
def evaluate_gan_tail_metrics(
    generator,           
    real_data,           
    noise_dim=128,
    n_samples=4096,
    device="cpu",
    alpha=0.05,
    latent="student_t",  
    df=3.0,
    fisher_excess=True,  
    asset_names=None,    
    print_table=True
):
    """
    Samples synthetic paths from `generator` and compares per-asset VaR/ES/Skew/Kurt
    against `real_data`. Also prints coverage: P[synth <= VaR_real].
    """
    generator.eval()

    # --- data to numpy ---
    real_np = _to_numpy_nta(real_data)           
    N, T, A = real_np.shape

    # --- sample synthetic ---
    z = _sample_noise(n_samples, noise_dim, device, latent=latent, df=df)
    synth_np = generator(z).detach().cpu().numpy()   

    real_flat  = real_np.reshape(-1, A)           
    synth_flat = synth_np.reshape(-1, A)            

    # --- compute stats ---
    real_var  = np.array([np.quantile(real_flat[:, i],  alpha) for i in range(A)])
    synth_var = np.array([np.quantile(synth_flat[:, i], alpha) for i in range(A)])

    real_es = np.array([
        real_flat[:, i][real_flat[:, i] <= real_var[i]].mean()  for i in range(A)
    ])
    synth_es = np.array([
        synth_flat[:, i][synth_flat[:, i] <= synth_var[i]].mean() for i in range(A)
    ])

    real_sk = np.array([sp_skew(real_flat[:, i],  bias=False) for i in range(A)])
    syn_sk  = np.array([sp_skew(synth_flat[:, i], bias=False) for i in range(A)])

    real_ku = np.array([sp_kurt(real_flat[:, i],  bias=False, fisher=fisher_excess) for i in range(A)])
    syn_ku  = np.array([sp_kurt(synth_flat[:, i], bias=False, fisher=fisher_excess) for i in range(A)])

  
    coverage = np.array([
        (synth_flat[:, i] <= real_var[i]).mean() for i in range(A)
    ])


    if print_table:
        cols = [f"Asset {i+1:>2}" for i in range(A)] if asset_names is None else [f"{n}" for n in asset_names]
        print(f"{'Metric':<12} | " + " | ".join([f"{c:>12}" for c in cols]))
        print("-" * (14 + A*15))

        def row(label, arr): return f"{label:<12} | " + " | ".join([f"{v:12.5f}" for v in arr])

        print(row("VaR_real",  real_var))
        print(row("VaR_synth", synth_var))
        print(row("ES_real",   real_es))
        print(row("ES_synth",  synth_es))
        print(row("Skew_real", real_sk))
        print(row("Skew_synth",syn_sk))
        print(row("Kurt_real", real_ku))
        print(row("Kurt_synth",syn_ku))
        print(row("Coverage@", np.full(A, alpha))) 
        print(row("Cov_synth", coverage))

    return {
        "VaR_real":  real_var,   "VaR_synth":  synth_var,
        "ES_real":   real_es,    "ES_synth":   synth_es,
        "Skew_real": real_sk,    "Skew_synth": syn_sk,
        "Kurt_real": real_ku,    "Kurt_synth": syn_ku,
        "Coverage_realVaR": coverage,         
        "synthetic": synth_np                  
    }



metrics = evaluate_gan_tail_metrics(
    generator=G_ema if 'G_ema' in globals() else generator,
    real_data=real_data,
    noise_dim=noise_dim,
    n_samples=4096,
    device=device,
    alpha=0.05,
    latent="student_t",   
    df=3.0,
    fisher_excess=True,
    asset_names=[f"Asset {i+1}" for i in range(real_data.shape[2])],
    print_table=True
)


In [None]:
import numpy as np, torch, matplotlib.pyplot as plt

# ---------- auto-pick REAL & SYNTHETIC ----------
def _as_np(x): return x if isinstance(x, np.ndarray) else x.detach().cpu().numpy()

g = globals()

for key in ["real_data", "val_data", "train_data", "real_returns"]:
    if key in g:
        real = _as_np(g[key]); break
else:
    raise NameError("Define one of real_data / val_data / train_data / real_returns (shape [N,T,A]).")

for key in ["synthetic", "synthetic_returns", "synthetic_batch", "vanilla_synth", "cgan_synth"]:
    if key in g:
        synthetic = _as_np(g[key]); break
else:
    if "metrics" in g and isinstance(g["metrics"], dict) and "synthetic" in g["metrics"]:
        synthetic = _as_np(g["metrics"]["synthetic"])
    else:
        G = g.get("generator_ema") or g.get("G_ema") or g.get("generator") or g.get("G")
        if G is None: raise NameError("No synthetic data found and no generator in memory.")
        noise_dim = g.get("noise_dim", 128)
        device = g.get("device", "cuda" if torch.cuda.is_available() else "cpu")
        N, T, A = real.shape
        G = G.to(device).eval()
        with torch.no_grad():
            z = torch.randn(N, noise_dim, device=device)
            synthetic = G(z).cpu().numpy()

assert real.ndim == synthetic.ndim == 3 and real.shape[2] == synthetic.shape[2], "Shapes must be [N,T,A] with same A."

asset = 0  

r = real[:, :, asset].reshape(-1)
s = synthetic[:, :, asset].reshape(-1)

# 1) Lower-tail Q–Q (0–10%)
qs = np.linspace(0.0, 0.10, 40)
rq = np.quantile(r, qs)
sq = np.quantile(s, qs)

plt.figure(figsize=(5.2,5.2))
plt.plot(rq, sq, lw=2, label="Synthetic")
plt.plot(rq, rq, "--", lw=2, label="Real (y=x)")
plt.grid(alpha=0.4, linestyle="--")
plt.title(f"Lower-tail Q–Q (0–10%) — Asset {asset+1}")
plt.xlabel("Real quantiles"); plt.ylabel("Synthetic quantiles")
plt.legend(); plt.tight_layout(); plt.show()

# 2) ES(α) curve (1–10%)
def _es(x, a):
    q = np.quantile(x, a)
    return x[x <= q].mean()

alphas = np.linspace(0.01, 0.10, 30)
es_real = np.array([_es(r, a) for a in alphas])
es_syn  = np.array([_es(s, a) for a in alphas])

plt.figure(figsize=(7,4))
plt.plot(alphas*100, es_real, lw=2, label="Real")
plt.plot(alphas*100, es_syn,  lw=2, label="Synthetic")
plt.title(f"ES(α) Curve — Asset {asset+1}")
plt.xlabel("Alpha (%)"); plt.ylabel("ES(α)")
plt.legend(); plt.tight_layout(); plt.show()


In [None]:

labels = []
for i in range(len(train_data)):
    asset_idx = i % 3 
    one_hot = [0]*3
    one_hot[asset_idx] = 1
    labels.append(one_hot)
labels = np.array(labels)


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import spectral_norm


# Config / Hyperparameter

DEVICE        = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE    = 64
EPOCHS        = 100
NOISE_DIM     = 32
ALPHA         = 0.05          
LAMBDA_TAIL   = 10.0          
PINBALL_W     = 0.5          
WINDOW_W      = 12            
USE_EMA       = True
EMA_DECAY     = 0.999
LR_D, LR_G    = 2e-4, 1e-4    



N, T, A = train_data.shape
assert T >= WINDOW_W, "Window length WINDOW_W must be <= sequence length T."



def build_condition(data, W):

    cond_window = data[:, :W, :]                      
    cond_vol    = cond_window.std(axis=1)             
    cond = np.concatenate(
        [cond_window.reshape(data.shape[0], -1), cond_vol],
        axis=1
    )                                                 
    # Normalize condition features 
    mean = cond.mean(axis=0, keepdims=True)
    std  = cond.std(axis=0, keepdims=True) + 1e-8
    cond_norm = (cond - mean) / std
    return cond_norm.astype(np.float32)

train_cond = build_condition(train_data, WINDOW_W)     
val_cond   = build_condition(val_data, WINDOW_W)

COND_DIM = train_cond.shape[1]
INPUT_DIM = T * A                                     

In [None]:
# Dataset & DataLoader
# -----------------------------
class ConditionalCryptoDataset(Dataset):
    def __init__(self, data_array, condition_array):
        self.data = torch.tensor(data_array, dtype=torch.float32)
        self.cond = torch.tensor(condition_array, dtype=torch.float32)
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx], self.cond[idx]

train_ds = ConditionalCryptoDataset(train_data, train_cond)
val_ds   = ConditionalCryptoDataset(val_data,   val_cond)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

In [None]:
# Minibatch-StdDev trick (FC)
# -----------------------------
class MinibatchStdDev(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.eps = eps
    def forward(self, x):  
        B = x.size(0)
        if B < 2:
            return x
        mean = x.mean(dim=0, keepdim=True)
        var  = (x - mean).pow(2).mean(dim=0, keepdim=True)
        std  = torch.sqrt(var + self.eps)         
        std_feat = std.mean().expand(B, 1)        
        return torch.cat([x, std_feat], dim=1)   

# -----------------------------
# Conditional Generator / Discriminator
# -----------------------------
class ConditionalGenerator(nn.Module):
    def __init__(self, noise_dim, cond_dim, output_dim):
        super().__init__()
        in_dim = noise_dim + cond_dim
        self.net = nn.Sequential(
            nn.Linear(in_dim, 256), nn.LeakyReLU(0.2),
            nn.Linear(256, 512),    nn.LeakyReLU(0.2),
            nn.Linear(512, 256),    nn.LeakyReLU(0.2),
            nn.Linear(256, output_dim)
        )
        self.T = T
        self.A = A
    def forward(self, z, c):                  
        x = torch.cat([z, c], dim=1)
        out = self.net(x)                     
        return out.view(z.size(0), self.T, self.A)  

class ConditionalDiscriminator(nn.Module):
    def __init__(self, input_dim, cond_dim):
        super().__init__()
        in_dim = input_dim + cond_dim
        self.feat = nn.Sequential(
            spectral_norm(nn.Linear(in_dim, 128)), nn.LeakyReLU(0.2),
            spectral_norm(nn.Linear(128, 64)),     nn.LeakyReLU(0.2),
        )
        self.mbstd = MinibatchStdDev()
        self.out   = spectral_norm(nn.Linear(64+1, 1))  
    def forward(self, x, c):                 
        x = x.view(x.size(0), -1)            
        h = self.feat(torch.cat([x, c], dim=1))  
        h = self.mbstd(h)                         
        return self.out(h)                       

G = ConditionalGenerator(NOISE_DIM, COND_DIM, INPUT_DIM).to(DEVICE)
D = ConditionalDiscriminator(INPUT_DIM, COND_DIM).to(DEVICE)

In [None]:

def expected_shortfall(x, alpha=ALPHA, dim=0):
    
    k = max(1, int(alpha * x.size(dim)))
    vals, _ = torch.sort(x, dim=dim)
    return vals.narrow(dim, 0, k).mean(dim=dim)

def es_matching_loss(real, fake, alpha=ALPHA):
  
    B, Tt, Aa = real.shape
    r = real.reshape(B*Tt, Aa)
    f = fake.reshape(B*Tt, Aa)
    es_r = expected_shortfall(r, alpha=alpha, dim=0)
    es_f = expected_shortfall(f, alpha=alpha, dim=0) 
    return torch.mean((es_f - es_r).abs())

def quantile_pinball_loss(y_pred, y_true, q=0.05):
    e = y_true - y_pred
    return torch.mean(torch.maximum(q*e, (q-1)*e))


# Optimizers, loss, EMA

criterion = nn.BCEWithLogitsLoss()
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D, betas=(0.5, 0.999))
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G, betas=(0.5, 0.999))

if USE_EMA:
    G_ema = ConditionalGenerator(NOISE_DIM, COND_DIM, INPUT_DIM).to(DEVICE)
    G_ema.load_state_dict(G.state_dict())
    @torch.no_grad()
    def ema_update(target, source, decay=EMA_DECAY):
        for p_t, p_s in zip(target.parameters(), source.parameters()):
            p_t.data.mul_(decay).add_(p_s.data, alpha=1-decay)


In [None]:
for epoch in range(EPOCHS):
    G.train(); D.train()
    for real_batch, cond in train_loader:
        real_batch = real_batch.to(DEVICE)   
        cond       = cond.to(DEVICE)       
        B = real_batch.size(0)

        # ======== Train D ========
        z = torch.randn(B, NOISE_DIM, device=DEVICE)
        with torch.no_grad():
            fake_batch = G(z, cond)

        # label smoothing
        real_labels = torch.empty(B, 1, device=DEVICE).uniform_(0.9, 1.0)
        fake_labels = torch.empty(B, 1, device=DEVICE).uniform_(0.0, 0.1)

        opt_D.zero_grad(set_to_none=True)
        d_real = D(real_batch, cond)
        d_fake = D(fake_batch, cond)
        d_loss = criterion(d_real, real_labels) + criterion(d_fake, fake_labels)
        d_loss.backward()
        opt_D.step()

        # ======== Train G (GAN + Tail) ========
        z = torch.randn(B, NOISE_DIM, device=DEVICE)
        fake_batch = G(z, cond)

        opt_G.zero_grad(set_to_none=True)
        d_fake = D(fake_batch, cond)
        gan_loss = criterion(d_fake, real_labels)  

        es_loss = es_matching_loss(real_batch, fake_batch, alpha=ALPHA)

        idx_t   = torch.randint(0, T, (B,), device=DEVICE)
        r_samp  = real_batch[torch.arange(B), idx_t, :] 
        f_samp  = fake_batch[torch.arange(B), idx_t, :] 
        q05_loss = quantile_pinball_loss(f_samp, r_samp, q=0.05)

        g_loss = gan_loss + LAMBDA_TAIL * es_loss + PINBALL_W * q05_loss
        g_loss.backward()
        nn.utils.clip_grad_norm_(G.parameters(), max_norm=5.0)  
        opt_G.step()

        if USE_EMA:
            ema_update(G_ema, G, EMA_DECAY)

    print(f"Epoch {epoch+1}/{EPOCHS} | D: {d_loss.item():.4f} | G: {g_loss.item():.4f} | ES: {es_loss.item():.4f}")


In [None]:
import numpy as np, torch
from scipy.stats import skew, kurtosis

@torch.no_grad()
def evaluate_cgan_table(
    generator,          
    val_loader,        
    noise_dim,          
    device,             
    alpha=0.05,
    fisher_excess=True, 
    latent="gaussian",  
    df=3.0              
):
    # ---- gather real + cond ----
    Xs, Cs = [], []
    for x, c in val_loader:
        Xs.append(x.cpu().numpy())
        Cs.append(c.cpu().numpy())
    real_np = np.concatenate(Xs, 0)  
    cond_np = np.concatenate(Cs, 0) 
    N, T, A = real_np.shape

    # ---- generate synthetic with same conditions ----
    generator = generator.to(device).eval()
    synth_chunks = []
    bs = 512
    for i in range(0, N, bs):
        c = torch.tensor(cond_np[i:i+bs], dtype=torch.float32, device=device)
        if latent.lower() == "student_t":
            from torch.distributions import StudentT
            z = StudentT(df=df).sample((c.size(0), noise_dim)).to(device)
        else:
            z = torch.randn(c.size(0), noise_dim, device=device)
        y = generator(z, c).cpu().numpy()   
        synth_chunks.append(y)
    synth_np = np.concatenate(synth_chunks, 0)     

    # ---- flatten over N*T for per-asset stats ----
    real_flat  = real_np.reshape(-1, A)             
    synth_flat = synth_np.reshape(-1, A)

    # ---- print table  ----
    print(f"{'Metric':<12} | " + " | ".join([f"Asset {i+1:>2}" for i in range(A)]))
    print("-" * (14 + A*14))

    # VaR
    real_var  = [np.quantile(real_flat[:, i],  alpha) for i in range(A)]
    synth_var = [np.quantile(synth_flat[:, i], alpha) for i in range(A)]
    print("VaR (5%)    | " + " | ".join([f"{v:10.5f}" for v in real_var]))
    print("            | " + " | ".join([f"{v:10.5f}" for v in synth_var]))

    # ES
    real_es  = [real_flat[:, i][real_flat[:, i]  <= real_var[i]].mean()  for i in range(A)]
    synth_es = [synth_flat[:, i][synth_flat[:, i] <= synth_var[i]].mean() for i in range(A)]
    print("ES (5%)     | " + " | ".join([f"{e:10.5f}" for e in real_es]))
    print("            | " + " | ".join([f"{e:10.5f}" for e in synth_es]))

    # Skewness
    real_sk = [skew(real_flat[:, i],  bias=False) for i in range(A)]
    syn_sk  = [skew(synth_flat[:, i], bias=False) for i in range(A)]
    print("Skewness    | " + " | ".join([f"{s:10.5f}" for s in real_sk]))
    print("            | " + " | ".join([f"{s:10.5f}" for s in syn_sk]))

    # Kurtosis
    real_ku = [kurtosis(real_flat[:, i],  bias=False, fisher=fisher_excess) for i in range(A)]
    syn_ku  = [kurtosis(synth_flat[:, i], bias=False, fisher=fisher_excess) for i in range(A)]
    print("Kurtosis    | " + " | ".join([f"{k:10.5f}" for k in real_ku]))
    print("            | " + " | ".join([f"{k:10.5f}" for k in syn_ku]))

    return {
        "synthetic": synth_np,
        "VaR_real": np.array(real_var),   "VaR_synth": np.array(synth_var),
        "ES_real":  np.array(real_es),    "ES_synth":  np.array(synth_es),
        "Skew_real": np.array(real_sk),   "Skew_synth": np.array(syn_sk),
        "Kurt_real": np.array(real_ku),   "Kurt_synth": np.array(syn_ku),
    }


G_eval = G_ema if 'G_ema' in globals() and G_ema is not None else G

metrics_cgan = evaluate_cgan_table(
    generator=G_eval,
    val_loader=val_loader,  
    noise_dim=NOISE_DIM,
    device=DEVICE,
    alpha=0.05,
    fisher_excess=True,
    latent="gaussian"        
)


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

plt.rcParams.update({"figure.dpi": 140})

# --- helpers ---
@torch.no_grad()
def collect_real_and_cond(loader):
    Xs, Cs = [], []
    for batch in loader:
        if isinstance(batch, (tuple, list)) and len(batch) == 2:
            x, c = batch
        else:
            x, c = batch, torch.zeros(batch.size(0), 1)
        Xs.append(x.cpu().numpy())
        Cs.append(c.cpu().numpy())
    return np.concatenate(Xs, 0), np.concatenate(Cs, 0)  

@torch.no_grad()
def sample_cgan(G, cond_np, noise_dim, device):
    G.eval()
    out = []
    bs = 512
    for i in range(0, cond_np.shape[0], bs):
        c = torch.tensor(cond_np[i:i+bs], dtype=torch.float32, device=device)
        z = torch.randn(c.size(0), noise_dim, device=device)
        y = G(z, c).cpu().numpy()  
        out.append(y)
    return np.concatenate(out, 0)

def momentum_pnl_from_returns(rets):
    # rets: [N, T, A] log-returns
    r = rets[:, 1:, :]
    signal = (rets[:, :-1, :] > 0).astype(int) * 2 - 1  
    return (signal * r).sum(axis=1)  # [N, A]

# --- 1) gather real & cond from validation ---
real_np, cond_np = collect_real_and_cond(val_loader)   
N, T, A = real_np.shape

# --- 2) generate CGAN synthetic with same conditions ---
G_eval = G_ema if "G_ema" in globals() and G_ema is not None else G 
synthetic_np = sample_cgan(G_eval, cond_np, NOISE_DIM, DEVICE)       

# --- 3) compute PnL per sequence ---
real_pnl  = momentum_pnl_from_returns(real_np)     
synth_pnl = momentum_pnl_from_returns(synthetic_np)

# --- 4) simple PnL hist overlays  ---
for a in range(A):
    rp, sp = real_pnl[:, a], synth_pnl[:, a]
    mn, mx = float(min(rp.min(), sp.min())), float(max(rp.max(), sp.max()))
    bins = np.linspace(mn, mx, 80)

    plt.figure(figsize=(7,4))
    plt.hist(rp, bins=bins, alpha=0.6, density=True, label="Real PnL")
    plt.hist(sp, bins=bins, alpha=0.6, density=True, label="CGAN PnL")
    plt.title(f"PnL Distribution — Asset {a+1}")
    plt.xlabel("Strategy PnL"); plt.ylabel("Density"); plt.legend()

    plt.tight_layout(); plt.show()
