<a href="https://colab.research.google.com/github/cabamarcos/SuperMask/blob/main/prueba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Creamos las redes

In [3]:
# Definimos las dos redes convolucionales
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class Mask(nn.Module):
    def __init__(self):
        super(Mask, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Inicializamos las dos redes y las movemos a la GPU si está disponible
net = Net().to(device)
mask = Mask().to(device)

Cargamos los datos

In [4]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [5]:
# Definimos la función de pérdida y los optimizadores
criterion = nn.CrossEntropyLoss()
optimizer_net = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
optimizer_mask = optim.SGD(mask.parameters(), lr=0.01, momentum=0.9)


In [6]:
def print_parameters(layer_name, params):
    print(f"--- {layer_name} ---")
    print(params)

def apply_mask(net, mask):
    # Aplica la máscara a la red: selecciona el 30% de los pesos más altos y desactiva el resto
    with torch.no_grad():
        for net_name, net_param, mask_name, mask_param in zip(net.state_dict(), net.parameters(), mask.state_dict(), mask.parameters()):
            print_parameters("Mask before applying", mask_param.data)
            print_parameters("Net before applying mask", net_param.data)

            mask_data = mask_param.data.abs()
            threshold = torch.quantile(mask_data, 0.7)
            mask_applied = (mask_data >= threshold).float()

            print_parameters("Mask applied (binary)", mask_applied)
            net_param.data[mask_applied == 0] = 0

            print_parameters("Net after applying mask", net_param.data)
            print("\n")


In [7]:
# Listas para almacenar las pérdidas de entrenamiento y las precisiones de validación
train_losses = []
val_accuracies = []
epochs = 10
accuracy_threshold = 0.6

for epoch in range(epochs):
    # Entrenamiento de la red
    net.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer_net.zero_grad()
        outputs = net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_net.step()

        running_loss += loss.item()


    train_losses.append(running_loss / len(train_loader))

    # Retropropagación del error en la máscara
    mask.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer_mask.zero_grad()
        outputs = mask(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_mask.step()

    # Aplicamos la máscara a la red
    apply_mask(net, mask)

    # Validación de la red
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    val_accuracies.append(accuracy)

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}, Validation Accuracy: {accuracy}')

    # Paramos el entrenamiento si la precisión en validación supera el 60%
    if accuracy > accuracy_threshold:
        break

--- Mask before applying ---
tensor([[[[-2.4490e-01, -1.5014e-01, -1.7546e-02],
          [-4.3704e-02, -3.2765e-01, -2.9462e-01],
          [-1.6410e-01,  1.6711e-01,  2.9621e-01]]],


        [[[ 3.0637e-01, -2.4663e-04, -7.8967e-02],
          [-1.3604e-01,  4.6451e-03,  2.6309e-02],
          [-2.1282e-02,  3.4126e-01,  1.4583e-01]]],


        [[[ 6.9449e-02,  3.1037e-01,  9.3735e-02],
          [-1.4655e-01, -2.5977e-01,  3.4955e-02],
          [ 4.6439e-02, -2.4189e-01, -2.1095e-02]]],


        [[[ 4.5984e-01,  3.5539e-01,  3.1896e-01],
          [ 4.9786e-02, -4.2887e-02,  2.9038e-01],
          [-4.6624e-01, -4.4480e-01, -2.5963e-01]]],


        [[[ 1.4127e-01,  3.8384e-01, -1.5490e-02],
          [ 3.3560e-01,  4.9874e-01,  4.1612e-01],
          [ 2.9740e-01,  4.7644e-02,  3.4876e-01]]],


        [[[ 1.8271e-02,  1.9640e-02,  5.9162e-03],
          [ 2.8896e-01, -5.3717e-03, -3.2405e-01],
          [-3.7980e-02,  1.9569e-01, -3.2188e-01]]],


        [[[ 1.0056e-01,  1.99

In [9]:
for mask_name, mask_param in zip(mask.state_dict(), mask.parameters()):

    mask_data = mask_param.data.abs()
    threshold = torch.quantile(mask_data, 0.7)
    mask_applied = (mask_data >= threshold).float()

    # Imprimir los pesos de la máscara aplicada
    print(f"Name: {mask_name}, Shape: {mask_applied.shape}")
    print(mask_applied)
    print()

Name: conv1.weight, Shape: torch.Size([32, 1, 3, 3])
tensor([[[[0., 0., 0.],
          [0., 1., 1.],
          [0., 0., 1.]]],


        [[[1., 0., 0.],
          [0., 0., 0.],
          [0., 1., 0.]]],


        [[[0., 1., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[0., 1., 0.],
          [0., 1., 1.],
          [0., 0., 1.]]],


        [[[0., 0., 0.],
          [1., 0., 1.],
          [0., 0., 1.]]],


        [[[0., 0., 1.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[0., 0., 1.],
          [0., 0., 0.],
          [0., 0., 1.]]],


        [[[0., 1., 0.],
          [0., 1., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 1.],
          [0., 0., 1.],
          [1., 0., 0.]]],


        [[[1., 0., 1.],
          [1., 0., 0.],
          [1., 0., 0.]]],


        [[[1., 1., 0.],
          [0., 0., 0.],
          [0., 1., 1.]]],


        [[[1., 0., 0.],
          [

In [10]:
for var_name in net.state_dict():
    print(var_name, "\t", net.state_dict()[var_name])

conv1.weight 	 tensor([[[[ 6.3885e-03, -3.3876e-02, -5.3445e-02],
          [ 2.4904e-02,  2.5092e-01,  1.5653e-01],
          [-3.3412e-02, -3.9517e-02,  3.3785e-02]]],


        [[[-4.3152e-01, -9.8103e-03,  3.6888e-02],
          [-3.7973e-02, -3.0596e-02,  4.0125e-02],
          [-3.0867e-02, -2.1032e-01,  6.4338e-02]]],


        [[[-4.5691e-03,  2.3915e-01,  1.7124e-02],
          [ 4.2205e-02,  5.1494e-02,  1.6485e-02],
          [ 3.4034e-02,  2.7736e-02, -4.9360e-03]]],


        [[[-3.5523e-01, -2.1569e-01,  2.2272e-02],
          [-3.1559e-02, -3.4600e-02, -2.2038e-01],
          [-1.4232e-01, -9.1369e-02, -2.0227e-02]]],


        [[[-2.7018e-02, -2.4891e-01, -2.0091e-02],
          [-1.3881e-01, -3.9471e-01,  1.9481e-01],
          [-4.5995e-02, -4.8298e-02, -1.7681e-01]]],


        [[[ 3.1643e-02,  7.9653e-02,  8.6217e-02],
          [-1.1867e-01, -2.1388e-02, -1.8078e-01],
          [-5.5955e-02, -4.1379e-02, -3.7870e-01]]],


        [[[ 7.1706e-02,  8.4158e-02,  3.061