<a href="https://colab.research.google.com/github/falseywinchnet/AI_STUFF/blob/main/Lorenz_RCO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
from torch.optim.optimizer import Optimizer
import torch
import math

class RCO_Batching(Optimizer):
    """
    Runge-Kutta-Chebyshev Optimizer (RCO) with batching support, featuring self-decaying learning rate
    batching necessitates lowering the learn rate- this may not beat adam
    it also required clamping, depending on the problem.. YRMV
    """
    def __init__(self, model, max_batch_size=None):
        self.model = model
        self.max_batch_size = max_batch_size
        self.scaling_factor = 0.9

        # Pre-compute constants
        pi = torch.tensor(math.pi)
        self.w1 = self.w2 = 4/3
        self.w3 = self.w4 =  4/3

        # Trig constants
        self.cos_3pi8 = torch.cos(3*pi/8)
        self.cos_pi8 = torch.cos(pi/8)

    def compute_loss(self, x, y):
        y_pred = self.model(x)
        return torch.mean((y_pred - y)**2)

    def _split_batch(self, x, y):
        """Split large batches into smaller ones if needed"""
        if self.max_batch_size is None or x.size(0) <= self.max_batch_size:
            return [(x, y)]

        num_splits = (x.size(0) + self.max_batch_size - 1) // self.max_batch_size
        x_splits = torch.split(x, self.max_batch_size)
        y_splits = torch.split(y, self.max_batch_size)
        return list(zip(x_splits, y_splits))

    def step(self, x, y):
        """
        Performs a single optimization step using combined RK4-Chebyshev method.
        """
        # Initial k1
        loss = self.compute_loss(x, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # Fixed line

        k1 = [p.grad.clone() for p in self.model.parameters()]

        # Store original params
        orig_params = [p.data.clone() for p in self.model.parameters()]

        # RK4 steps
        for p, k in zip(self.model.parameters(), k1):
            p.data -= k * ((6/  x.size(0)))
        loss = self.compute_loss(x, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # Fixed line

        k2_rk4 = [p.grad.clone() for p in self.model.parameters()]

        # Reset and move to k3 position
        for p, orig in zip(self.model.parameters(), orig_params):
            p.data.copy_(orig)
        for p, k in zip(self.model.parameters(), k2_rk4):
            p.data -= k * ((4/  x.size(0)))
        loss = self.compute_loss(x, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # Fixed line

        k3_rk4 = [p.grad.clone() for p in self.model.parameters()]

        # Reset and move to k4
        for p, orig in zip(self.model.parameters(), orig_params):
            p.data.copy_(orig)
        for p, k in zip(self.model.parameters(), k3_rk4):
            p.data -= k * ((12/  x.size(0)))
        loss = self.compute_loss(x, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # Fixed line

        k4_rk4 = [p.grad.clone() for p in self.model.parameters()]

        # Chebyshev steps
        # Reset for Chebyshev
        for p, orig in zip(self.model.parameters(), orig_params):
            p.data.copy_(orig)

        # Chebyshev nodes with learning rate
        c1 = ((12/  x.size(0))) * (1 + self.cos_3pi8)/2
        c2 = ((12/  x.size(0))) * (1 + self.cos_pi8)/2
        c3 = ((12/  x.size(0))) * (1 - self.cos_pi8)/2

        # k2 Cheb
        for p, k in zip(self.model.parameters(), k1):
            p.data -= k * c1
        loss = self.compute_loss(x, y)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # Fixed line

        k2_cheb = [p.grad.clone() for p in self.model.parameters()]

        # k3 Cheb
        for p, orig in zip(self.model.parameters(), orig_params):
            p.data.copy_(orig)
        for p, k in zip(self.model.parameters(), k2_cheb):
            p.data -= k * c2
        loss = self.compute_loss(x, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # Fixed line

        k3_cheb = [p.grad.clone() for p in self.model.parameters()]

        # k4 Cheb
        for p, orig in zip(self.model.parameters(), orig_params):
            p.data.copy_(orig)
        for p, k in zip(self.model.parameters(), k3_cheb):
            p.data -= k * c3
        loss = self.compute_loss(x, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # Fixed line

        k4_cheb = [p.grad.clone() for p in self.model.parameters()]

        # Combine RK4 result
        rk4_update = []
        for k1_p, k2_p, k3_p, k4_p in zip(k1, k2_rk4, k3_rk4, k4_rk4):
            update = (k1_p + 2*k2_p + 3*k3_p + k4_p)/((6))
            rk4_update.append(update)

        # Combine Cheb result
        cheb_update = []
        for k1_p, k2_p, k3_p, k4_p in zip(k1, k2_cheb, k3_cheb, k4_cheb):
            update = (self.w1*k1_p + self.w2*k2_p + self.w3*k3_p + self.w4*k4_p)/(16)
            cheb_update.append(update)

        # Average the two methods and apply final update
        for i, (p, rk4_u, cheb_u) in enumerate(zip(self.model.parameters(), rk4_update, cheb_update)):
            update = (rk4_u + cheb_u)/2
            p.data -=( update * ((12/  x.size(0))) ) *self.scaling_factor
            p.grad.zero_()

        final_loss = self.compute_loss(x, y)
        return final_loss.item()

In [10]:
import torch
import torch.nn as nn
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

# First, let's create our ground truth data generator
class LorenzSystem:
    def __init__(self, sigma=10.0, rho=28.0, beta=8/3):
        self.sigma = sigma
        self.rho = rho
        self.beta = beta

    def derivatives(self, state, t):
        x, y, z = state
        dx = self.sigma * (y - x)
        dy = x * (self.rho - z) - y
        dz = x * y - self.beta * z
        return [dx, dy, dz]

    def generate_trajectory(self, initial_state, t_span, n_points=1000):
        t = np.linspace(t_span[0], t_span[1], n_points)
        trajectory = odeint(self.derivatives, initial_state, t)
        return t, trajectory

# Neural network to predict Lorenz dynamics
class LorenzPredictor(nn.Module):
    def __init__(self, hidden_size=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 3)
        )

    def forward(self, x):
        return self.net(x)

# Generate training data
def generate_training_data(n_trajectories=100, n_points=1000, t_span=(0, 10)):
    lorenz = LorenzSystem()
    all_trajectories = []

    for _ in range(n_trajectories):
        # Random initial conditions
        initial_state = np.random.randn(3) * 0.1
        _, trajectory = lorenz.generate_trajectory(initial_state, t_span, n_points)
        all_trajectories.append(trajectory)

    return np.vstack(all_trajectories)

# Training setup
def prepare_training_data(trajectories, sequence_length=10):
    X, y = [], []
    for i in range(len(trajectories) - sequence_length):
        X.append(trajectories[i])
        y.append(trajectories[i + 1])  # Predict next state
    return torch.FloatTensor(X), torch.FloatTensor(y)

# Training function
# Modify training function to add gradient clipping
def train_model(model, optimizer, train_X, train_y, n_epochs=1, batch_size=128):
    criterion = nn.MSELoss()
    batch_losses = []

    # Move model to GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    for epoch in range(n_epochs):
        # Shuffle data at start of each epoch
        indices = torch.randperm(len(train_X))
        train_X = train_X[indices]
        train_y = train_y[indices]

        for i in range(0, len(train_X), batch_size):
            batch_X = train_X[i:i+batch_size].to(device)
            batch_y = train_y[i:i+batch_size].to(device)

            def closure():
                optimizer.zero_grad()
                pred = model(batch_X)
                loss = criterion(pred, batch_y)
                loss.backward()
                return loss

            # For RCO
            if isinstance(optimizer, RCO_Batching):
                loss = optimizer.step(batch_X, batch_y)
            else:  # For other optimizers like Adam
                loss = optimizer.step(closure)

            batch_losses.append(loss if isinstance(loss, float) else loss.item())

            if i % (batch_size * 10) == 0:
                avg_loss = sum(batch_losses[-10:]) / min(10, len(batch_losses))
                print(f'Epoch {epoch}, Batch {i//batch_size}, Avg Loss: {avg_loss:.6f}')

    return batch_losses



# Main experiment
def run_experiment():
    # Generate data
    trajectories = generate_training_data()
    train_X, train_y = prepare_training_data(trajectories)

    # Models and optimizers
    model_rco = LorenzPredictor().cuda()
    model_adam = LorenzPredictor().cuda()

    optimizer_rco = RCO_Batching(model_rco)
    optimizer_adam = torch.optim.Adam(model_adam.parameters(), lr=1e-3)

    # Train both models
    losses_rco = train_model(model_rco, optimizer_rco, train_X, train_y)
    losses_adam = train_model(model_adam, optimizer_adam, train_X, train_y)

    return losses_rco, losses_adam


def plot_results(losses_rco, losses_adam):
    plt.figure(figsize=(10, 6))
    plt.plot(losses_rco, label='RCO')
    plt.plot(losses_adam, label='Adam')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title('Single Epoch Loss Trajectory: RCO vs Adam')
    plt.legend()
    plt.grid(True)
    plt.show()

In [11]:
losses_rco, losses_adam = run_experiment()
plot_results(losses_rco, losses_adam)

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx