In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


import numpy as np
import matplotlib.pyplot as plt
import copy
from ast import Param

from utils.prune import prune_weights
from utils.count_improvement import improvements

In [2]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Definimos las redes

In [4]:
net = Net().to(device)
varianzas_net = Net().to(device)

In [5]:
# Sumamos los pesos de las redes en una red nueva
# Extraer los diccionarios de estado (pesos y biases)
state_dict_red1 = net.state_dict()
state_dict_red2 = varianzas_net.state_dict()

print(state_dict_red1)

OrderedDict([('fc1.weight', tensor([[ 0.0329,  0.0145,  0.0162,  ..., -0.0349,  0.0297, -0.0134],
        [-0.0265,  0.0115, -0.0213,  ..., -0.0310, -0.0078, -0.0270],
        [-0.0020, -0.0248, -0.0323,  ...,  0.0115,  0.0282, -0.0235],
        ...,
        [-0.0202,  0.0118, -0.0123,  ..., -0.0182, -0.0240,  0.0049],
        [-0.0201, -0.0269,  0.0354,  ...,  0.0032, -0.0330, -0.0303],
        [-0.0294,  0.0118,  0.0177,  ...,  0.0267,  0.0110,  0.0111]],
       device='cuda:0')), ('fc1.bias', tensor([ 2.0693e-03,  3.3079e-02,  1.7259e-02, -1.0592e-02, -1.8672e-02,
        -2.0285e-02, -1.9722e-02, -1.0663e-02,  3.0717e-02,  2.1539e-02,
        -2.0625e-02, -2.4775e-02, -2.2755e-02, -1.7345e-02,  2.4999e-02,
        -6.6528e-03, -2.5787e-03, -2.2601e-02, -9.1202e-03,  2.6998e-02,
        -8.6464e-03,  1.0482e-02, -3.9921e-03,  2.7064e-02,  1.5890e-02,
         2.8549e-02, -3.1854e-02,  1.3721e-02,  1.2319e-02,  1.6060e-02,
        -3.6185e-03, -4.4872e-03,  2.8199e-03, -7.3405e-03, -

In [6]:
print(state_dict_red2)

OrderedDict([('fc1.weight', tensor([[-0.0202, -0.0236, -0.0277,  ..., -0.0025, -0.0215, -0.0151],
        [ 0.0003,  0.0086,  0.0023,  ...,  0.0082,  0.0102, -0.0076],
        [ 0.0143, -0.0035,  0.0010,  ...,  0.0330, -0.0228, -0.0071],
        ...,
        [ 0.0327, -0.0143,  0.0168,  ..., -0.0109, -0.0060, -0.0102],
        [-0.0205, -0.0244,  0.0044,  ...,  0.0333, -0.0039,  0.0043],
        [-0.0185,  0.0199, -0.0327,  ..., -0.0178,  0.0232,  0.0042]],
       device='cuda:0')), ('fc1.bias', tensor([-2.9930e-02, -3.0819e-02,  1.7896e-02, -2.6291e-02, -3.3947e-02,
        -6.9935e-05, -1.8060e-02, -2.9968e-02, -1.1740e-02,  3.4875e-03,
        -2.3601e-02, -3.2802e-02, -8.6999e-03,  3.6323e-03, -1.6533e-02,
        -2.4026e-02, -1.7625e-02,  2.8350e-02,  6.6661e-03, -2.7433e-02,
         7.6569e-03,  3.5379e-02,  1.8759e-02,  1.8571e-02, -6.9187e-03,
         1.4507e-02, -8.1061e-03,  3.1394e-02, -1.3479e-02,  2.5988e-02,
        -3.5602e-02, -2.0645e-02, -7.6443e-03, -1.1035e-03,  

In [12]:
# Crear un nuevo diccionario de estado donde sumamos los pesos
state_dict_suma = {}
for key in state_dict_red1:
    if state_dict_red1[key].size() == state_dict_red2[key].size():  # Asegurar que las dimensiones coincidan
        # Asegurar que las desviaciones estándar sean positivas para generar el ruido
        std_dev = torch.abs(state_dict_red2[key])
        
        # Generamos los valores aleatorios con una distribución normal usando torch.normal
        noise = torch.normal(0, std_dev)  # Media = 0, Desviación estándar = std_dev
        
        # Crear una máscara para determinar si debemos sumar o restar
        mask_negativa = state_dict_red2[key] < 0  # Máscara de valores negativos
        print(mask_negativa)
        
        # Aplicar la operación de suma o resta dependiendo de la máscara
        state_dict_suma[key] = torch.where(mask_negativa, state_dict_red1[key] - noise, state_dict_red1[key] + noise)
    else:
        # Si los tamaños no coinciden, copiamos directamente
        state_dict_suma[key] = state_dict_red1[key]

# Crear una nueva red o modificar una existente con los pesos sumados
varied_net = Net().to(device)
varied_net.load_state_dict(state_dict_suma)


tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False, False, False,  ..., False, False,  True],
        [False,  True, False,  ..., False,  True,  True],
        ...,
        [False,  True, False,  ...,  True,  True,  True],
        [ True,  True, False,  ..., False,  True, False],
        [ True, False,  True,  ...,  True, False, False]], device='cuda:0')
tensor([ True,  True, False,  True,  True,  True,  True,  True,  True, False,
         True,  True,  True, False,  True,  True,  True, False, False,  True,
        False, False, False, False,  True, False,  True, False,  True, False,
         True,  True,  True,  True, False, False,  True, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False,  True, False, False, False,  True, False,  True, False, False,
        False, False,  True, False,  True,  True, False,  True, False, False,
         True,  True,  True, False,  True, False,  True,  True, False,  Tru

<All keys matched successfully>

In [13]:
print(varied_net.state_dict())

OrderedDict([('fc1.weight', tensor([[ 0.0266,  0.0530,  0.0454,  ..., -0.0381,  0.0181, -0.0044],
        [-0.0259,  0.0091, -0.0181,  ..., -0.0351, -0.0011, -0.0231],
        [-0.0099, -0.0208, -0.0306,  ..., -0.0040,  0.0320, -0.0076],
        ...,
        [-0.0295,  0.0128, -0.0085,  ..., -0.0241, -0.0286,  0.0038],
        [-0.0078, -0.0085,  0.0293,  ..., -0.0078, -0.0314, -0.0325],
        [-0.0372,  0.0023,  0.0272,  ...,  0.0127,  0.0481,  0.0197]],
       device='cuda:0')), ('fc1.bias', tensor([ 8.0336e-03,  1.5723e-02,  2.2015e-02, -1.8940e-02, -1.0206e-02,
        -2.0172e-02, -2.8258e-02,  1.9947e-02,  2.6587e-02,  2.1092e-02,
        -4.5499e-02, -2.0760e-02, -4.5640e-02, -1.2931e-02, -4.9966e-03,
        -3.9620e-04,  4.1057e-03, -6.5178e-02, -3.8012e-03,  1.7038e-02,
        -1.2878e-02,  5.5952e-02, -1.0810e-02,  4.7487e-02,  7.6296e-03,
         3.0221e-02, -3.6920e-02,  1.1840e-02,  4.9490e-03,  6.2418e-03,
         3.8920e-02, -8.1679e-03,  2.4282e-03, -6.3142e-03,  

Cargamos los datos

In [5]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 16004834.76it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 489279.00it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3793480.99it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3930375.24it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [6]:
# Definimos la función de pérdida para calcular el error
criterion = nn.CrossEntropyLoss()

In [9]:
# train loop
train_loss = []
test_accuracies = []
epoch = 1
accuracy_threshold = 0.6

while True:
    print(f"Epoch {epoch}")
    if epoch <= 9:
        pruned_net = prune_weights(net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()

        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break

    else:
        # Sumamos los pesos de las redes en una red nueva
        # Extraer los diccionarios de estado (pesos y biases)
        state_dict_red1 = net.state_dict()
        state_dict_red2 = varianzas_net.state_dict()

        # Crear un nuevo diccionario de estado donde sumamos los pesos
        state_dict_suma = {}
        for key in state_dict_red1:
            state_dict_suma[key] = state_dict_red1[key] + state_dict_red2[key]

        # Crear una nueva red o modificar una existente con los pesos sumados
        varied_net = Net().to(device)
        varied_net.load_state_dict(state_dict_suma)

        #print(varied_net.state_dict())

        pruned_net = prune_weights(varied_net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()

        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break

        print(test_accuracies)

        # actualizmos el vector de varianzas
        if improvements(test_accuracies) == 0:## +mejoras que peoras
            print("+mejoras que peoras")
            with torch.no_grad():
                for param in varianzas_net.parameters():
                    param *= (1/0.82)

        else: ## -mejoras que peoras
            print(" -mejoras que peoras")
            with torch.no_grad():
                for param in varianzas_net.parameters():
                    param *= 0.82





Epoch 1
Train loss: 2.3101668157048825
Test accuracy: 0.0896
Epoch 2
Train loss: 2.2946895693919296
Test accuracy: 0.1141
[0.0896, 0.1141]
Peoras
 -mejoras que peoras
Epoch 3
Train loss: 2.29320497426397
Test accuracy: 0.129
[0.0896, 0.1141, 0.129]
Peoras
 -mejoras que peoras
Epoch 4
Train loss: 2.2939458206010017
Test accuracy: 0.1106
[0.0896, 0.1141, 0.129, 0.1106]
Peoras
 -mejoras que peoras
Epoch 5
Train loss: 2.296246580223539
Test accuracy: 0.1015
[0.0896, 0.1141, 0.129, 0.1106, 0.1015]
Mejoras
+mejoras que peoras
Epoch 6
Train loss: 2.293943423452154
Test accuracy: 0.1106
[0.0896, 0.1141, 0.129, 0.1106, 0.1015, 0.1106]
Peoras
 -mejoras que peoras
Epoch 7
Train loss: 2.296235585263543
Test accuracy: 0.1015
[0.0896, 0.1141, 0.129, 0.1106, 0.1015, 0.1106, 0.1015]
Mejoras
+mejoras que peoras
Epoch 8


KeyboardInterrupt: 

In [None]:
print(test_accuracies)