<a href="https://colab.research.google.com/github/cabamarcos/SuperMask/blob/main/prueba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import copy

In [24]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Creamos las redes

In [25]:
# Definimos las dos redes convolucionales
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class Mask(nn.Module):
    def __init__(self):
        super(Mask, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Inicializamos las dos redes y las movemos a la GPU si está disponible
net = Net().to(device)
mask = Mask().to(device)

Cargamos los datos

In [26]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [27]:
# Definimos la función de pérdida y los optimizadores
criterion = nn.CrossEntropyLoss()
optimizer_mask = optim.SGD(mask.parameters(), lr=0.01, momentum=0.9)


In [28]:
def print_parameters(layer_name, params):
    print(f"--- {layer_name} ---")
    print(params)

def apply_mask(net, mask):
    # Aplica la máscara a la red: selecciona el 30% de los pesos más altos y desactiva el resto
    net_masked = copy.deepcopy(net)
    with torch.no_grad():
        for net_name, net_param, mask_name, mask_param in zip(net_masked.state_dict(), net_masked.parameters(), mask.state_dict(), mask.parameters()):
            #print_parameters("Mask before applying", mask_param.data)
            #print_parameters("Net before applying mask", net_param.data)

            mask_data = mask_param.data.abs()
            threshold = torch.quantile(mask_data, 0.7)
            mask_applied = (mask_data >= threshold).float()

            #print_parameters("Mask applied (binary)", mask_applied)
            net_param.data *= mask_applied

            #print_parameters("Net after applying mask", net_param.data)
            #print("\n")

    return net_masked


In [32]:
loss

tensor(2.3027, device='cuda:0')

In [33]:
loss_mask_net


tensor(2.3027, device='cuda:0')

In [38]:
loss_mask_mask

tensor(2.3047, device='cuda:0', grad_fn=<NllLossBackward0>)

In [36]:
loss_mask

tensor(2.3027, device='cuda:0', grad_fn=<AddBackward0>)

In [31]:
from ast import Param
# Listas para almacenar las pérdidas de entrenamiento y las precisiones de validación
train_loss = []
test_accuracies = []
epochs = 10
accuracy_threshold = 0.6

for epoch in range(epochs):
    # Entrenamiento de la red
    print(f"Epoch {epoch + 1}")

    # Establecemos el modo de entrenamiento
    net.train()
    mask.train()


    #print("Net before mask: ", net.state_dict())
    #print("Mask before: ", mask.state_dict())

    # Aplicamos la máscara a la red
    net_masked = apply_mask(net, mask)
    #print("net_masked: ", net_masked.state_dict())
    #print("Net after mask: ", net.state_dict())
    print("Mask after: ", mask.state_dict())


    running_loss = 0.0

    for images, labels in train_loader:
      images, labels = images.to(device), labels.to(device)

      # Paso 1: Pasar imágenes por la red net y calcular la pérdida sin retropropagarla
      outputs_net = net_masked(images).detach()
      loss = criterion(outputs_net, labels)

      # Paso 2: Retropropagar la pérdida con respecto a los parámetros de la red mask
      optimizer_mask.zero_grad()  # Limpiar los gradientes acumulados de la máscara
      outputs_mask = mask(images)  # Obtener la salida de la máscara (ficticio)

      # Calcular el loss de mask borrando el loss real de mask

      loss_mask_net = loss # Error de net
      loss_mask_mask = criterion(outputs_mask, labels) # Error de mask
      loss_mask = loss_mask_net * 1 + loss_mask_mask * (1 - 1) # Combinar errores, alpha en [0,1]


      #loss_mask = criterion(outputs_mask, labels)*0.0  + loss
      loss_mask.backward(retain_graph=True)  # Calcular gradientes
      #grad_mask = torch.autograd.grad(loss_mask, mask.parameters(), create_graph=True, allow_unused=True) # Retropropagar gradientes



      #Actualizamos los gradientes manualmente
      #for name, param in mask.named_parameters():
       # if param.grad is not None:
        #  param.grad = param.grad.clone()


      optimizer_mask.step()  # Actualizar los parámetros de la máscara

      # Limpiamos gradientes para la proxima iteraccion
      net_masked.zero_grad()
      mask.zero_grad()

      running_loss += loss.item()

    train_loss.append(running_loss / len(train_loader))
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')


    # Verificar que los parámetros de NetB se han actualizado
    for name, param in mask.named_parameters():
        if param.requires_grad:
            print(name, param.data)
########################################


    # test de la red
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    test_accuracies.append(accuracy)

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}, test Accuracy: {accuracy}')

    # Paramos el entrenamiento si la precisión en validación supera el 60%
    if accuracy > accuracy_threshold:
        break

Epoch 1
Mask after:  OrderedDict([('conv1.weight', tensor([[[[-0.2264, -0.0179, -0.3108],
          [-0.2368,  0.0597,  0.1807],
          [-0.0581, -0.1834,  0.2415]]],


        [[[-0.0363,  0.1337, -0.2548],
          [ 0.0784, -0.2212,  0.3144],
          [ 0.2928,  0.0450,  0.0645]]],


        [[[ 0.0783, -0.0750, -0.2760],
          [ 0.0120,  0.1506,  0.0065],
          [-0.1925,  0.2197,  0.3105]]],


        [[[-0.2301, -0.1796, -0.0402],
          [-0.1485,  0.3068, -0.0125],
          [-0.2651,  0.1425, -0.0033]]],


        [[[-0.0233,  0.0096, -0.0571],
          [ 0.1836,  0.0893,  0.2298],
          [ 0.2486, -0.2881, -0.1489]]],


        [[[ 0.2190,  0.0459,  0.3121],
          [-0.3266, -0.2854, -0.1122],
          [ 0.3170,  0.1061,  0.2828]]],


        [[[ 0.2379, -0.1954,  0.2226],
          [-0.0323, -0.1462,  0.0234],
          [ 0.0657,  0.3278, -0.0653]]],


        [[[ 0.3152,  0.3051,  0.1340],
          [ 0.0927,  0.0885, -0.0712],
          [ 0.2317, -0.1

KeyboardInterrupt: 

In [None]:
loss.item()

In [None]:
loss

In [None]:
loss_mask.item()