In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


import numpy as np
import matplotlib.pyplot as plt
import copy
from ast import Param

from utils.prune import prune_weights
from utils.count_improvement import improvements

In [11]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [12]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        



Definimos una red y le copiamos los pesos en una lista

In [13]:
net = Net().to(device)
varianzas_net = Net().to(device)

varianzas = []
for param in varianzas_net.parameters():
    varianzas.extend(param.data.clone().flatten().tolist())
print(varianzas_net.state_dict())


OrderedDict({'fc1.weight': tensor([[-3.7824e-03, -1.2473e-02, -9.3940e-03,  ...,  2.7885e-02,
         -3.3165e-02,  1.5361e-02],
        [-7.3810e-03, -1.9430e-02,  3.0847e-02,  ...,  1.4536e-02,
         -2.7475e-02,  1.9282e-02],
        [-3.3374e-02,  1.4796e-02, -6.4544e-04,  ..., -2.5812e-02,
         -1.8412e-02, -3.0038e-02],
        ...,
        [-2.6505e-05, -1.5706e-02, -5.8061e-03,  ...,  2.8293e-02,
         -1.5007e-02,  5.4961e-03],
        [-2.0942e-02, -2.4197e-02, -8.9763e-03,  ...,  3.4299e-02,
          4.2806e-03, -2.9122e-03],
        [-2.6377e-02,  2.3919e-02, -1.2575e-02,  ..., -1.4290e-02,
         -2.0268e-02,  2.7221e-02]], device='cuda:0'), 'fc1.bias': tensor([-3.1602e-02, -1.7118e-02,  1.6287e-02, -1.1305e-02,  3.5280e-02,
         3.2186e-02, -2.4522e-02, -3.0532e-02,  3.0476e-02, -9.1163e-03,
         2.1308e-04, -8.0643e-03, -1.4271e-02, -2.4243e-03,  1.7208e-02,
         6.7622e-03, -3.4222e-02,  1.9659e-02, -3.0102e-02,  2.9544e-02,
        -9.0373e-03

In [14]:
print(len(varianzas))
print(varianzas[0])

235146
-0.003782421350479126


In [15]:
# Calculate the number of weights
num_weights = sum(p.numel() for p in varianzas_net.parameters() if p.requires_grad)
print(f"Number of weights: {num_weights}")

Number of weights: 235146


Cargamos los datos

In [16]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [17]:
# Definimos la función de pérdida para calcular el error
criterion = nn.CrossEntropyLoss()

In [18]:
# train loop
train_loss = []
test_accuracies = []
epochs = 10
accuracy_threshold = 0.6

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")
    if epoch + 1 == 1:
        pruned_net = prune_weights(net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()
        
        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break
    
    else:
        # Crear una copia de la red neuronal original
        varied_net = copy.deepcopy(net)
        print(varied_net.state_dict())

        # Actualizar los pesos de la red neuronal copiada
        with torch.no_grad():  # Desactiva el cálculo del gradiente para evitar que se almacenen gradientes innecesarios
            for param, varianza in zip(varied_net.parameters(), varianzas):
                param += varianza  # Sumar varianza a cada peso

        print(varied_net.state_dict())

        pruned_net = prune_weights(varied_net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()
        
        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break

        # actualizmos el vector de varianzas
        if improvements(train_loss) == 0:## +mejoras que peoras
            varianzas = [varianza * (1/0.82) for varianza in varianzas]
        
        elif improvements(train_loss) == 1: ## -mejoras que peoras
            varianzas = [varianza * 0.82 for varianza in varianzas]

        



Epoch 1
Train loss: 2.308339726950314
Test accuracy: 0.0924
Epoch 2
OrderedDict({'fc1.weight': tensor([[-0.0346,  0.0000,  0.0287,  ...,  0.0287,  0.0000,  0.0000],
        [ 0.0265,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0345,  ...,  0.0000,  0.0333,  0.0000],
        ...,
        [-0.0301,  0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0299],
        [-0.0321,  0.0000,  0.0000,  ...,  0.0000,  0.0297,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0'), 'fc1.bias': tensor([-1.0868e-02, -5.9395e-03, -2.5255e-02, -6.9755e-03, -2.5891e-02,
         4.1298e-03,  1.4969e-02, -8.4799e-03, -1.0652e-02,  3.3648e-02,
         1.7382e-02,  2.1708e-02,  1.1924e-02, -1.4142e-02, -2.9194e-03,
         2.2122e-02,  1.5926e-02,  7.5681e-03,  2.5220e-05, -2.0818e-02,
        -3.0350e-03, -3.1327e-02, -1.5511e-03,  1.2087e-02,  2.7704e-02,
        -1.3782e-02, -3.4459e-02,  2.0656e-02, -3.1582e-02,  2.1182e-

KeyboardInterrupt: 