In [4]:
import sys
sys.path.append("/home/miri/Documents/bachelorthesis/part1")
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
import jax
import jax.numpy as jnp
import argparse
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from jax.experimental import mesh_utils
from utils import *
from tqdm import tqdm
import os
import copy
import shutil

settings = {"num_devices": 1, "num_experiments_per_device": 3, "random_key": 42, "num_steps": 1000, "save_args": {"save_states_every": 5000, "save_train_stats_every": 1000, "save_test_stats_every": 1000, "save_grad_every": 5000, "save_hessian_every": -1}, "model": {"model": "vgg11", "num_classes": 10, "activation_fn": "relu"}, "dataset": {"dataset": "cifar10", "batch_size": 128, "dataset_path": "../datasets/"}, "optimizer": {"optimizer": "sgdm", "lr": 0.001, "lambda_wd": 0, "momentum": 0.9, "apply_wd_every": 1}, "norm": {"change_scale": "identity", "norm_fn": "global_center_std_uncenter", "norm_multiply": 0.2, "norm_every": 1, "reverse_norms": False}, "at_step": 0}
args = dict_to_namespace(settings)
# Get the cpu device in order to put all params on the cpu before distributing them to the other devices (either cpu's or gpu's)
default_cpu_device = jax.devices("cpu")[0]

# Create device mesh in order to distribute params on devices
devices = mesh_utils.create_device_mesh((args.num_devices,))
mesh = Mesh(devices, axis_names=('d',))
named_sharding = NamedSharding(mesh, P('d'))

# Initialize the random seed
split_key = jax.random.key(args.random_key)

ds_train,ds_test = get_dataset(args,False)

model,layer_depth_dict,num_layers = get_model(args)

# Get key paths of all layers
with jax.default_device(default_cpu_device):
    helper_weights = model.init(jax.random.key(0),jnp.ones((1,32,32,3)))["params"]

optimizer = get_optimizer(args,helper_weights=helper_weights)


# Initialize model 
if args.at_step == 0:
    keys = jax.random.split(split_key,num=args.num_devices*args.num_experiments_per_device+1)
    sk,split_key = keys[:-1],keys[-1]
    weights,batch_stats,optimizer_state = get_states(model.init,optimizer.init,jnp.ones((1,32,32,3)),sk,default_cpu_device)

weights,batch_stats,optimizer_state = device_put(named_sharding,weights,batch_stats,optimizer_state)

if args.norm.norm_fn == "center_std_uncenter" or args.norm.norm_fn == "global_center_std_uncenter":
    if args.at_step == 0:
        std_weights = weights

# If we want to use the normalization scheme proposed by Niehaus et al. 2024, we have to calculate the standard deviation before training
if args.norm.norm_fn == "center_std_uncenter":
    # Get the standard deviations of the weights in the beginning
    target_std = tree_map_with_path(lambda s,w : jax.vmap(lambda x : jnp.std(x,axis=tuple(range(len(x.shape)-1)),keepdims=True),in_axes=(0,))(w) if substrings_in_path(s,"conv","kernel") else None, std_weights)
    # Function that applies settings.norm_fn to every leaf of the params dictionary
    # The result is a dictionary that contains the normed params
    norm_fn =  jax.jit(lambda tree : tree_map_with_path(lambda s,w,std : get_norm_fn(args.norm.norm_fn)(w,args.norm.norm_multiply,std) if substrings_in_path(s,"conv","kernel") else w,tree,target_std))
elif args.norm.norm_fn == "global_center_std_uncenter":
    # Get the standard deviations of the weights in the beginning
    target_std = tree_map_with_path(lambda s,w : jax.vmap(lambda x : jnp.std(x,keepdims=True),in_axes=(0,))(w) if substrings_in_path(s,"conv","kernel") else None, std_weights)
    # Function that applies settings.norm_fn to every leaf of the params dictionary
    # The result is a dictionary that contains the normed params
    norm_fn =  jax.jit(lambda tree : tree_map_with_path(lambda s,w,std : get_norm_fn(args.norm.norm_fn)(w,args.norm.norm_multiply,std) if substrings_in_path(s,"conv","kernel") else w,tree,target_std))
else:
    # Function that applies settings.norm_fn to every leaf of the params dictionary
    # The result is a dictionary that contains the normed params
    norm_fn =  jax.jit(lambda tree : tree_map_with_path(lambda s,w : get_norm_fn(args.norm.norm_fn)(w,args.norm.norm_multiply) if substrings_in_path(s,"conv","kernel") else w,tree))

