In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from collections import Counter
import time
import numpy as np
import os
from itertools import product
import math
from datetime import datetime
from IPython.display import clear_output
from torch.linalg import norm

# Define the LeNet-5 architecture
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        # First convolutional layer: input 1 channel, output 6 channels, kernel size 5x5
        self.conv1 = nn.Conv2d(1, 6, 5)
        # Average pooling layer with kernel size 2x2 and stride 2
        self.pool = nn.AvgPool2d(2, 2)
        # Second convolutional layer: input 6 channels, output 16 channels, kernel size 5x5
        self.conv2 = nn.Conv2d(6, 16, 5)
        # Fully connected layers
        self.fc1 = nn.Linear(16*4*4, 120)  # Flattened input size 16*4*4, output 120
        self.fc2 = nn.Linear(120, 84)  # Output 84 neurons
        self.fc3 = nn.Linear(84, 10)  # Output 10 classes

    def forward(self, x):
        # Apply first convolution, then activation function (tanh), then pooling
        x = torch.tanh(self.conv1(x))
        x = self.pool(x)
        # Apply second convolution, then activation function (tanh), then pooling
        x = torch.tanh(self.conv2(x))
        x = self.pool(x)
        # Flatten the tensor for the fully connected layers
        x = x.view(-1, 16*4*4)
        # Pass through fully connected layers with activation (tanh)
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        # Output layer without activation (raw scores for classification)
        x = self.fc3(x)
        return x

def compute_entropy(string):
    """
    Function to compute the entropy of a given string.
    """
    # Count the frequency of each character in the string
    frequencies = Counter(string)
    # Calculate the total length of the string
    total_length = len(string)
    
    # Compute entropy
    entropy = 0
    for freq in frequencies.values():
        probability = freq / total_length  # Compute probability of each character
        entropy -= probability * math.log2(probability)  # Apply entropy formula
    
    return entropy * total_length  # Return the entropy weighted by string length
            
def test_accuracy(model, dataloader, device):
    """
    Function to calculate the accuracy of a model on a given dataloader.
    """
    correct, total = 0, 0
    with torch.no_grad():  # Disable gradient computation for evaluation
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)  # Move data to the appropriate device
            outputs = model(images)  # Get model predictions
            _, predicted = torch.max(outputs.data, 1)  # Get the class with the highest probability
            total += labels.size(0)  # Update total number of samples
            correct += (predicted == labels).sum().item()  # Count correct predictions
    
    accuracy = 100 * correct / total  # Compute accuracy percentage
    return accuracy

