<a href="https://colab.research.google.com/github/kiplangatkorir/Hierarchial-Compression-With-KANs/blob/main/simplified_KAN_compressor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install torch



In [2]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.decomposition import PCA

In [5]:
class MemoryEfficientCompression:
    def __init__(self, model, compression_ratio=0.5):
        self.model = model
        self.compression_ratio = compression_ratio
        self.compressed_state = None
        self.pca_models = {}

    def compress(self):
        compressed_state = {}
        for name, param in self.model.named_parameters():
            if param.dim() > 1:  # Only compress 2D+ tensors
                shape = param.shape
                flattened = param.data.flatten().numpy()
                # Calculate n_components based on the minimum of shape dimensions
                n_components = max(1, int(min(shape) * self.compression_ratio)) # Changed this line

                pca = PCA(n_components=n_components)
                compressed = pca.fit_transform(flattened.reshape(-1, shape[1])) # Changed this line

                compressed_state[name] = {
                    'compressed': torch.from_numpy(compressed).float(),
                    'shape': shape,
                    'mean': torch.from_numpy(pca.mean_).float(),
                    'components': torch.from_numpy(pca.components_).float()
                }
                self.pca_models[name] = pca
            else:
                compressed_state[name] = param.data

        self.compressed_state = compressed_state
        return compressed_state

    def decompress(self):
        if self.compressed_state is None:
            raise ValueError("Model hasn't been compressed yet.")

        decompressed_state = {}
        for name, compressed_data in self.compressed_state.items():
            if isinstance(compressed_data, dict):  # Compressed tensor
                pca = self.pca_models[name]
                decompressed = pca.inverse_transform(compressed_data['compressed'].numpy())
                decompressed = torch.from_numpy(decompressed).float().view(compressed_data['shape'])
                decompressed_state[name] = decompressed
            else:  # Uncompressed tensor
                decompressed_state[name] = compressed_data

        return decompressed_state

    def apply_compressed_weights(self):
        decompressed_state = self.decompress()
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                param.copy_(decompressed_state[name])

def compress_model(model, compression_ratio=0.5):
    compressor = MemoryEfficientCompression(model, compression_ratio)
    compressed_state = compressor.compress()

    total_params = sum(p.numel() for p in model.parameters())
    compressed_params = sum(c['compressed'].numel() for c in compressed_state.values() if isinstance(c, dict))
    compressed_params += sum(c.numel() for c in compressed_state.values() if not isinstance(c, dict))

    print(f"Original parameters: {total_params}")
    print(f"Compressed parameters: {compressed_params}")
    print(f"Compression ratio: {compressed_params / total_params:.2f}")

    return compressor

In [6]:
# Example usage
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

compressor = compress_model(model, compression_ratio=0.5)

# To use the compressed model:
compressor.apply_compressed_weights()
print("Compression and decompression complete.")

Original parameters: 235146
Compressed parameters: 41404
Compression ratio: 0.18
Compression and decompression complete.
