In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


import numpy as np
import matplotlib.pyplot as plt
import copy
from ast import Param

from utils.prune import prune_weights
from utils.count_improvement import improvements

In [2]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Definimos las redes

In [4]:
net = Net().to(device)
varianzas_net = Net().to(device)

Cargamos los datos

In [5]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [6]:
# Definimos la función de pérdida para calcular el error
criterion = nn.CrossEntropyLoss()

In [7]:
# train loop
train_loss = []
test_accuracies = []
epoch = 1
accuracy_threshold = 0.4


while True:
    print(f"Epoch {epoch}")
    if epoch <= 9:

        state_dict_red1 = net.state_dict()
        state_dict_red2 = varianzas_net.state_dict()

        # Crear un nuevo diccionario de estado donde sumamos los pesos
        state_dict_suma = {}
        for key in state_dict_red1:
            if state_dict_red1[key].size() == state_dict_red2[key].size():  # Asegurar que las dimensiones coincidan
                # Asegurar que las desviaciones estándar sean positivas para generar el ruido
                std_dev = torch.abs(state_dict_red2[key])
                
                # Generamos los valores aleatorios con una distribución normal usando torch.normal
                noise = torch.normal(0, std_dev)  # Media = 0, Desviación estándar = std_dev
                
                # Crear una máscara para determinar si debemos sumar o restar
                mask_negativa = state_dict_red2[key] < 0  # Máscara de valores negativos
                
                # Aplicar la operación de suma o resta dependiendo de la máscara
                state_dict_suma[key] = torch.where(mask_negativa, state_dict_red1[key] - noise, state_dict_red1[key] + noise)
            else:
                # Si los tamaños no coinciden, copiamos directamente
                state_dict_suma[key] = state_dict_red1[key]

        # Crear una nueva red o modificar una existente con los pesos sumados
        varied_net = Net().to(device)
        varied_net.load_state_dict(state_dict_suma)

        pruned_net = prune_weights(varied_net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()

        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break

    else:
        # Sumamos los pesos de las redes en una red nueva
        # Extraer los diccionarios de estado (pesos y biases)
        state_dict_red1 = net.state_dict()
        state_dict_red2 = varianzas_net.state_dict()

        # Crear un nuevo diccionario de estado donde sumamos los pesos
        state_dict_suma = {}
        for key in state_dict_red1:
            if state_dict_red1[key].size() == state_dict_red2[key].size():  # Asegurar que las dimensiones coincidan
                # Asegurar que las desviaciones estándar sean positivas para generar el ruido
                std_dev = torch.abs(state_dict_red2[key])
                
                # Generamos los valores aleatorios con una distribución normal usando torch.normal
                noise = torch.normal(0, std_dev)  # Media = 0, Desviación estándar = std_dev
                
                # Crear una máscara para determinar si debemos sumar o restar
                mask_negativa = state_dict_red2[key] < 0  # Máscara de valores negativos
                
                # Aplicar la operación de suma o resta dependiendo de la máscara
                state_dict_suma[key] = torch.where(mask_negativa, state_dict_red1[key] - noise, state_dict_red1[key] + noise)
            else:
                # Si los tamaños no coinciden, copiamos directamente
                state_dict_suma[key] = state_dict_red1[key]


        # Crear una nueva red o modificar una existente con los pesos sumados
        varied_net = Net().to(device)
        varied_net.load_state_dict(state_dict_suma)

        #print(varied_net.state_dict())

        pruned_net = prune_weights(varied_net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()

        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break

        # actualizmos el vector de varianzas
        if improvements(train_loss) > 2:## +mejoras que peoras
            with torch.no_grad():
                for param in varianzas_net.parameters():
                    param *= (1/0.82)

        elif improvements(train_loss) < 2: ## -mejoras que peoras
            with torch.no_grad():
                for param in varianzas_net.parameters():
                    param *= 0.82
        
        else:
            pass
    
    epoch += 1

Epoch 1
Train loss: 2.3277546030117757
Test accuracy: 0.0974
Epoch 2
Train loss: 2.3238662994746715
Test accuracy: 0.087
Epoch 3
Train loss: 2.3173148293993364
Test accuracy: 0.1072
Epoch 4
Train loss: 2.3382061407255974
Test accuracy: 0.0902
Epoch 5
Train loss: 2.313588749625281
Test accuracy: 0.0859
Epoch 6
Train loss: 2.3150027115970278
Test accuracy: 0.1269
Epoch 7
Train loss: 2.311042940184506
Test accuracy: 0.136
Epoch 8
Train loss: 2.323499440384318
Test accuracy: 0.093
Epoch 9
Train loss: 2.3145109956452585
Test accuracy: 0.0922
Epoch 10
Train loss: 2.3345805678540454
Test accuracy: 0.0507
Epoch 11
Train loss: 2.3541113760933947
Test accuracy: 0.1103
Epoch 12
Train loss: 2.4060182886591344
Test accuracy: 0.0903
Epoch 13
Train loss: 2.5024724436212957
Test accuracy: 0.0748
Epoch 14
Train loss: 2.8167611019951955
Test accuracy: 0.1319
Epoch 15
Train loss: 2.9181126226494305
Test accuracy: 0.1329
Epoch 16
Train loss: 2.5016650632500395
Test accuracy: 0.1589
Epoch 17
Train loss: 2.

KeyboardInterrupt: 

In [8]:
print(max(test_accuracies))

0.1589


In [12]:
print(train_loss)

[2.3089392459722977, 2.3089258322583586, 2.3089331533354738, 2.3089323369170556, 2.308943989688654, 2.3089327558017234, 2.3089386578053555, 2.308926501507952, 2.308933406496353, 2.331363138867848, 2.3078895896228393, 2.440253779323879, 2.4666561267015013, 2.79058767127584, 3.3400989626008055, 3.010180982953704, 4.811453668801769, 4.418321354556948, 10.217527754525387, 18.82480004296374, 13.772105462007177, 21.99386759556687, 45.70997622475695, 64.74080042239191, 143.58107632309643, 139.96319021294113, 440.29322326920436, 250.359760536568, 480.15460940363056, 1241.5257479864906, 1448.8484698785649, 738.5807276483793, 1341.848469618287, 3422.6838277397887, 5481.837048105594, 5314.774295636077, 7646.559327733542, 8059.600765529218, 10266.443307839985, 6653.3415038021385, 24605.685163746002, 22956.6271446895, 31916.929951942297, 67563.10882945762, 80467.56999600213, 94896.8792810501, 102030.40131929638, 83748.10826725746, 259069.01502531982, 180015.7687566631, 423983.3984874734, 340118.085