In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


import numpy as np
import matplotlib.pyplot as plt
import copy
from ast import Param

from utils.prune import prune_weights
from utils.count_improvement import improvements

In [2]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        



In [11]:
net = Net().to(device)
varianzas_net = Net().to(device)

# Extraer los diccionarios de estado (pesos y biases)
state_dict_red1 = net.state_dict()
state_dict_red2 = varianzas_net.state_dict()

# Crear un nuevo diccionario de estado donde sumamos los pesos
state_dict_suma = {}
for key in state_dict_red1:
    state_dict_suma[key] = state_dict_red1[key] + state_dict_red2[key]

# Crear una nueva red o modificar una existente con los pesos sumados
red3 = Net().to(device)
red3.load_state_dict(state_dict_suma)

print(red3.state_dict())


OrderedDict([('fc1.weight', tensor([[-0.0630,  0.0289, -0.0100,  ..., -0.0022, -0.0578,  0.0271],
        [-0.0269,  0.0447, -0.0009,  ..., -0.0132, -0.0047, -0.0447],
        [-0.0237,  0.0021,  0.0435,  ...,  0.0304,  0.0384, -0.0447],
        ...,
        [ 0.0027,  0.0225,  0.0119,  ...,  0.0070, -0.0048, -0.0374],
        [ 0.0326,  0.0086, -0.0353,  ..., -0.0260,  0.0090, -0.0086],
        [-0.0257, -0.0326, -0.0043,  ...,  0.0217, -0.0585,  0.0570]],
       device='cuda:0')), ('fc1.bias', tensor([ 1.6353e-02, -2.1737e-02,  1.0240e-02, -2.5188e-02, -3.3184e-02,
         3.1251e-03, -3.7318e-02, -2.9135e-02, -3.3493e-02, -5.7243e-03,
        -3.9931e-02,  1.0838e-02, -2.6046e-02,  3.5433e-02, -1.0883e-03,
         5.0756e-03, -5.4094e-03,  1.9611e-02,  3.4448e-03, -9.7918e-03,
         4.3054e-02, -2.1449e-02,  2.1230e-03, -4.0123e-02,  2.7753e-02,
        -1.4180e-02, -1.4833e-02, -9.3382e-03, -2.6001e-02, -3.9731e-03,
        -4.6212e-03, -5.5549e-03, -2.5561e-02, -3.3738e-02, -

In [12]:
print(net.state_dict())

OrderedDict([('fc1.weight', tensor([[-0.0293,  0.0172, -0.0075,  ...,  0.0057, -0.0331, -0.0066],
        [-0.0291,  0.0268,  0.0014,  ..., -0.0108,  0.0194, -0.0259],
        [-0.0005,  0.0163,  0.0255,  ...,  0.0258,  0.0292, -0.0332],
        ...,
        [-0.0226,  0.0129, -0.0102,  ..., -0.0075, -0.0148, -0.0327],
        [ 0.0136,  0.0265, -0.0100,  ..., -0.0154,  0.0184,  0.0021],
        [ 0.0028, -0.0006, -0.0122,  ..., -0.0098, -0.0243,  0.0236]],
       device='cuda:0')), ('fc1.bias', tensor([ 0.0344, -0.0039, -0.0114, -0.0172, -0.0211, -0.0276, -0.0354, -0.0004,
        -0.0015, -0.0166, -0.0116, -0.0093, -0.0121,  0.0113,  0.0253,  0.0234,
         0.0114, -0.0083,  0.0298,  0.0020,  0.0305, -0.0308, -0.0148, -0.0189,
         0.0158,  0.0118, -0.0335,  0.0124,  0.0091, -0.0154, -0.0039,  0.0160,
         0.0059, -0.0298, -0.0037, -0.0031, -0.0248, -0.0284, -0.0223,  0.0130,
        -0.0342,  0.0019, -0.0222,  0.0047, -0.0260, -0.0286,  0.0043, -0.0319,
         0.0255, -0

In [13]:
print(varianzas_net.state_dict())

OrderedDict([('fc1.weight', tensor([[-0.0337,  0.0117, -0.0025,  ..., -0.0079, -0.0247,  0.0338],
        [ 0.0021,  0.0179, -0.0023,  ..., -0.0024, -0.0240, -0.0188],
        [-0.0232, -0.0142,  0.0179,  ...,  0.0046,  0.0092, -0.0115],
        ...,
        [ 0.0253,  0.0096,  0.0221,  ...,  0.0144,  0.0101, -0.0047],
        [ 0.0190, -0.0179, -0.0253,  ..., -0.0106, -0.0094, -0.0107],
        [-0.0285, -0.0320,  0.0079,  ...,  0.0315, -0.0342,  0.0334]],
       device='cuda:0')), ('fc1.bias', tensor([-1.8049e-02, -1.7880e-02,  2.1617e-02, -7.9563e-03, -1.2093e-02,
         3.0767e-02, -1.8917e-03, -2.8758e-02, -3.1957e-02,  1.0830e-02,
        -2.8293e-02,  2.0118e-02, -1.3978e-02,  2.4127e-02, -2.6401e-02,
        -1.8338e-02, -1.6814e-02,  2.7943e-02, -2.6335e-02, -1.1781e-02,
         1.2599e-02,  9.3972e-03,  1.6908e-02, -2.1192e-02,  1.1978e-02,
        -2.5978e-02,  1.8632e-02, -2.1782e-02, -3.5125e-02,  1.1440e-02,
        -6.7544e-04, -2.1564e-02, -3.1449e-02, -3.9266e-03, -

In [14]:
with torch.no_grad():
    for param in varianzas_net.parameters():
        param *= 0.82

In [15]:
print(varianzas_net.state_dict())

OrderedDict([('fc1.weight', tensor([[-0.0276,  0.0096, -0.0020,  ..., -0.0065, -0.0203,  0.0277],
        [ 0.0018,  0.0147, -0.0019,  ..., -0.0020, -0.0197, -0.0154],
        [-0.0190, -0.0117,  0.0147,  ...,  0.0038,  0.0076, -0.0094],
        ...,
        [ 0.0207,  0.0079,  0.0181,  ...,  0.0118,  0.0082, -0.0038],
        [ 0.0156, -0.0147, -0.0208,  ..., -0.0087, -0.0077, -0.0088],
        [-0.0233, -0.0262,  0.0065,  ...,  0.0259, -0.0280,  0.0274]],
       device='cuda:0')), ('fc1.bias', tensor([-1.4800e-02, -1.4662e-02,  1.7726e-02, -6.5241e-03, -9.9165e-03,
         2.5229e-02, -1.5512e-03, -2.3582e-02, -2.6205e-02,  8.8803e-03,
        -2.3200e-02,  1.6497e-02, -1.1462e-02,  1.9784e-02, -2.1649e-02,
        -1.5037e-02, -1.3788e-02,  2.2913e-02, -2.1595e-02, -9.6602e-03,
         1.0331e-02,  7.7057e-03,  1.3865e-02, -1.7377e-02,  9.8221e-03,
        -2.1302e-02,  1.5278e-02, -1.7861e-02, -2.8802e-02,  9.3810e-03,
        -5.5386e-04, -1.7682e-02, -2.5788e-02, -3.2198e-03, -

Definimos una red y le copiamos los pesos en una lista

In [6]:
net = Net().to(device)
varianzas_net = Net().to(device)

varianzas = []
for param in varianzas_net.parameters():
    varianzas.extend(param.data.clone().flatten().tolist())
print(varianzas_net.state_dict())


OrderedDict([('fc1.weight', tensor([[ 0.0183,  0.0323,  0.0195,  ..., -0.0197, -0.0136,  0.0116],
        [-0.0203,  0.0343,  0.0126,  ..., -0.0135,  0.0137, -0.0252],
        [ 0.0218, -0.0224, -0.0071,  ..., -0.0172,  0.0285, -0.0032],
        ...,
        [ 0.0176, -0.0034, -0.0225,  ...,  0.0046, -0.0051, -0.0052],
        [ 0.0082, -0.0039,  0.0242,  ..., -0.0116, -0.0223,  0.0338],
        [ 0.0210,  0.0131,  0.0139,  ...,  0.0256,  0.0317, -0.0064]],
       device='cuda:0')), ('fc1.bias', tensor([ 1.9351e-02, -3.1154e-02, -2.7616e-02, -2.2990e-03,  3.2229e-02,
        -2.8787e-02, -9.8190e-03, -2.9529e-04, -2.3027e-02, -2.0231e-02,
        -1.8741e-02, -1.1691e-02,  5.1121e-03,  3.8970e-03,  2.8642e-02,
         3.8653e-03,  1.6965e-03,  8.7548e-03,  3.4946e-02, -1.6820e-03,
        -6.3705e-03,  2.2441e-02,  9.5980e-03, -1.3547e-02, -1.9347e-02,
         1.0524e-02,  7.2295e-03, -2.8469e-02, -3.2139e-02,  3.2589e-03,
         2.0906e-02, -7.5973e-03, -2.9912e-02, -3.0602e-03,  

In [7]:
print(len(varianzas))
print(varianzas[0])

235146
0.01834365725517273


In [8]:
# Calculate the number of weights
num_weights = sum(p.numel() for p in varianzas_net.parameters() if p.requires_grad)
print(f"Number of weights: {num_weights}")

Number of weights: 235146


Cargamos los datos

In [9]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [10]:
# Definimos la función de pérdida para calcular el error
criterion = nn.CrossEntropyLoss()

In [11]:
# train loop
train_loss = []
test_accuracies = []
epochs = 10
accuracy_threshold = 0.6

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")
    if epoch + 1 == 1:
        pruned_net = prune_weights(net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()
        
        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break
    
    else:
        # Crear una copia de la red neuronal original
        varied_net = copy.deepcopy(net)
        print(varied_net.state_dict())

        # Actualizar los pesos de la red neuronal copiada
        with torch.no_grad():  # Desactiva el cálculo del gradiente para evitar que se almacenen gradientes innecesarios
            for param, varianza in zip(varied_net.parameters(), varianzas):
                param += varianza  # Sumar varianza a cada peso

        print(varied_net.state_dict())

        pruned_net = prune_weights(varied_net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()
        
        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break

        # actualizmos el vector de varianzas
        if improvements(train_loss) == 0:## +mejoras que peoras
            varianzas = [varianza * (1/0.82) for varianza in varianzas]
        
        elif improvements(train_loss) == 1: ## -mejoras que peoras
            varianzas = [varianza * 0.82 for varianza in varianzas]

        



Epoch 1
Train loss: 2.3013094569574286
Test accuracy: 0.1112
Epoch 2
OrderedDict([('fc1.weight', tensor([[-0.0243, -0.0327, -0.0067,  ..., -0.0165,  0.0229,  0.0104],
        [-0.0287,  0.0047, -0.0076,  ...,  0.0352,  0.0335,  0.0347],
        [ 0.0283,  0.0193,  0.0223,  ..., -0.0224, -0.0089, -0.0021],
        ...,
        [-0.0124, -0.0337, -0.0138,  ..., -0.0259,  0.0286, -0.0274],
        [ 0.0281,  0.0156, -0.0259,  ..., -0.0348,  0.0324,  0.0122],
        [-0.0151, -0.0169,  0.0270,  ...,  0.0207,  0.0337,  0.0083]],
       device='cuda:0')), ('fc1.bias', tensor([-1.7072e-02,  1.5967e-02,  6.9215e-03, -1.3616e-03,  7.5412e-03,
        -8.4647e-03, -3.5553e-02, -2.7519e-02, -4.8625e-03,  1.4957e-02,
         7.5889e-03,  2.2177e-05,  3.4757e-02, -1.4625e-02,  1.4038e-02,
        -3.0996e-02, -2.8554e-02, -3.4068e-02, -1.7054e-02, -1.0238e-02,
         2.0614e-02, -2.8742e-02, -3.0275e-02, -9.5916e-03,  1.2566e-02,
        -1.8234e-04,  2.6475e-03, -9.8234e-03, -3.3385e-02, -1.17

KeyboardInterrupt: 

In [12]:
print(test_accuracies)

[0.1112, 0.1116, 0.1639, 0.1116]