def knapsack_specialized_pruning_parallel(xi, v, w, C, delta):

    xi = xi - delta

    M = w.shape[0]

    # Step 1: Calcolo x_plus (breakpoints)
    b_list = []
    b = 0
    while True:
        delta_xi = (xi[b + 1:] - xi[b])
        delta_v = (v[b + 1:] - v[b])
        b = torch.argmin(delta_xi / delta_v) + 1 + b_list[-1] if b_list else 0
        if b != C - 1:
            b_list.append(int(b))
        if b + 1 > C - 1:
            break
    b_list.append(C - 1)
    x_plus = torch.zeros(C, dtype=torch.int32)
    x_plus[torch.tensor(b_list)] = 1

    # Preallocazioni
    x = torch.zeros(M, C)
    lambda_opt = torch.zeros(M)

    # Step 2: Classificazione dei problemi
    v0 = v[0]
    v_last = v[-1]
    mask_small = w < v0
    mask_large = w > v_last
    mask_mid = (~mask_small) & (~mask_large)

    # CASO: w > v[-1]
    x[mask_large, -1] = 1

    # CASO: w < v[0]
    x[mask_small, 0] = 1

    # CASO INTERMEDIO
    if mask_mid.any():
        M_mid = mask_mid.sum() #Numero di v[0]<w<v[-1]
        w_mid = w[mask_mid]
        ratio = xi / v
        neg_indices = torch.where(ratio < 0)[0]
        neg_sorted = neg_indices[torch.argsort(ratio[neg_indices], descending=True)]
        pos_indices = torch.where(ratio >= 0)[0]
        pos_sorted = pos_indices[torch.argsort(ratio[pos_indices])]
        b_vector = torch.cat([neg_sorted, pos_sorted], dim=0)
        ratio_b = w_mid[:, None] / v[b_vector]
        x_plus_b = x_plus[b_vector].bool()
        cond1 = (ratio_b >= 0) & x_plus_b
        valid_i0 = cond1.float() * torch.arange(C)[None, :]
        valid_i0[~cond1] = float('inf')
        i0_pos = valid_i0.argmin(dim=1)
        i0 = b_vector[i0_pos]
        v_i0 = v[i0]
        x_single = w_mid / v_i0
        invalid_i0 = x_plus[i0] == 0
        use_two = x_single > 1
        i1 = torch.full_like(i0, fill_value=-1)
        if use_two.any():
            b_vector_exp = b_vector.unsqueeze(0).expand(M_mid, -1)
            i0_exp = i0.unsqueeze(1).expand_as(b_vector_exp)
            x_plus_mask = x_plus[b_vector_exp] == 1
            greater_mask = b_vector_exp > i0_exp
            valid_mask = x_plus_mask & greater_mask
            masked_b_vector = torch.where(valid_mask, b_vector_exp, torch.full_like(b_vector_exp, C))
            i1_candidate, _ = masked_b_vector.min(dim=1)
            i1_candidate_use_two = i1_candidate[use_two]
            valid_i1_use_two = i1_candidate_use_two < C
            i0_use_two = i0[use_two]
            i1[use_two] = torch.where(valid_i1_use_two, i1_candidate_use_two, i0_use_two)

        # Costruzione x_mid
        x_mid = torch.zeros(M_mid, C)

        # Caso: uso un solo indice
        mask_one = ~use_two
        rows_one = torch.where(mask_one)[0]
        cols_one = i0[mask_one]
        x_mid[rows_one, cols_one] = torch.clamp(torch.round(w_mid[mask_one] / v[cols_one], decimals=5), 0.0, 1.0)

        # Caso: combinazione convessa
        mask_two = use_two & (i1 != i0)
        rows_two = torch.where(mask_two)[0]
        idx0 = i0[mask_two]
        idx1 = i1[mask_two]
        v0 = v[idx0]
        v1 = v[idx1]
        w_sel = w_mid[mask_two]
        theta = (w_sel - v1) / (v0 - v1)
        x_mid[rows_two, idx0] = torch.round(theta, decimals=5)
        x_mid[rows_two, idx1] = torch.round(1 - theta, decimals=5)

        x[mask_mid] = x_mid
        
    # === Calcolo vectorizzato dei moltiplicatori ===
    eps = 1e-6
    nz_mask = torch.abs(x) > eps
    nz_counts = nz_mask.sum(dim=1)
    lambda_opt = torch.zeros(x.shape[0])

    # Caso 1 valore non nullo
    m1 = torch.where(nz_counts == 1)[0]
    if m1.numel() > 0:
        submask = nz_mask[m1]
        indices = submask.nonzero(as_tuple=False) 
        i = indices[:, 1]
        lambda_opt[m1] = -xi[i] / v[i]
        lambda_opt[m1] = torch.round(lambda_opt[m1], decimals=5)

    # Caso 2 valori non nulli
    m2 = torch.where(nz_counts == 2)[0]
    if m2.numel() > 0:
        indices = nz_mask[m2].nonzero().reshape(-1, 2)
        grouped = indices.view(-1, 2, 2)
        i = grouped[:, 0, 1]
        j = grouped[:, 1, 1]
        delta_xi = xi[j] - xi[i]
        delta_idx = j - i
        passo = v[1] - v[0]
        lambda_opt[m2] = -delta_xi / (delta_idx * passo)
        lambda_opt[m2] = torch.round(lambda_opt[m2], decimals=5)

    objective_values = delta + x @ xi
    
    return x, lambda_opt, objective_values

