In [1]:
import time
import torch
import json
import os
import pickle
import argparse
from typing import Dict, List, Tuple
import numpy as np
import random
import uuid

def read_global_model(file_path):
    with open(file_path, 'rb') as file:
        data = pickle.load(file)

    return data

In [2]:
data = read_global_model('database_SNN_noise_propor_dot7/global_model_permanent.pkl')
params = data['global_model_params']

In [14]:
print(len(params[0].flatten()))

1728


In [15]:
flattened_params = np.concatenate([p.flatten() for p in params])

In [20]:
indices_flattened_params_desc_sort_abs_value = np.argsort(np.abs(flattened_params))[::-1]

In [21]:
flattened_params_desc_sort = flattened_params[indices_flattened_params_desc_sort_abs_value]

In [40]:
print(len(flattened_params_desc_sort))
print(flattened_params_desc_sort[0:-int(len(flattened_params_desc_sort) / 10)])

128818322
[ 3.66600000e+03  3.66600000e+03  3.66600000e+03 ...  8.25705659e-03
 -8.25705659e-03 -8.25705566e-03]


In [58]:

def compress_parameters(parameters: List[np.ndarray], compression_rate) -> List[np.ndarray]:
    # Flatten all parameters to a single array
    all_params = np.concatenate([p.flatten() for p in parameters])
    
    # Determine the threshold value for the top 'compression_rate' fraction of parameters
    if compression_rate > 0:
        threshold = np.quantile(np.abs(all_params), 1 - compression_rate)
    else:
        # If compression_rate is 0, we keep no parameters (everything is set to 0)
        threshold = np.inf
    
    # Apply the threshold to each parameter matrix
    compressed_parameters = []
    for param in parameters:
        # Use np.where to keep values above the threshold, else set to 0
        compressed_param = np.where(np.abs(param) >= threshold, param, 0)
        compressed_parameters.append(compressed_param)
    
    return compressed_parameters

In [59]:
data2 = read_global_model('database_SNN_noise_propor_dot7/global_model_permanent.pkl')
params2 = data2['global_model_params']

In [53]:
print(params2)

[array([[[[-0.23282641,  0.6375502 , -0.00367537],
         [ 0.19462684,  0.13028947, -0.16674395],
         [ 0.1165335 ,  0.32420838,  0.27131054]],

        [[-0.06379094,  0.01429567,  0.04863561],
         [ 0.42089146, -0.35786644,  0.11061637],
         [ 0.2114875 , -0.11639033, -0.02849376]],

        [[-0.06378027, -0.35963267,  0.03818979],
         [-0.21626393,  0.1955364 , -0.04629376],
         [ 0.24144451, -0.4951545 , -0.25317335]]],


       [[[ 0.14027815,  0.29869395,  0.46900004],
         [-0.31544578,  0.11596999,  0.40508536],
         [-0.07832726, -0.01217095,  0.2912489 ]],

        [[ 0.14242707,  0.3510918 , -0.34579703],
         [-0.11926011,  0.02315711, -0.05745   ],
         [-0.28821948, -0.2872924 ,  0.2304738 ]],

        [[-0.08494649,  0.18396494,  0.08875465],
         [-0.517625  , -0.22853261,  0.16533321],
         [ 0.05800114,  0.14155464, -0.23429401]]],


       [[[-0.25014985, -0.40326715, -0.41298553],
         [ 0.05868469,  0.2836221

In [60]:
compressed_parameters = compress_parameters(params2, 0.8)

In [61]:
flattened_compressed_parameters = np.concatenate([p.flatten() for p in compressed_parameters])

In [62]:
indices_flattened_compressed_params_desc_sort_abs_value = np.argsort(np.abs(flattened_compressed_parameters))[::-1]
flattened_compressed_params_desc_sort = flattened_compressed_parameters[indices_flattened_compressed_params_desc_sort_abs_value]

In [70]:
print(len(flattened_compressed_params_desc_sort))
print(flattened_compressed_params_desc_sort[0:-int(len(flattened_compressed_params_desc_sort) / 10 * 2) + 1])
print(flattened_compressed_params_desc_sort[0:-1])

128818322
[3.66600000e+03 3.66600000e+03 3.66600000e+03 ... 1.66661106e-02
 1.66661106e-02 1.66661106e-02]
[3666. 3666. 3666. ...    0.    0.    0.]


In [72]:
print(type(compressed_parameters))

<class 'list'>


In [74]:
def count_differences(arr1: np.ndarray, arr2: np.ndarray) -> int:
    # Ensure that arr1 and arr2 have the same shape
    if arr1.shape != arr2.shape:
        raise ValueError("Arrays must have the same shape")
    
    # Compare the two arrays and count the differences
    differences = np.sum(arr1 != arr2)
    
    return differences

In [75]:
total_diff = 0
for item1, item2 in zip(params2, compressed_parameters):
    total_diff += count_differences(item1, item2)

print(total_diff / len(flattened_compressed_parameters))

0.19999998913198078
