In [11]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pickle
import glob
import os
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any, List, Tuple, Optional
from dataclasses import dataclass
import json

class GeneralizedNewtonSchulz(nn.Module):
    def __init__(self, degree: int = 3, num_iterations: int = 5, init_coefficients: Optional[List[float]] = None):
        super().__init__()
        self.degree = degree
        self.num_iterations = num_iterations
        if init_coefficients is None:
            self.coefficients = [nn.Parameter(torch.randn(self.num_polynomial_terms)) for _ in range(self.num_iterations)]
        else:
            self.verify_init_coefficients_shape(init_coefficients)
            self.coefficients = nn.ParameterList([
                nn.Parameter(torch.tensor(layer_coeff) if init_coefficients is not None 
                            else torch.randn(self.num_polynomial_terms))
                for layer_coeff in (init_coefficients or [None] * num_iterations)
            ])
        self.initial_scale = 1.1
    
    @property
    def num_polynomial_terms(self) -> int:
        return (self.degree + 1) // 2

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        X = X / torch.norm(X, p='fro', dim=(-2, -1), keepdim=True) * 1.1

        for layer_idx in range(self.num_iterations):
            XT = X.transpose(-2, -1)
            XTX = torch.bmm(XT, X)
            terms = [X]
            
            for _ in range(self.num_polynomial_terms - 1):
                terms.append(torch.bmm(terms[-1], XTX))
            
            X = sum(coeff * term for coeff, term in zip(self.coefficients[layer_idx], terms))
            
        return X

    def print_polynomial_at_layer(self, layer_idx: int) -> str:
        terms = [f"{coeff:.3f}X(X^TX)^{i}" if i else f"{coeff:.3f}X" 
                for i, coeff in enumerate(self.coefficients[layer_idx].detach().cpu().tolist())
                if abs(coeff) > 1e-6]
        return " + ".join(terms)

    def evaluate_scalar_at_layer(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor:
        powers = torch.tensor([2 * i + 1 for i in range(len(self.coefficients))])
        return sum(coeff * torch.pow(x, power) for coeff, power in zip(self.coefficients[layer_idx], powers))
    
    def evaluate_scalar_across_layers(self, x: torch.Tensor) -> torch.Tensor:
        for layer_idx in range(self.num_iterations):
            x = self.evaluate_scalar_at_layer(x, layer_idx)
        return x

    def verify_init_coefficients_shape(self, initial_coefficients: List[float]):
        assert len(initial_coefficients) == self.num_iterations
        for layer_coeff in initial_coefficients:
            assert len(layer_coeff) == self.num_polynomial_terms

class SingularValuesDataset(Dataset):
    def __init__(self, checkpoint_dir: str, layer_filter: Optional[List[str]] = None):
        """
        Dataset for loading singular values from JSON files.
        
        Args:
            checkpoint_dir: Directory containing JSON files with singular values
            layer_filter: Optional list of layer names to include. If None, all layers are used.
        """
        self.checkpoint_dir = checkpoint_dir
        self.layer_filter = layer_filter
        
        # Get all JSON files in the directory
        file_paths = sorted(
            [os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.endswith('.json')],
            key=lambda x: int(os.path.basename(x).split('.')[0])  # Sort by minibatch index
        )
        
        if not file_paths:
            raise ValueError(f"No JSON files found in {checkpoint_dir}")
        
        # Load all singular values into memory
        self.all_singular_values = []
        
        # Load the first file to get layer names if no filter is provided
        with open(file_paths[0], 'r') as f:
            data = json.load(f)
            if self.layer_filter is None:
                self.layer_filter = list(data.keys())
        
        # Load all files and extract singular values
        print(f"Loading singular values from {len(file_paths)} files...")
        for file_path in file_paths:
            with open(file_path, 'r') as f:
                data = json.load(f)
            
            # Flatten all singular values across layers into a single pool
            for values in data.values():
                if len(values) > 0:  # Only add non-empty lists
                    self.all_singular_values.extend(values)
        
        # Convert to tensor for efficiency
        self.all_singular_values = torch.tensor(self.all_singular_values, dtype=torch.float32)
        print(f"Loaded {len(self.all_singular_values)} singular values")
    
    def __len__(self) -> int:
        return len(self.all_singular_values)
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        """
        Get a single singular value as a matrix for the model.
        We construct a diagonal matrix with the singular value.
        
        Returns:
            torch.Tensor: A tensor of shape (dim, dim) representing a diagonal matrix
        """
        # Create a small diagonal matrix with the singular value
        dim = 10  # Arbitrary dimension, adjust as needed
        sv = self.all_singular_values[idx]
        
        # Create diagonal matrix with the singular value
        matrix = torch.zeros(dim, dim)
        matrix[0, 0] = sv
        
        # Add some small random noise to other elements to avoid exact zeros
        # This makes the problem more realistic while preserving the dominant singular value
        noise = torch.randn(dim, dim) * 1e-6
        matrix = matrix + noise
        
        return matrix

def create_dataloader(checkpoint_dirs: List[str], batch_size: int = 8) -> DataLoader:
    """
    Create a DataLoader for singular values from one or more checkpoint directories.
    
    Args:
        checkpoint_dirs: List of directories containing JSON files with singular values
        batch_size: Batch size for the DataLoader
        
    Returns:
        DataLoader: DataLoader for the SingularValuesDataset
    """
    datasets = []
    
    for checkpoint_dir in checkpoint_dirs:
        datasets.append(SingularValuesDataset(checkpoint_dir))
    
    # Combine all datasets
    if len(datasets) == 1:
        combined_dataset = datasets[0]
    else:
        combined_dataset = torch.utils.data.ConcatDataset(datasets)
    
    return DataLoader(
        combined_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )

def iterate(fn, num_iterations: int = 10):
    def wrapper(x):
        result = x
        for _ in range(num_iterations):
            result = fn(result)
        return result
    return wrapper

def norm_of_xtx_minus_i(X):
    I = torch.eye(X.size(-1)).to(X.device)
    return torch.norm(X.transpose(-2, -1) @ X - I, p='fro')

def derivative_at_zero(model):
    # product of zeroth element of each layer's coefficients
    return torch.prod(torch.stack([coeff[0] for coeff in model.coefficients]), dim=0).pow(2.0/model.num_iterations)


def train_newton_schulz(config: Dict[str, Any]):
    device = torch.device(config['device'])
    
    model = GeneralizedNewtonSchulz(
        degree=config['degree'],
        num_iterations=config['num_iterations'],
        init_coefficients=config['init_coefficients']
    ).to(device)
    
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config['learning_rate'],
        betas=config['adam_betas']
    )
    
    dataloader = create_dataloader(config['checkpoint_dirs'], config['batch_size'])
    
    def plot_fn(fn, title):
        fig, ax = plt.subplots(figsize=(10, 6))
        x = torch.linspace(-0.5, 2.0, 10000).to(device)
        y = fn(x).cpu().detach()
        ax.plot(x.cpu().numpy(), y.numpy())
        
        # Handle multi-line title
        ax.set_title(title, y=1.05, pad=10)
        ax.set_ylim(-1, 3)
        ax.axhline(y=0, color='k', linestyle='-', alpha=0.3)
        ax.axvline(x=0, color='k', linestyle='-', alpha=0.3)
        
        # Adjust layout to prevent title cutoff
        fig.tight_layout()
        
        plt.close()
    
    for epoch in range(config['num_epochs']):
        epoch_loss = 0
        num_batches = 0
        total_orthogonality_loss = 0
        total_derivative_loss = 0

        for i, matrices in enumerate(dataloader):
            matrices = matrices.to(device)
            output = model(matrices)
            orthogonality_loss = norm_of_xtx_minus_i(output)
            derivative_loss = derivative_at_zero(model)
            loss = orthogonality_loss - config['alpha'] * derivative_loss
            
            # Backprop and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Track metrics
            epoch_loss += loss.item()
            total_orthogonality_loss += orthogonality_loss.item()
            total_derivative_loss += derivative_loss.item()
            num_batches += 1
            
            if i % 10 == 0:
                print(f"Batch {i}: Ortho loss {orthogonality_loss.item():.6f}, Deriv loss {derivative_loss.item():.6f}, loss {loss.item():.6f}")
        
        # End of epoch reporting
        avg_loss = epoch_loss / num_batches
        avg_ortho_loss = total_orthogonality_loss / num_batches
        avg_deriv_loss = total_derivative_loss / num_batches
        
        print(f"Epoch {epoch}: Avg loss {avg_loss:.6f}, Avg ortho loss {avg_ortho_loss:.6f}, Avg deriv loss {avg_deriv_loss:.6f}")