def FISTA(xi, v, w, C, subgradient_step, delta, max_iterations):
    """
    Implements the Fast Iterative Shrinking-Thresholding Algorithm (FISTA) 
    for optimizing a constrained objective function.

    Args:
        xi (torch.Tensor): Initial parameter vector.
        v (torch.Tensor): Constraint-related vector.
        w (torch.Tensor): Weight vector.
        C (float): Constraint parameter.
        subgradient_step (float): Step size for subgradient descent.
        max_iterations (int): Maximum number of iterations.

    Returns:
        tuple: Updated xi, lambda_plus (Lagrange multiplier), 
               x_i_star (optimal allocation), and phi (objective function value).
    """
    
    upper_c = w.size(0)  # Define an upper bound for constraints
    
    # Initialize previous values for FISTA acceleration
    xi_prev = xi.clone()
    t_prev = torch.tensor(1.0)

    for iteration in range(1, max_iterations + 1):
        # Solve the simil-knapsack problem for the current xi
        x_i_star, lambda_plus, phi_plus = knapsack_specialized_pruning_parallel(xi, v, w, C, delta)
        sum_x_star = torch.sum(x_i_star, dim=0)

        # Compute the optimal c values c_star
        c_star = torch.exp(torch.log(torch.tensor(2)) * xi - 1)
        c_star = torch.clamp(c_star, min=0, max=upper_c)

        # Compute the super-gradient
        g = -(c_star - sum_x_star)
        
        # Compute the 3 pieces of the objective function value phi and put them together
        phi1 = torch.sum(c_star * torch.log(c_star) / torch.log(torch.tensor(2)))
        phi2 = -torch.sum(xi * c_star)
        phi3 = torch.sum(xi * sum_x_star)
        phi = phi1 + phi2 + phi3

        # FISTA acceleration step
        t_current = (1 + torch.sqrt(1 + 4 * t_prev**2)) / 2
        y = xi + ((t_prev - 1) / t_current) * (xi - xi_prev)

        # Gradient update step
        xi_next = y + (1 / subgradient_step) * g 

        # Update variables for next iteration
        xi_prev = xi.clone()
        xi = xi_next.clone()
        t_prev = t_current

        # Ensure xi remains sorted
        xi = torch.sort(xi)[0]

    return xi, lambda_plus, x_i_star, phi

