In [102]:
import math
import json
from pprint import pprint


def calculate_width(P, F, C, A, B):
    """
    Calculate the width (W) of layers in a neural network given the desired number of parameters.
    Parameters:
        - P: int, total number of parameters
        - F: int, number of features (input size)
        - C: int, number of classes (output size)
        - A: int, number of fold layers
        - B: int, total number of linear layers (including the first and last)
    Returns:
        - W: int, calculated width of the layers
    """
    # Coefficients for the quadratic equation
    a = (B - 2)
    b = (F + A + C)
    c = -P

    # Solve the quadratic equation: aW^2 + bW + c = 0
    discriminant = b**2 - 4 * a * c
    if discriminant < 0:
        return 0
    
    try:
        W = (-b + math.sqrt(discriminant)) / (2 * a)
    except ZeroDivisionError:
        W = P / (F + A + C)
    W = max(1, round(W))
    return W



def count_params(n_features, n_classes, min_size=4, scale=2, repeat=6):
    n_params = []
    for i in range(repeat):
        n_params.append(int(n_features * n_classes * min_size * scale**i))
    return n_params



def layer_widths(layers:list, n_features:int, n_classes:int, 
                 min_size:int=8, scale:int=4, repeat:int=6, verbose=0) -> list:
    
    count_linear = sum([1 for layer in layers if "linear" in layer.lower()])
    count_fold = sum([1 for layer in layers if "fold" in layer.lower()])
    
    widths = []
    for i in range(repeat):
        n_params = int(n_features * n_classes * min_size * scale**i)
        widths.append(calculate_width(P=n_params, F=n_features, C=n_classes, A=count_fold, B=count_linear))
    if verbose > 0:
        print("Biggest model:", int(n_params))
    return widths
        
widths = layer_widths(["Linear", "Fold", "Linear", ], 784, 10)
width = widths[0]
print(widths)

[79, 316, 1262, 5049, 20197, 80787]


In [None]:
ablation_archs = {}
learning_rate = 0.001
soft_fold = True
has_stretch = True
crease = None
fold_in = False
leak = 0
repeat = 5
architectures = [["Linear0", "Linear", "Linear1"], 
                 ["Linear0", "Fold", "Linear1"], 
                 ["Linear0", "Fold", "Linear", "Fold", "Linear1"], 
                 ["Linear0", "Fold", "Linear", "Fold", "Linear", "Fold", "Linear1"], 
                 ["Fold", "Fold", "Fold", "Fold", "Fold", "Fold", "Fold", "Fold", "Linear1"]]
arch_names = ["FullLinear", "21Alt", "32Alt", "43Alt", "FullFold"]
dataset_dims =  [54,        784,        28,     1024,       784]
dataset_ccs =   [7,         10,         2,      10,         10]
dataset_names = ["Cover",   "Digits",   "Higgs", "Cifar10", "Fashion"]

for aname, architecture in zip(arch_names, architectures):
    for dname, input_dim, class_count in zip(dataset_names, dataset_dims, dataset_ccs):
        
        n_params = count_params(input_dim, class_count, min_size=4, scale=2, repeat=5)
        widths = layer_widths(architecture, input_dim, class_count)
        
        for n_param, width in zip(n_params, widths):
 
            name = f"{aname}_{dname}_{n_param}"
            if soft_fold:
                architecture = [layer if layer != "Fold" else "Soft" + layer for layer in architecture]
            arch = {"learning_rate": learning_rate, 
                    "repeat": repeat,
                    "structure": []}
            
            for layer in architecture:
                if "Linear" in layer:
                    if layer[-1] == "0":
                        inf = input_dim
                        out = width
                    elif layer[-1] == "1":
                        inf = width
                        out = class_count
                    else:
                        inf = width
                        out = width
                    arch["structure"].append({"params": {"in_features": inf, 
                                                         "out_features": out}, 
                                              "type": "Linear"})
                else:
                    arch["structure"].append({"params": {"has_stretch": has_stretch,
                                                         "width": width,
                                                         "crease": crease,
                                                         }, 
                                              "type": layer})
            arch["string"] = ["".join(filter(str.isalpha, layer)) for layer in architecture]
            ablation_archs[name] = arch
# save ablation architectures
with open("BenchmarkTests/ablation_archs.json", "w") as f:
    json.dump(ablation_archs, f)

pprint(ablation_archs)

{'21Alt_Cifar10_163840': {'learning_rate': 0.001,
                          'repeat': 5,
                          'string': ['Linear', 'SoftFold', 'Linear'],
                          'structure': [{'params': {'in_features': 1024,
                                                    'out_features': 1266},
                                         'type': 'Linear'},
                                        {'params': {'crease': None,
                                                    'has_stretch': True,
                                                    'width': 1266},
                                         'type': 'SoftFold'},
                                        {'params': {'in_features': 1266,
                                                    'out_features': 10},
                                         'type': 'Linear'}]},
 '21Alt_Cifar10_327680': {'learning_rate': 0.001,
                          'repeat': 5,
                          'string': ['Linear', 'SoftFold', 'Linea

In [73]:
old_archs = json.load(open("BenchmarkTests/architectures.json", "r"))
read_idx = 15
for i, (name, info) in enumerate(old_archs.items()):
    if i == read_idx:
        pprint(info)
        break

{'learning_rate': 0.01,
 'repeat': 3,
 'string': ['SoftFold', 'SoftFold', 'SoftFold', 'Linear'],
 'structure': [{'params': {'has_stretch': False, 'width': 1.1},
                'type': 'SoftFold'},
               {'params': {'has_stretch': False, 'width': 1.1},
                'type': 'SoftFold'},
               {'params': {'has_stretch': False, 'width': 1.1},
                'type': 'SoftFold'},
               {'params': {'in_features': 1.1, 'out_features': 1.1},
                'type': 'Linear'}]}