# We want to be able to specify how much the weights are changed via:
# new_params = (1-change_scale)*params + change_scale*params_normed
# If change_scale is not provided via settings, we simply set it to 1. Otherwise change scale is a function that takes:
# n -> current step
# N -> Max steps
# l -> current layer
# L -> Max layers
change_scale = get_change_scale(args.norm.change_scale)

# This function calculates the new params as described earlier
def change_fn(w,normed_w,n,N,l,L):
    s = change_scale(n,N,l,L)
    return (1-s)*w + s*normed_w

# This function takes as input the params, the normed params, n, N, the dictionary containing the layer depth and L
# change_fn is then applied to every common leaf of params, normed_params and the layer depth dictionary 
layerwise_stepscale_fn = jax.jit(lambda params,normed_params,n,N,layer_depth_dict,L : 
                                    tree_map_with_path(lambda s,w,normed_w,l : change_fn(w,normed_w,n,N,l,L) 
                                                    if substrings_in_path(s,"conv","kernel") else w,params,normed_params,layer_depth_dict))

import time
times_step = []
for i,(img,lbl) in zip(tqdm(range(1,args.num_steps+1)),ds_train):

    # Generate new random keys for this step
    keys = jax.random.split(split_key,num=args.num_devices*args.num_experiments_per_device+1)
    sk,split_key = keys[:-1],keys[-1]
    
    start = time.time()
    #if args.norm.start_after is None or i>=args.norm.start_after:
    #    if args.norm.stop_after is None or i<=args.norm.stop_after:
    #        if i%args.norm.norm_every == 0 and args.norm.norm_every != -1:
    #            
    #            weights = layerwise_stepscale_fn(weights,norm_fn(weights),i,args.num_steps,layer_depth_dict,num_layers)
                #tree_map(lambda x : x.block_until_ready(), weights)
                #end  = time.time()
                #if i> 5:
                #    times_norm.append(end-start)

    if i%args.optimizer.apply_wd_every == 0 and args.optimizer.apply_wd_every != -1:
        pass

    grad,aux = get_grad_fn(weights,batch_stats,img,lbl,sk,model.apply)
    batch_stats = aux["batch_stats"]
    weights,optimizer_state = update_states_fn(weights, grad, optimizer_state, optimizer.update)
    tree_map(lambda x : x.block_until_ready(), weights),tree_map(lambda x : x.block_until_ready(), optimizer_state)
    end  = time.time()
    if i> 5:
        times_step.append(end-start)

print("Mean step without norm {0}".format(sum(times_step)/len(times_step)))


times_step = []
for i,(img,lbl) in zip(tqdm(range(1,args.num_steps+1)),ds_train):

    # Generate new random keys for this step
    keys = jax.random.split(split_key,num=args.num_devices*args.num_experiments_per_device+1)
    sk,split_key = keys[:-1],keys[-1]
    
    start = time.time()
    if args.norm.start_after is None or i>=args.norm.start_after:
        if args.norm.stop_after is None or i<=args.norm.stop_after:
            if i%args.norm.norm_every == 0 and args.norm.norm_every != -1:
                
                weights = layerwise_stepscale_fn(weights,norm_fn(weights),i,args.num_steps,layer_depth_dict,num_layers)


    if i%args.optimizer.apply_wd_every == 0 and args.optimizer.apply_wd_every != -1:
        pass

    grad,aux = get_grad_fn(weights,batch_stats,img,lbl,sk,model.apply)
    batch_stats = aux["batch_stats"]
    weights,optimizer_state = update_states_fn(weights, grad, optimizer_state, optimizer.update)
    tree_map(lambda x : x.block_until_ready(), weights),tree_map(lambda x : x.block_until_ready(), optimizer_state)
    end  = time.time()
    if i> 5:
        times_step.append(end-start)


print("Mean step with norm {0}".format(sum(times_step)/len(times_step)))


2025-01-24 16:37:15.264569: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2025-01-24 16:37:15.276233: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2025-01-24 16:37:15.289716: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. I

Mean step without norm 0.021856519085678025


100%|██████████| 1000/1000 [00:27<00:00, 37.02it/s]

Mean step with norm 0.02315261735388981