def ProximalBM(xi, v, w, C, zeta, subgradient_step, delta, max_iterations):
    """
    Implements the Proximal Bundle Method (PBM) for solving constrained 
    optimization problems using bundle techniques.

    Args:
        xi (torch.Tensor): Initial parameter vector.
        v (torch.Tensor): Constraint-related vector.
        w (torch.Tensor): Weight vector.
        C (float): Constraint parameter.
        zeta (float): Regularization parameter for proximal term.
        subgradient_step (float): Step size for subgradient descent.
        max_iterations (int): Maximum number of iterations.

    Returns:
        tuple: Updated xi, lambda_plus (Lagrange multiplier), 
               x_i_star (optimal allocation), and phi (objective function value).
    """
    
    upper_c = w.size(0)  # Define an upper bound for constraints

    # Parameters for the bundle method
    epsilon = 1e-5  # Convergence tolerance
    bundle_size = 5  # Maximum bundle size
    bundle = []  # Initialize the bundle (list of points, phi values, and gradients)

    for iteration in range(1, max_iterations + 1):
        # Solve the knapsack problem for the current xi
        x_i_star, lambda_plus, phi_plus = knapsack_specialized_pruning_parallel(xi, v, w, C, delta)
        sum_x_star = torch.sum(x_i_star, dim=0)

        # Compute the optimal c values c_star
        c_star = torch.exp(torch.log(torch.tensor(2)) * xi - 1)
        c_star = torch.clamp(c_star, min=0, max=upper_c)

        # Compute the super-gradient
        g = -(c_star - sum_x_star)

        # Compute the objective function value phi
        phi1 = torch.sum(c_star * torch.log(c_star) / torch.log(torch.tensor(2)))
        phi2 = -torch.sum(xi * c_star)
        phi3 = torch.sum(xi * sum_x_star)
        phi = phi1 + phi2 + phi3

        # Add the current point to the bundle
        bundle.append((xi.clone(), phi, g.clone()))
        if len(bundle) > bundle_size:
            bundle.pop(0)  # Remove the oldest point if the bundle exceeds max size

        # Solve the quadratic regularization subproblem
        bundle_points = torch.stack([item[0] for item in bundle])  # Bundle points
        bundle_phis = torch.tensor([item[1] for item in bundle])  # Phi values
        bundle_gradients = torch.stack([item[2] for item in bundle])  # Gradient values

        # Construct the quadratic approximation model
        diff = xi - bundle_points
        model_phi = bundle_phis + torch.sum(bundle_gradients * diff, dim=1)
        proximal_term = (zeta / 2) * norm(diff, dim=1)**2
        subproblem_objective = model_phi + proximal_term

        # Determine the next xi by minimizing the subproblem objective
        best_idx = torch.argmax(subproblem_objective)
        xi_next = bundle_points[best_idx] + (1 / zeta) * bundle_gradients[best_idx]

        # Clip xi to enforce constraints
        xi_next = torch.clamp(xi_next, min=0.01, max=upper_c)

        # Check for convergence
        if norm(xi_next - xi) < epsilon:
            break

        # Update xi for the next iteration
        xi = xi_next.clone()
        
    return xi, lambda_plus, x_i_star, phi

def initialize_weights(model, min_w, max_w):
    """
    Initializes the weights of a given model using a uniform distribution.

    Args:
        model (torch.nn.Module): The neural network model whose weights need initialization.
        min_w (float): Minimum value for weight initialization.
        max_w (float): Maximum value for weight initialization.

    Returns:
        None
    """
    for param in model.parameters():
        torch.nn.init.uniform_(param, a=min_w, b=max_w)

