<a href="https://colab.research.google.com/github/cabamarcos/SuperMask/blob/main/checkpoint2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import copy
from ast import Param

In [2]:
# Verificar si la GPU está disponible y establecer el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Creamos las redes

# Definimos las dos redes convolucionales
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class Mask(nn.Module):
    def __init__(self):
        super(Mask, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 5 * 5, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Inicializamos las dos redes y las movemos a la GPU si está disponible
net = Net().to(device)
mask = Mask().to(device)

In [3]:
# Definimos las dos redes convolucionales
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)  # Aplana la imagen de entrada
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class Mask(nn.Module):
    def __init__(self):
        super(Mask, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)  # Aplana la imagen de entrada
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Inicializamos las dos redes y las movemos a la GPU si está disponible
net = Net().to(device)
mask = Mask().to(device)

Cargamos los datos

In [4]:
# Definimos el transform para los datos de MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Cargamos el dataset de MNIST
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Definimos los DataLoaders para los conjuntos de entrenamiento y prueba
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

In [5]:
# Definimos la función de pérdida y los optimizadores
criterion = nn.CrossEntropyLoss()
optimizer_mask = optim.SGD(mask.parameters(), lr=0.01, momentum=0.9)


In [6]:
def print_parameters(layer_name, params):
    print(f"--- {layer_name} ---")
    print(params)

def apply_mask(net, mask):
    # Aplica la máscara a la red: selecciona el 30% de los pesos más altos y desactiva el resto
    net_masked = copy.deepcopy(net)
    with torch.no_grad():
        for net_name, net_param, mask_name, mask_param in zip(net_masked.state_dict(), net_masked.parameters(), mask.state_dict(), mask.parameters()):
            #print_parameters("Mask before applying", mask_param.data)
            #print_parameters("Net before applying mask", net_param.data)

            mask_data = mask_param.data.abs()
            threshold = torch.quantile(mask_data, 0.7)
            mask_applied = (mask_data >= threshold).float()

            #print_parameters("Mask applied (binary)", mask_applied)
            net_param.data *= mask_applied

            #print_parameters("Net after applying mask", net_param.data)
            #print("\n")

    return net_masked


In [7]:
for param in net.parameters():
    param.requires_grad = False

In [8]:
# Listas para almacenar las pérdidas de entrenamiento y las precisiones de validación
train_loss = []
test_accuracies = []
epochs = 1
accuracy_threshold = 0.6

for epoch in range(epochs):
    # Entrenamiento de la red
    print(f"Epoch {epoch + 1}")

    # Establecemos el modo de entrenamiento
    net.train()
    mask.train()


    #print("Net before mask: ", net.state_dict())
    #print("Mask before: ", mask.state_dict())

    # Aplicamos la máscara a la red
    net_masked = apply_mask(net, mask)
    #for param in net_masked.parameters():
     #param.requires_grad = False
    #print("net_masked: ", net_masked.state_dict())
    #print("Net after mask: ", net.state_dict())
    #print("Mask after: ", mask.state_dict())


    running_loss = 0.0

    for images, labels in train_loader:
      images, labels = images.to(device), labels.to(device)

      outputs_net = net_masked(images)
      loss_net = criterion(outputs_net, labels)

      outputs_mask = mask(images) # Obtener la salida de la máscara (ficticio)
      loss_mask = criterion(outputs_mask, labels)

      loss = loss_net + loss_mask*0

      optimizer_mask.zero_grad()

      loss.backward()

      # Verificar y imprimir gradientes antes de actualizar
      #for name, param in net.named_parameters():
        #print(f"Gradiente de {name}: {param.grad}")

      optimizer_mask.step()  # Actualizar los parámetros de la máscara

      running_loss += loss_net.item()

    train_loss.append(running_loss / len(train_loader))

    # Verificar que los parámetros de NetB se han actualizado
    #for name, param in mask.named_parameters():
     #   if param.requires_grad:
      #      print(name, param.data)
########################################


    # test de la red
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    test_accuracies.append(accuracy)

    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}, test Accuracy: {accuracy}')

    # Paramos el entrenamiento si la precisión en validación supera el 60%
    if accuracy > accuracy_threshold:
        break

Epoch 1
Gradiente de fc1.weight: None
Gradiente de fc1.bias: None
Gradiente de fc2.weight: None
Gradiente de fc2.bias: None
Gradiente de fc3.weight: None
Gradiente de fc3.bias: None
Gradiente de fc1.weight: None
Gradiente de fc1.bias: None
Gradiente de fc2.weight: None
Gradiente de fc2.bias: None
Gradiente de fc3.weight: None
Gradiente de fc3.bias: None
Gradiente de fc1.weight: None
Gradiente de fc1.bias: None
Gradiente de fc2.weight: None
Gradiente de fc2.bias: None
Gradiente de fc3.weight: None
Gradiente de fc3.bias: None
Gradiente de fc1.weight: None
Gradiente de fc1.bias: None
Gradiente de fc2.weight: None
Gradiente de fc2.bias: None
Gradiente de fc3.weight: None
Gradiente de fc3.bias: None
Gradiente de fc1.weight: None
Gradiente de fc1.bias: None
Gradiente de fc2.weight: None
Gradiente de fc2.bias: None
Gradiente de fc3.weight: None
Gradiente de fc3.bias: None
Gradiente de fc1.weight: None
Gradiente de fc1.bias: None
Gradiente de fc2.weight: None
Gradiente de fc2.bias: None
Gradie

KeyboardInterrupt: 

In [None]:
loss_net.item() = 0

In [None]:
loss_mask

In [None]:
loss