In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


import numpy as np
import matplotlib.pyplot as plt
import copy
from ast import Param

In [2]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        

net = Net().to(device)

Definimos una red y le copiamos los pesos en una lista

In [4]:
varianzas = Net().to(device)

varianzas_weights = []
for param_varianzas, param_net in zip(varianzas.parameters(), net.parameters()):
    varianzas_weights.extend(param_varianzas.data.clone().flatten().tolist())
print(varianzas.state_dict())


OrderedDict({'fc1.weight': tensor([[ 0.0304, -0.0331,  0.0005,  ..., -0.0162, -0.0085, -0.0248],
        [ 0.0231,  0.0208, -0.0241,  ...,  0.0052, -0.0138, -0.0281],
        [ 0.0019,  0.0106,  0.0052,  ...,  0.0116, -0.0243, -0.0261],
        ...,
        [-0.0103,  0.0176,  0.0029,  ...,  0.0226, -0.0179,  0.0319],
        [ 0.0121, -0.0344, -0.0202,  ...,  0.0271,  0.0110,  0.0354],
        [ 0.0138,  0.0149, -0.0064,  ...,  0.0077,  0.0253,  0.0236]],
       device='cuda:0'), 'fc1.bias': tensor([ 4.4834e-03,  2.6346e-03,  3.1669e-02,  2.5477e-02, -3.3289e-02,
        -8.5888e-03,  9.3318e-03,  2.5616e-02,  3.5045e-02,  2.1531e-02,
         1.9039e-02,  3.1133e-02,  2.8574e-02,  2.9702e-02,  9.7054e-03,
        -2.3100e-03,  9.2436e-05,  1.0286e-02, -3.2472e-02,  2.8419e-02,
         2.9794e-02, -1.7768e-02,  2.4497e-02, -1.3386e-03,  3.5932e-03,
        -2.0014e-02,  2.3991e-02, -1.8655e-02, -4.0346e-03,  1.1647e-02,
        -1.1759e-02, -2.4971e-03,  2.5795e-02, -2.2754e-02, -2.7

In [5]:
print(varianzas_weights)

[0.030441921204328537, -0.03313577175140381, 0.0004920437932014465, -0.024555722251534462, -0.015248393639922142, 0.010074857622385025, 0.018874432891607285, 0.016369342803955078, 0.013306848704814911, 0.013377416878938675, -0.01731564849615097, 0.01621868833899498, 0.028983216732740402, 0.0021269842982292175, -0.019943008199334145, -0.035063859075307846, -0.0017068088054656982, -0.0313526876270771, 0.03369644656777382, -0.022315319627523422, 0.03553677722811699, 0.022075287997722626, -0.012568926438689232, -0.012643985450267792, -0.02577032521367073, 0.02236248552799225, 0.02899354323744774, -0.006996802985668182, 0.004771728068590164, 0.03487760201096535, 0.02196580544114113, 0.0032808594405651093, 0.0002179853618144989, -0.012343799695372581, 0.0038627348840236664, -0.00244826078414917, -0.00802166573703289, 0.014753229916095734, 0.018075447529554367, 0.008446641266345978, -0.027879813686013222, 0.02195144072175026, -0.011231040582060814, 0.02116941660642624, -0.029343338683247566, 

Cargamos los datos

In [None]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [None]:
# Definimos la función de pérdida y los optimizadores
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

In [None]:
def prune_weights(net):
    """This function prunes the 70% of the lowest weights in each layer of the network and sets them to zero"""
    for name, module in net.named_modules():
        if isinstance(module, nn.Linear):
            weights = module.weight.data.cpu().numpy()
            threshold = np.percentile(np.abs(weights), 70)
            weights[np.abs(weights) < threshold] = 0
            module.weight.data = torch.from_numpy(weights).to(device)
    return net  