def train_and_evaluate(C, lr, lambda_reg, alpha, subgradient_step, w0, r, 
                       target_acc, target_entr, min_xi, max_xi, n_epochs, device, 
                       train_optimizer, entropy_optimizer, trainloader, testloader, delta):
    
    model = LeNet5().to(device)
    criterion = nn.CrossEntropyLoss()

    if(train_optimizer == 'A'):
        optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=lambda_reg * alpha)
    elif(train_optimizer == 'S'):
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=lambda_reg * alpha)
    
    # Parameters initialization
    min_w, max_w = w0 - r, w0 + r
    v = torch.linspace(min_w, max_w - (max_w - min_w)/C, steps=C)
    initialize_weights(model, min_w, max_w)    
    w = torch.cat([param.data.view(-1) for param in model.parameters()])
    upper_c, lower_c = w.size(0), 1e-2
    xi = min_xi + (max_xi - min_xi) * torch.rand(C, device=device)    
    xi = torch.sort(xi)[0]   
    entropy, accuracy = 0, 0
    accuracies, entropies, distinct_weights = [], [], []
    zeta, l = 50000, 0.5

    for epoch in range(n_epochs):
        start_time = time.time()

        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            w = torch.cat([param.data.view(-1) for param in model.parameters()])
            #unique_weights = torch.unique(w).numel() 
            #indices = torch.searchsorted(v, w, right=True) - 1
            #indices = torch.clamp(indices, min=0)
            #w_quantized = v[indices]

            zeta *= 1 + l
            l = l / 1.5
            if(entropy_optimizer == 'F'):
                #xi, beta_tensor, x_star, phi = FISTA(xi, v, w_quantized, C, subgradient_step, max_iterations=15) 
                xi, beta_tensor, x_star, phi = FISTA(xi, v, w, C, subgradient_step, delta, max_iterations=15) 
            elif(entropy_optimizer == 'PM'):
                #xi, beta_tensor, x_star, phi = ProximalBM(xi, v, w_quantized, C, zeta, subgradient_step, max_iterations=15) 
                xi, beta_tensor, x_star, phi = ProximalBM(xi, v, w, C, zeta, subgradient_step, delta, max_iterations=15)      
            
            # Update of ∇ɸ
            idx = 0
            for param in model.parameters():
                numel = param.numel()
                if param.grad is not None:
                    param_grad = param.grad.view(-1)
                else:
                    param_grad = torch.zeros_like(param.data.view(-1))
                param_grad += (1 - alpha) * lambda_reg * beta_tensor[idx:idx + numel]
                param.grad = param_grad.view(param.size())
                idx += numel
            
            loss.backward()
            optimizer.step()

        w = torch.cat([param.data.view(-1) for param in model.parameters()])
        
        entropy = round(compute_entropy(w.tolist())) + 1
        entropies.append(entropy)
        accuracy = test_accuracy(model, testloader, device)
        accuracies.append(accuracy)
        
        print(f"C={C}, lr={lr}, lambda_reg={lambda_reg}, "
              f"alpha={alpha}, subgradient_step={subgradient_step}, w0={w0}, r={r}, "
              f"target_acc={target_acc}, target_entr={target_entr}, "
              f"min_xi={min_xi}, max_xi={max_xi}, n_epochs={n_epochs}, train_optimizer={train_optimizer} "
              f"entropy_optimizer={entropy_optimizer}")
        print("\nEpoch:", epoch+1)
        print("\nAccuracies:", accuracies)
        print("\nEntropies:", entropies)
        print("\nMax Accuracy:", max(accuracies))
        print("Min entropy:", min(entropies))

        # Saving a better model
        if(accuracy >= target_acc and entropy <= target_entr):
            print("💥💥💥💥💥💥💥\n💥ATTENTION!💥\n💥💥💥💥💥💥💥")
            torch.save(model.state_dict(), f"BestModelsBeforeQuantization/C{C}_r{round(r*1000)}.pth")
            target_acc = accuracy
            target_entr = entropy
        
        print("-"*60)
        
        # Entropy exit conditions
        if(epoch > 20 and entropy > 600000):
            print("Entropy is not decreasing enough! (A)")
            return accuracy, entropy, target_acc, target_entr
        if(epoch > 50):
            if(entropies[-1] > 200000 and entropies[-2] > 200000 and entropies[-3] > 200000 and entropies[-4] > 200000):
                print("Entropy is not decreasing enough! (B)")
                return accuracy, entropy, target_acc, target_entr           
            
        # Accuracy exit condition
        if(epoch == 1 and accuracies[-1] < 70):
            print("Accuracy is too low! (C)")
            return accuracy, entropy, target_acc, target_entr                    
        if(epoch > 10):
            if(accuracies[-1] < 90 and accuracies[-2] < 90 and accuracies[-3] < 90 and accuracies[-4] < 90):
                print("Accuracy is too low! (D)")
                return accuracy, entropy, target_acc, target_entr     
        
        # ... ADD OTHER EXIT CONDITIONS ...      
        
        training_time = time.time() - start_time
        print(f"Time taken for a epoch: {training_time:.2f} seconds\n")
              
    return accuracy, entropy, target_acc, target_entr

# Select the computing device: use GPU if available, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define a transformation: convert images to tensors
transform = transforms.Compose([transforms.ToTensor()])
# Load the MNIST training dataset with the defined transformation
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a DataLoader for the training set with batch size 64, shuffling enabled, and 4 worker threads
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
# Load the MNIST test dataset with the same transformation
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Create a DataLoader for the test set with batch size 1000, shuffling disabled, and 4 worker threads
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False, num_workers=4)

