# Analysis of NTK Correction Term Scaling Laws with respect to L

We analyze how the spectral radius of the NTK correction term scales with respect to network depth (L)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import seaborn as sns
from scipy.stats import linregress

In [None]:
PATH_TO_DATA = "/home/janis/STG3A/deeperorwider/experiments/data/large_ntk_corrections"
files = os.listdir(PATH_TO_DATA)
data = []

for f in files:
    if f.startswith("ntk_correction_"):
        d = np.load(os.path.join(PATH_TO_DATA, f), allow_pickle=True).item()
        data.append(d)

## Scaling Analysis by Configuration

In [None]:
def analyze_scaling(data):
    # Group data by configurations for L scaling
    L_groups = {}
    for d in data:
        key = (d['N'], d['D_IN'], d['M'])
        if key not in L_groups:
            L_groups[key] = []
        L_groups[key].append((d['L'], d['mean_spectral_radius']))
        
    # Group data by configurations for N scaling
    N_groups = {}
    for d in data:
        key = (d['L'], d['D_IN'], d['M'])
        if key not in N_groups:
            N_groups[key] = []
        N_groups[key].append((d['N'], d['mean_spectral_radius']))

    # Calculate slopes for L scaling
    print("\nSlopes for L scaling (N, D_IN, M):")
    print("-" * 50)
    print("Config\t\t\tSlope\t\tR^2")
    print("-" * 50)
    
    L_slopes = {}
    for config, values in L_groups.items():
        if len(values) > 4:  # we require at least 5 points for reliable slope calculation
            x = np.log([v[0] for v in sorted(values)])
            y = np.log([v[1] for v in sorted(values)])
            slope, _, r_value, _, _ = linregress(x, y)
            L_slopes[config] = (slope, r_value**2)
            print(f"{config}\t{slope:.3f}\t\t{r_value**2:.3f}")

    # Calculate slopes for N scaling
    print("\nSlopes for N scaling (L, D_IN, M):")
    print("-" * 50)
    print("Config\t\t\tSlope\t\tR^2")
    print("-" * 50)
    
    N_slopes = {}
    for config, values in N_groups.items():
        if len(values) > 4:  # we require at least 5 points for reliable slope calculation
            x = np.log([v[0] for v in sorted(values)])
            y = np.log([v[1] for v in sorted(values)])
            slope, _, r_value, _, _ = linregress(x, y)
            N_slopes[config] = (slope, r_value**2)
            print(f"{config}\t{slope:.3f}\t\t{r_value**2:.3f}")

    return L_slopes, N_slopes

In [None]:
# Analyze scaling with respect to L and N
print("Analyzing scaling laws...")
L_slopes, N_slopes = analyze_scaling(data)

Analyzing scaling laws...

Slopes for L scaling (N, D_IN, M):
--------------------------------------------------
Config			Slope		R^2
--------------------------------------------------
(128, 20, 10)	1.315		0.985
(8, 50, 200)	1.158		0.993
(8, 100, 20)	1.092		0.993
(10, 100, 30)	1.248		0.993
(16, 20, 10)	1.123		0.988
(10, 50, 10)	1.093		0.992
(10, 100, 20)	1.147		0.991
(8, 50, 2000)	1.171		0.992
(16, 50, 500)	1.214		0.989
(25, 50, 10)	1.201		0.984
(32, 50, 100)	1.238		0.987
(32, 50, 10)	1.191		0.976
(16, 100, 20)	1.210		0.989
(32, 20, 20)	1.254		0.981
(8, 20, 60)	1.245		0.993
(32, 50, 200)	1.330		0.990
(8, 500, 20)	1.230		0.994
(8, 50, 20)	1.076		0.993
(64, 50, 10)	1.268		0.986
(16, 100, 10)	1.141		0.990
(10, 100, 10)	1.108		0.992
(8, 50, 1000)	1.159		0.994
(8, 20, 500)	1.182		0.993
(8, 20, 10)	1.129		0.886
(8, 20, 40)	1.113		0.989
(16, 100, 200)	1.207		0.990
(8, 50, 500)	1.173		0.993
(8, 200, 1000)	1.204		0.994
(16, 20, 30)	1.169		0.986
(16, 200, 10)	1.211		0.990
(8, 200, 20)	1.136		0.99