In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


import numpy as np
import matplotlib.pyplot as plt
import copy
from ast import Param

from utils.prune import prune_weights

In [7]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [8]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        



Definimos una red y le copiamos los pesos en una lista

In [10]:
net = Net().to(device)
varianzas_net = Net().to(device)

varianzas = []
for param in varianzas_net.parameters():
    varianzas.extend(param.data.clone().flatten().tolist())
print(varianzas_net.state_dict())


OrderedDict({'fc1.weight': tensor([[-0.0029,  0.0290, -0.0343,  ..., -0.0217, -0.0126,  0.0237],
        [-0.0031, -0.0088,  0.0077,  ...,  0.0128, -0.0248, -0.0353],
        [ 0.0029, -0.0096, -0.0051,  ..., -0.0028,  0.0191,  0.0127],
        ...,
        [ 0.0166, -0.0056, -0.0026,  ..., -0.0068, -0.0181,  0.0081],
        [ 0.0257, -0.0076,  0.0179,  ..., -0.0251,  0.0145,  0.0244],
        [-0.0201,  0.0275, -0.0145,  ...,  0.0045, -0.0289, -0.0089]],
       device='cuda:0'), 'fc1.bias': tensor([ 4.2715e-03,  9.3399e-03, -3.0948e-02,  1.4484e-02, -3.1076e-02,
        -5.2877e-03,  3.4754e-02,  2.7919e-03,  3.4709e-02,  3.3284e-02,
        -2.9984e-02,  1.7072e-02,  2.7890e-02,  3.2782e-02,  1.5063e-02,
        -2.5879e-02, -1.5474e-02, -8.1720e-03,  1.1609e-02, -2.2175e-02,
         3.0783e-03,  2.5014e-02, -3.2224e-02,  6.8078e-03,  1.9075e-02,
         4.2264e-03, -1.4352e-02, -5.5661e-04,  1.1375e-02,  2.0536e-02,
        -3.3622e-04, -1.8502e-02, -3.5701e-02, -3.2197e-02, -5.9

In [16]:
print(len(varianzas))
print(varianzas[0])

235146
-0.0029386505484580994


In [15]:
# Calculate the number of weights
num_weights = sum(p.numel() for p in varianzas_net.parameters() if p.requires_grad)
print(f"Number of weights: {num_weights}")

Number of weights: 235146


Cargamos los datos

In [17]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 9602572.60it/s] 


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 284388.13it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2410801.86it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4437579.49it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [18]:
# Definimos la función de pérdida para calcular el error
criterion = nn.CrossEntropyLoss()

In [None]:
# train loop
train_loss = []
test_accuracies = []
epochs = 10
accuracy_threshold = 0.6

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")
    if epoch + 1 == 1:
        pruned_net = prune_weights(net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()
        
        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

        # Evaluamos el modelo en el conjunto de test
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                outputs_net = pruned_net(images)
                _, predicted = torch.max(outputs_net.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracies.append(correct / total)
        print(f"Test accuracy: {correct / total}")

        if correct / total > accuracy_threshold:
            break
    
    else:
        # Crear una copia de la red neuronal original
        varied_net = copy.deepcopy(net)

        # Actualizar los pesos de la red neuronal copiada
        varied_net.weights = [peso + varianza for peso, varianza in zip(varied_net.weights, varianzas)]

        pruned_net = prune_weights(varied_net)

        running_loss = 0.0
        # Pasamos todas las imagenes de train por la red net
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs_net = pruned_net(images)
            loss = criterion(outputs_net, labels)

            running_loss += loss.item()
        
        train_loss.append(running_loss / len(train_loader))
        print(f"Train loss: {running_loss / len(train_loader)}")

