<a href="https://colab.research.google.com/github/cabamarcos/SuperMask/blob/main/prueba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [11]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Creamos las redes

In [12]:
# Definimos las dos redes convolucionales
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class Mask(nn.Module):
    def __init__(self):
        super(Mask, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Inicializamos las dos redes y las movemos a la GPU si está disponible
net = Net().to(device)
mask = Mask().to(device)

Cargamos los datos

In [13]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [14]:
# Definimos la función de pérdida y los optimizadores
criterion = nn.CrossEntropyLoss()
optimizer_net = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
optimizer_mask = optim.SGD(mask.parameters(), lr=0.01, momentum=0.9)


In [15]:
def print_parameters(layer_name, params):
    print(f"--- {layer_name} ---")
    print(params)

def apply_mask(net, mask):
    # Aplica la máscara a la red: selecciona el 30% de los pesos más altos y desactiva el resto
    with torch.no_grad():
        for net_name, net_param, mask_name, mask_param in zip(net.state_dict(), net.parameters(), mask.state_dict(), mask.parameters()):
            #print_parameters("Mask before applying", mask_param.data)
            #print_parameters("Net before applying mask", net_param.data)

            mask_data = mask_param.data.abs()
            threshold = torch.quantile(mask_data, 0.7)
            mask_applied = (mask_data >= threshold).float()

            #print_parameters("Mask applied (binary)", mask_applied)
            net_param.data[mask_applied == 0] = 0

            #print_parameters("Net after applying mask", net_param.data)
            #print("\n")


In [16]:
# Listas para almacenar las pérdidas de entrenamiento y las precisiones de validación
train_loss = []
test_accuracies = []
epochs = 10
accuracy_threshold = 0.6

for epoch in range(epochs):
    # Entrenamiento de la red
    print(f"Epoch {epoch + 1}")

    # Establecemos el modo de entrenamiento
    net.train()
    mask.train()


    print("Net before mask: ", net.state_dict())
    print("Mask before: ", mask.state_dict())
    # Aplicamos la máscara a la red
    apply_mask(net, mask)

    print("Net after mask: ", net.state_dict())
    print("Mask after: ", mask.state_dict())


    running_loss = 0.0

    for images, labels in train_loader:
      images, labels = images.to(device), labels.to(device)

      # Paso 1: Pasar imágenes por la red net y calcular la pérdida sin retropropagarla
      outputs_net = net(images)
      loss = criterion(outputs_net, labels)

      # Paso 2: Retropropagar la pérdida con respecto a los parámetros de la red mask
      optimizer_mask.zero_grad()  # Limpiar los gradientes acumulados de la máscara
      grads_mask = torch.autograd.grad(loss, mask.parameters(), retain_graph=True)  # Calcular gradientes para la máscara

      # Asignar estos gradientes a los parámetros de la máscara
      for param_mask, grad in zip(mask.parameters(), grads_mask):
          param_mask.grad = grad

      # Actualizar los parámetros de la máscara
      optimizer_mask.step()

      running_loss += loss.item()

    train_loss.append(running_loss / len(train_loader))



    net.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer_net.zero_grad()
        outputs = net(images)
        loss = criterion(outputs, labels)#Esto Calcula los gradientes de la pérdida con respecto a los parámetros de la red.
        loss.backward()
        optimizer_net.step()

        running_loss += loss.item()


    train_losses.append(running_loss / len(train_loader))

    # Retropropagación del error en la máscara
    mask.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer_mask.zero_grad()
        outputs = mask(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_mask.step()

    # Aplicamos la máscara a la red
    apply_mask(net, mask)

    # Validación de la red
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    val_accuracies.append(accuracy)

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}, Validation Accuracy: {accuracy}')

    # Paramos el entrenamiento si la precisión en validación supera el 60%
    if accuracy > accuracy_threshold:
        break

Epoch 1
Net before mask:  OrderedDict([('conv1.weight', tensor([[[[ 0.0887,  0.0811,  0.1841],
          [ 0.2283, -0.2773, -0.0260],
          [ 0.1160, -0.1249, -0.2254]]],


        [[[-0.3107, -0.0977, -0.2051],
          [-0.2652, -0.1892, -0.1145],
          [ 0.2163, -0.2305, -0.0159]]],


        [[[-0.1579, -0.1253, -0.2349],
          [ 0.1173,  0.2511, -0.0014],
          [-0.0918, -0.3029,  0.0929]]],


        [[[ 0.1986, -0.0488, -0.0941],
          [ 0.1163,  0.2670, -0.0021],
          [ 0.1897,  0.1011, -0.0235]]],


        [[[-0.0720,  0.2416, -0.1326],
          [ 0.1249,  0.2409, -0.0124],
          [ 0.2400, -0.2293, -0.2769]]],


        [[[-0.0175, -0.3277, -0.1691],
          [ 0.2669, -0.1532,  0.1149],
          [ 0.2264,  0.0915,  0.0282]]],


        [[[-0.2020,  0.2691, -0.2954],
          [-0.0019, -0.0966,  0.2004],
          [-0.0026,  0.0568,  0.0942]]],


        [[[ 0.2708,  0.2307, -0.2637],
          [-0.0468, -0.1220, -0.2471],
          [-0.3133,

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

In [None]:
for mask_name, mask_param in zip(mask.state_dict(), mask.parameters()):

    mask_data = mask_param.data.abs()
    threshold = torch.quantile(mask_data, 0.7)
    mask_applied = (mask_data >= threshold).float()

    # Imprimir los pesos de la máscara aplicada
    print(f"Name: {mask_name}, Shape: {mask_applied.shape}")
    print(mask_applied)
    print()

In [None]:
for var_name in net.state_dict():
    print(var_name, "\t", net.state_dict()[var_name])