np.set_printoptions(precision=6)

# Grid search 
param_grid = {
    "C": [6, 256],  # Number of buckets of quantization
    "lr": [0.0007], # Learning rate for the optimizer
    "lambda_reg": [0.0015], # Regularization factor
    "alpha": [0.533], # Percentage of standard regularization wrt entropic one 
    "subgradient_step": [1e5],  # Step size for subgradient
    "w0": [-0.11], # Initial weight parameters
    "r": [1.1],
    "target_acc": [98.99], # Target accuracy percentage
    "target_entr": [0.99602e6], # Target entropy threshold 
    "min_xi": [0], # lower bound for xi initialization
    "max_xi": [1],  # upper bound for xi initialization
    "n_epochs": [100], # Number of training epochs
    "device": [device], # Computing device (GPU or CPU)
    "train_optimizer": ['A'],  # 'A' for Adam, and 'S' for SGD
    "entropy_optimizer": ['F'], # 'F' for FISTA, 'PM' for proximal bundle
    "trainloader": [trainloader],  # Training data loader
    "testloader": [testloader], # Test data loader
    "delta": [0.1, 0.2, 0.3, 0.4]
}

combination = 0

for (C, lr, lambda_reg, alpha, subgradient_step, w0, r, 
     target_acc, target_entr, min_xi, max_xi, n_epochs, 
     device, train_optimizer, entropy_optimizer, trainloader, 
     testloader, delta) in product(param_grid["C"],
                            param_grid["lr"],
                            param_grid["lambda_reg"],
                            param_grid["alpha"],
                            param_grid["subgradient_step"],
                            param_grid["w0"],
                            param_grid["r"],
                            param_grid["target_acc"],
                            param_grid["target_entr"],
                            param_grid["min_xi"],
                            param_grid["max_xi"],
                            param_grid["n_epochs"],
                            param_grid["device"],
                            param_grid["train_optimizer"],      
                            param_grid["entropy_optimizer"],   
                            param_grid["trainloader"], 
                            param_grid["testloader"],
                            param_grid["delta"]
                            ):
    
    # Counts combinations
    combination += 1
    
    # Start training
    start_time = time.time()
    accuracy, entropy, target_acc, target_entr = train_and_evaluate(C=C,              
                                                                lr=lr,           
                                                                lambda_reg=lambda_reg,    
                                                                alpha=alpha,          
                                                                subgradient_step=subgradient_step, 
                                                                w0=w0,             
                                                                r=r,              
                                                                target_acc=target_acc,      
                                                                target_entr=target_entr, 
                                                                min_xi=min_xi,              
                                                                max_xi=max_xi,             
                                                                n_epochs=n_epochs,        
                                                                device=device,      
                                                                train_optimizer=train_optimizer,     
                                                                entropy_optimizer=entropy_optimizer,   
                                                                trainloader=trainloader, 
                                                                testloader=testloader,
                                                                delta=delta
                                                            )
        
    training_time = time.time() - start_time
    print(f'Time spent to train the model: {training_time:.2f} seconds\n')

C=6, lr=0.0007, lambda_reg=0.0015, alpha=0.533, subgradient_step=100000.0, w0=-0.11, r=1.1, target_acc=98.99, target_entr=996020.0, min_xi=0, max_xi=1, n_epochs=100, train_optimizer=A entropy_optimizer=F

Epoch: 1

Accuracies: [56.05]

Entropies: [685823]

Max Accuracy: 56.05
Min entropy: 685823
------------------------------------------------------------
Time taken for a epoch: 46.83 seconds



KeyboardInterrupt: 

In [2]:
[round(1.1 + i * 0.002, 3) for i in range(10)]

[1.1, 1.102, 1.104, 1.106, 1.108, 1.11, 1.112, 1.114, 1.116, 1.118]