In [12]:
default_config = {
    "device": "cuda",
    "checkpoint_dirs": ["singular_values/42155f96-2e99-4f36-8745-dd903b02b18f"],
    "degree": 5,
    "num_iterations": 7,
    "learning_rate": 1e-2,
    "adam_betas": (0.9, 0.9),
    "batch_size": 128,
    "num_epochs": 200,
    "alpha": 0,
    "init_coefficients": [
        (4.0848, -6.8946, 2.9270),
        (3.9505, -6.3029, 2.6377),
        (3.7418, -5.5913, 2.3037),
        (2.8769, -3.1427, 1.2046),
        (2.8366, -3.0525, 1.2012),
        (2.8366, -3.0525, 1.2012),
        (2.8366, -3.0525, 1.2012),
    ]
}

# Run single training
train_newton_schulz(default_config)

Loading singular values from 1646 files...
Loaded 87646208 singular values
Batch 0: Ortho loss 33.862671, Deriv loss 10.667524, loss 33.862671
Batch 10: Ortho loss 33.742851, Deriv loss 11.037832, loss 33.742851
Batch 20: Ortho loss 33.920349, Deriv loss 11.169354, loss 33.920349
Batch 30: Ortho loss 33.928726, Deriv loss 11.439039, loss 33.928726
Batch 40: Ortho loss nan, Deriv loss nan, loss nan
Batch 50: Ortho loss nan, Deriv loss nan, loss nan
Batch 60: Ortho loss nan, Deriv loss nan, loss nan
Batch 70: Ortho loss nan, Deriv loss nan, loss nan
Batch 80: Ortho loss nan, Deriv loss nan, loss nan
Batch 90: Ortho loss nan, Deriv loss nan, loss nan
Batch 100: Ortho loss nan, Deriv loss nan, loss nan
Batch 110: Ortho loss nan, Deriv loss nan, loss nan
Batch 120: Ortho loss nan, Deriv loss nan, loss nan
Batch 130: Ortho loss nan, Deriv loss nan, loss nan
Batch 140: Ortho loss nan, Deriv loss nan, loss nan
Batch 150: Ortho loss nan, Deriv loss nan, loss nan
Batch 160: Ortho loss nan, Deriv

KeyboardInterrupt: 