In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


import numpy as np
import matplotlib.pyplot as plt
import copy
from ast import Param

from utils.prune import prune_weights
from utils.count_improvement import improvements

In [9]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Definimos las redes

In [11]:
net = Net().to(device)
varianzas_net = Net().to(device)

Cargamos los datos

In [12]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [13]:
# Definimos la función de pérdida para calcular el error
criterion = nn.CrossEntropyLoss()

In [14]:
# train loop
train_loss = []
test_accuracies = []
epoch = 1
accuracy_threshold = 0.6

while True:
    print(f"Epoch {epoch}")
    if epoch <= 9:
        pruned_net = prune_weights(net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()

        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break

    else:
        # Sumamos los pesos de las redes en una red nueva
        # Extraer los diccionarios de estado (pesos y biases)
        state_dict_red1 = net.state_dict()
        state_dict_red2 = varianzas_net.state_dict()

        # Crear un nuevo diccionario de estado donde sumamos los pesos
        state_dict_suma = {}
        for key in state_dict_red1:
            if state_dict_red1[key].size() == state_dict_red2[key].size():  # Asegurar que las dimensiones coincidan
                # Asegurar que las desviaciones estándar sean positivas para generar el ruido
                std_dev = torch.abs(state_dict_red2[key])
                
                # Generamos los valores aleatorios con una distribución normal usando torch.normal
                noise = torch.normal(0, std_dev)  # Media = 0, Desviación estándar = std_dev
                
                # Crear una máscara para determinar si debemos sumar o restar
                mask_negativa = state_dict_red2[key] < 0  # Máscara de valores negativos
                
                # Aplicar la operación de suma o resta dependiendo de la máscara
                state_dict_suma[key] = torch.where(mask_negativa, state_dict_red1[key] - noise, state_dict_red1[key] + noise)
            else:
                # Si los tamaños no coinciden, copiamos directamente
                state_dict_suma[key] = state_dict_red1[key]


        # Crear una nueva red o modificar una existente con los pesos sumados
        varied_net = Net().to(device)
        varied_net.load_state_dict(state_dict_suma)

        #print(varied_net.state_dict())

        pruned_net = prune_weights(varied_net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()

        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break

        # actualizmos el vector de varianzas
        if improvements(train_loss) > 2:## +mejoras que peoras
            with torch.no_grad():
                for param in varianzas_net.parameters():
                    param *= (1/0.82)

        elif improvements(train_loss) < 2: ## -mejoras que peoras
            with torch.no_grad():
                for param in varianzas_net.parameters():
                    param *= 0.82
        
        else:
            pass
    
    epoch += 1

Epoch 1
Train loss: 2.29743431016072
Test accuracy: 0.1018
Epoch 2
Train loss: 2.297431904369834
Test accuracy: 0.1018
Epoch 3
Train loss: 2.2974305574827865
Test accuracy: 0.1018
Epoch 4
Train loss: 2.2974323141041086
Test accuracy: 0.1018
Epoch 5
Train loss: 2.2974322398842526
Test accuracy: 0.1018
Epoch 6
Train loss: 2.297430905197729
Test accuracy: 0.1018
Epoch 7
Train loss: 2.2974418597434885
Test accuracy: 0.1018
Epoch 8
Train loss: 2.297443937136929
Test accuracy: 0.1018
Epoch 9
Train loss: 2.297441573031167
Test accuracy: 0.1018
Epoch 10
Train loss: 2.3018326761880155
Test accuracy: 0.1037
Epoch 11
Train loss: 2.340897086332602
Test accuracy: 0.104
Epoch 12
Train loss: 2.3837192025520144
Test accuracy: 0.1248
Epoch 13
Train loss: 2.363578462397366
Test accuracy: 0.0958
Epoch 14
Train loss: 2.685446994645255
Test accuracy: 0.1193
Epoch 15
Train loss: 3.2704972745513103
Test accuracy: 0.1243
Epoch 16
Train loss: 4.104560978885399
Test accuracy: 0.055
Epoch 17
Train loss: 3.237177

KeyboardInterrupt: 

In [15]:
print(test_accuracies)

[0.1018, 0.1018, 0.1018, 0.1018, 0.1018, 0.1018, 0.1018, 0.1018, 0.1018, 0.1037, 0.104, 0.1248, 0.0958, 0.1193, 0.1243, 0.055, 0.1008, 0.1269, 0.1048, 0.1264, 0.1287, 0.1414, 0.0512, 0.1119, 0.1534, 0.1225, 0.15, 0.1042, 0.0608, 0.0766, 0.1183, 0.1133, 0.1203, 0.1039, 0.082, 0.1012, 0.0983, 0.1011, 0.07, 0.0864, 0.0702, 0.1205, 0.1163, 0.0911, 0.0873, 0.0954, 0.1004]


In [None]:
print(train_loss)