In [2]:
import numpy as np
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def pinball_loss(predictions, targets, quantile):
    diff = targets - predictions
    loss = torch.maximum(quantile * diff, (quantile - 1) * diff)
    return loss.mean()

def compute_quantile(values, quantile):
    return np.quantile(values, quantile)

def train_and_calibrate(dataset, model, optimizer, quantiles, epochs, split_ratio=0.8):
    data_size = len(dataset)
    split_idx = int(split_ratio * data_size)
    train_set, calib_set = torch.utils.data.random_split(dataset, [split_idx, data_size - split_idx])
    train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
    calib_loader = DataLoader(calib_set, batch_size=32, shuffle=False)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = sum(pinball_loss(outputs[:, i], targets, q) for i, q in enumerate(quantiles))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss:.4f}")

    model.eval()
    residuals = []
    with torch.no_grad():
        for inputs, targets in calib_loader:
            predictions = model(inputs).mean(dim=1)
            residuals.extend(torch.abs(targets - predictions).numpy())
    
    residuals = np.array(residuals)
    alpha = 0.1
    quantile_threshold = compute_quantile(residuals, 1 - alpha)
    
    return model, quantile_threshold

def plot_results(real_outputs, predictions, quantile_threshold):
    lower_bound = predictions - quantile_threshold
    upper_bound = predictions + quantile_threshold
    x = np.arange(len(real_outputs))
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=real_outputs, mode='lines', name='Real Outputs', line=dict(color='blue')))
    fig.add_trace(go.Scatter(x=x, y=lower_bound, mode='lines', name='Lower Bound', line=dict(color='red', dash='dash')))
    fig.add_trace(go.Scatter(x=x, y=upper_bound, mode='lines', name='Upper Bound', line=dict(color='green', dash='dash')))
    fig.add_trace(go.Scatter(x=np.concatenate([x, x[::-1]]),
                             y=np.concatenate([upper_bound, lower_bound[::-1]]),
                             fill='toself', fillcolor='rgba(128, 128, 128, 0.2)',
                             line=dict(color='rgba(255,255,255,0)'),
                             name='Prediction Interval'))
    fig.update_layout(title='Conformalized Quantile Regression Results',
                      xaxis_title='Sample Index',
                      yaxis_title='Value',
                      legend=dict(x=0, y=1))
    fig.show()

if __name__ == "__main__":
    np.random.seed(42)
    torch.manual_seed(42)
    num_samples = 100
    inputs = np.linspace(0, 10, num_samples).reshape(-1, 1)
    outputs = np.sin(inputs) + 0.1 * np.random.normal(size=inputs.shape)
    inputs_tensor = torch.tensor(inputs, dtype=torch.float32)
    outputs_tensor = torch.tensor(outputs, dtype=torch.float32)
    dataset = TensorDataset(inputs_tensor, outputs_tensor)
    
    class QuantileModel(nn.Module):
        def __init__(self, input_dim, quantiles):
            super(QuantileModel, self).__init__()
            self.network = nn.Sequential(
                nn.Linear(input_dim, 64),
                nn.Tanh(),
                nn.Linear(64, len(quantiles))
            )
        def forward(self, x):
            return self.network(x)

    quantiles = [0.1, 0.9]
    model = QuantileModel(input_dim=1, quantiles=quantiles)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    epochs = 100
    model, quantile_threshold = train_and_calibrate(dataset, model, optimizer, quantiles, epochs)
    model.eval()
    with torch.no_grad():
        predictions = model(inputs_tensor).mean(dim=1).numpy()
    plot_results(outputs.flatten(), predictions, quantile_threshold)


Epoch 1/100, Loss: 1.6744
Epoch 2/100, Loss: 1.0421
Epoch 3/100, Loss: 1.2489
Epoch 4/100, Loss: 1.2504
Epoch 5/100, Loss: 1.0762
Epoch 6/100, Loss: 0.8412
Epoch 7/100, Loss: 0.6894
Epoch 8/100, Loss: 0.8070
Epoch 9/100, Loss: 0.6789
Epoch 10/100, Loss: 0.6924
Epoch 11/100, Loss: 0.6903
Epoch 12/100, Loss: 0.6478
Epoch 13/100, Loss: 0.6785
Epoch 14/100, Loss: 0.6424
Epoch 15/100, Loss: 0.6420
Epoch 16/100, Loss: 0.6382
Epoch 17/100, Loss: 0.6199
Epoch 18/100, Loss: 0.6341
Epoch 19/100, Loss: 0.6181
Epoch 20/100, Loss: 0.6263
Epoch 21/100, Loss: 0.6108
Epoch 22/100, Loss: 0.6023
Epoch 23/100, Loss: 0.6022
Epoch 24/100, Loss: 0.5951
Epoch 25/100, Loss: 0.6081
Epoch 26/100, Loss: 0.6121
Epoch 27/100, Loss: 0.6183
Epoch 28/100, Loss: 0.6330
Epoch 29/100, Loss: 0.5982
Epoch 30/100, Loss: 0.6097
Epoch 31/100, Loss: 0.5972
Epoch 32/100, Loss: 0.6034
Epoch 33/100, Loss: 0.6118
Epoch 34/100, Loss: 0.6092
Epoch 35/100, Loss: 0.6261
Epoch 36/100, Loss: 0.6152
Epoch 37/100, Loss: 0.6327
Epoch 38/1