In [2]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

class RBCModel:
    """Simple Real Business Cycle Model"""
    def __init__(self, alpha=0.36, beta=0.96, delta=0.1, rho=0.918, sigma_e=0.014):
        """Initialize RBC model parameters"""
        self.alpha = alpha      # Capital share
        self.beta = beta        # Discount factor
        self.delta = delta      # Depreciation rate
        self.rho = rho          # Productivity shock persistence
        self.sigma_e = sigma_e  # Shock standard deviation

        # Compute steady state
        self._compute_steady_state()

    def _compute_steady_state(self):
        """Calculate model's steady state values"""
        r = 1/self.beta - 1 + self.delta
        self.k_y_ratio = self.alpha / r
        self.z_ss = 1.0  # Productivity steady state
        self.n_ss = 0.33  # Labor steady state (approximated)
        self.y_ss = (self.k_y_ratio ** (self.alpha / (1 - self.alpha))) * self.n_ss
        self.k_ss = self.k_y_ratio * self.y_ss
        self.c_ss = self.y_ss - self.delta * self.k_ss

    def production(self, z, k, n):
        """Cobb-Douglas production function"""
        return z * (k ** self.alpha) * (n ** (1 - self.alpha))

class PolicyNetwork(nn.Module):
    """Neural network to approximate consumption policy"""
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.network = nn.Sequential(
            nn.Linear(2, 16),  # Input: capital, productivity
            nn.Sigmoid(),
            nn.Linear(16, 16),
            nn.Sigmoid(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        """Compute consumption share"""
        return self.network(x)

    def solve_model(self, k, z):
        """Solve for consumption and labor given state"""
        # Consumption share
        phi = self(torch.cat([k, z], dim=1))
        
        # Labor from first-order condition
        val1 = self.model.alpha * (1 - 0.66)
        val2 = (1 - 0.66) * phi
        n = val1 / (val1 + val2)
        
        # Output and consumption
        y = self.model.production(z, k, n)
        c = phi * y
        
        return c, n

def train_neural_network(model, epochs=5000, lr=1e-3):
    """Train neural network to solve RBC model"""
    # Initialize network and optimizer
    network = PolicyNetwork(model)
    optimizer = torch.optim.Adam(network.parameters(), lr=lr)
    
    # Training parameters
    T = 1000  # Number of samples
    
    for epoch in range(epochs):
        # Generate random states
        k = torch.normal(mean=model.k_ss, std=0.1*model.k_ss, size=(T, 1))
        z = torch.ones_like(k) * model.z_ss
        
        # Draw innovation shocks
        eps1 = torch.normal(mean=0, std=model.sigma_e, size=(T, 1))
        eps2 = torch.normal(mean=0, std=model.sigma_e, size=(T, 1))
        
        # Clear gradients
        optimizer.zero_grad()
        
        # Compute current and next period values
        c, n = network.solve_model(k, z)
        
        # Compute next period states
        k_next = (1 - model.delta) * k + (model.production(z, k, n) - c)
        z_next1 = torch.exp(torch.log(z) * model.rho + eps1)
        z_next2 = torch.exp(torch.log(z) * model.rho + eps2)
        
        # Compute next period decisions
        c_next1, _ = network.solve_model(k_next, z_next1)
        c_next2, _ = network.solve_model(k_next, z_next2)
        
        # Euler equation residuals
        rhs1 = model.beta * (c / c_next1) * (model.alpha * model.production(z_next1, k_next, n) / k_next + 1 - model.delta)
        rhs2 = model.beta * (c / c_next2) * (model.alpha * model.production(z_next2, k_next, n) / k_next + 1 - model.delta)
        
        # Loss function (All-in-One expectation)
        loss = torch.mean((rhs1 - 1.0)**2 * (rhs2 - 1.0)**2)
        
        # Backpropagate
        loss.backward()
        optimizer.step()
        
        # Print progress
        if epoch % 500 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item():.6f}')
    
    return network

import plotly.graph_objs as go
from plotly.subplots import make_subplots

def plot_policy_functions_plotly(model, network):
    """Plot policy functions using Plotly"""
    try:
        # Generate grid of capital values
        k_grid = torch.linspace(model.k_ss * 0.8, model.k_ss * 1.2, 100).reshape(-1, 1)
        z = torch.ones_like(k_grid) * model.z_ss
        
        # Compute policy functions
        with torch.no_grad():
            c_nn, n_nn = network.solve_model(k_grid, z)
        
        # Convert to numpy for plotting
        k_grid_np = k_grid.numpy().flatten()
        c_nn_np = c_nn.numpy().flatten()
        n_nn_np = n_nn.numpy().flatten()
        
        # Create subplot
        fig = make_subplots(rows=1, cols=2, 
                             subplot_titles=('Consumption Policy', 'Labor Policy'))
        
        # Consumption plot
        fig.add_trace(
            go.Scatter(x=k_grid_np, y=c_nn_np, 
                       name='Consumption',
                       line=dict(color='blue')),
            row=1, col=1
        )
        
        # Steady state line for consumption
        fig.add_shape(
            type="line",
            x0=model.k_ss, y0=min(c_nn_np), 
            x1=model.k_ss, y1=max(c_nn_np),
            line=dict(color="Red", width=2, dash="dash"),
            row=1, col=1
        )
        
        # Labor plot
        fig.add_trace(
            go.Scatter(x=k_grid_np, y=n_nn_np, 
                       name='Labor',
                       line=dict(color='green')),
            row=1, col=2
        )
        
        # Steady state line for labor
        fig.add_shape(
            type="line",
            x0=model.k_ss, y0=min(n_nn_np), 
            x1=model.k_ss, y1=max(n_nn_np),
            line=dict(color="Red", width=2, dash="dash"),
            row=1, col=2
        )
        
        # Update layout
        fig.update_layout(
            title='RBC Model Policy Functions',
            height=500, width=1000,
            showlegend=True
        )
        
        # Update x and y axis labels
        fig.update_xaxes(title_text="Capital", row=1, col=1)
        fig.update_xaxes(title_text="Capital", row=1, col=2)
        fig.update_yaxes(title_text="Consumption", row=1, col=1)
        fig.update_yaxes(title_text="Labor", row=1, col=2)
        
        # Save the plot
        fig.write_html("policy_functions.html")
        print("Interactive plot saved as policy_functions.html")
        
        return fig
    
    except Exception as e:
        print(f"Error in Plotly plotting: {e}")
        return None

# Note: To use this, you'll need to install plotly
# pip install plotly

def main():
    # Create RBC model
    model = RBCModel()
    
    # Train neural network
    network = train_neural_network(model)
    
    # Plot policy functions
    plot_policy_functions_plotly(model, network)

if __name__ == "__main__":
    main()

Epoch 0, Loss: 0.000001
Epoch 500, Loss: 0.000000
Epoch 1000, Loss: 0.000000
Epoch 1500, Loss: 0.000000
Epoch 2000, Loss: 0.000000
Epoch 2500, Loss: 0.000000
Epoch 3000, Loss: 0.000000
Epoch 3500, Loss: 0.000000
Epoch 4000, Loss: 0.000000
Epoch 4500, Loss: 0.000000
Interactive plot saved as policy_functions.html
