In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
import os
import numpy as np

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sparsity = 0.9
batch_size = 64

In [3]:
# Dataset Animals10
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

data_dir = "./animals10"

train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, "train"), transform=transform)
test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, "test"), transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
net = models.alexnet(pretrained=False)
net.classifier[6] = nn.Linear(net.classifier[6].in_features, 10)
mask_model = models.alexnet(pretrained=False)
mask_model.classifier[6] = nn.Linear(mask_model.classifier[6].in_features, 10)

net.load_state_dict(torch.load("net.pth"))
mask_model.load_state_dict(torch.load("mask.pth"))
net.to(device)
mask_model.to(device)

  net.load_state_dict(torch.load("net.pth"))
  mask_model.load_state_dict(torch.load("mask.pth"))


AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [5]:
def create_mask_from_mask_model(mask_model, keep_ratio=0.1):
    all_scores = []

    for name, param in mask_model.named_parameters():
        if 'weight' in name:
            all_scores.append(param.data.cpu().abs().flatten().numpy())

    all_scores = np.concatenate(all_scores)
    threshold = np.percentile(all_scores, 100 * (1 - keep_ratio))

    mask = {}
    for name, param in mask_model.named_parameters():
        if 'weight' in name:
            mask_tensor = (param.data.abs() > threshold).float().to(device)
            mask[name] = mask_tensor

    return mask

mask = create_mask_from_mask_model(mask_model, keep_ratio=0.1)

In [6]:
for name, param in net.named_parameters():
    if name in mask:
        param.data *= mask[name]  # pruning: poner a cero
        param.register_hook(lambda grad, name=name: grad * mask[name])

In [7]:
#imprimir numero de parametros que están a cero
def count_zero_params(model):
    zero_count = 0
    total_count = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            zero_count += (param.data == 0).sum().item()
            total_count += param.numel()
    return zero_count, total_count

In [8]:
initial_weights = {name: param.clone().detach() for name, param in net.named_parameters() if name in mask}

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
def evaluate(model, dataloader):
    model.eval()
    correct, total = 0, 0
    total_loss = 0.0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item()
            _, pred = out.max(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return accuracy, avg_loss

for epoch in range(1, 21):
    zero_count, total_count = count_zero_params(net)
    print(f"Number of zero parameters: {zero_count} out of {total_count} ({100 * zero_count / total_count:.2f}%)")
    net.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    train_acc, _ = evaluate(net, train_loader)
    test_acc, test_loss = evaluate(net, test_loader)

    print(f"\n📊 [Epoch {epoch}]")
    print(f"   🔧 Train Accuracy : {train_acc*100:.2f}% |  Train Loss: {train_loss:.4f}")
    print(f"   🧪 Test Accuracy  : {test_acc*100:.2f}% | Test Loss: {test_loss:.4f}")

    # Verificar que pesos en 0 no se modifican
    for name, param in net.named_parameters():
        if name in mask:
            original = initial_weights[name]
            current = param.detach()
            changed = ((original == 0) & (current != 0)).sum().item()
            if changed > 0:
                print(f"⚠️ Warning: {changed} pesos enmascarados en '{name}' han cambiado.")


Number of zero parameters: 51331909 out of 57035456 (90.00%)

📊 [Epoch 1]
   🔧 Train Accuracy : 28.89% |  Train Loss: 2.7435
   🧪 Test Accuracy  : 28.93% | Test Loss: 2.0582
Number of zero parameters: 51331909 out of 57035456 (90.00%)

📊 [Epoch 2]
   🔧 Train Accuracy : 35.44% |  Train Loss: 1.9464
   🧪 Test Accuracy  : 35.43% | Test Loss: 1.7753
Number of zero parameters: 51331909 out of 57035456 (90.00%)

📊 [Epoch 3]
   🔧 Train Accuracy : 38.70% |  Train Loss: 1.7605
   🧪 Test Accuracy  : 39.15% | Test Loss: 1.6588
Number of zero parameters: 51331909 out of 57035456 (90.00%)

📊 [Epoch 4]
   🔧 Train Accuracy : 49.28% |  Train Loss: 1.6166
   🧪 Test Accuracy  : 48.14% | Test Loss: 1.4865
Number of zero parameters: 51331909 out of 57035456 (90.00%)

📊 [Epoch 5]
   🔧 Train Accuracy : 54.01% |  Train Loss: 1.4318
   🧪 Test Accuracy  : 52.30% | Test Loss: 1.3764
Number of zero parameters: 51331909 out of 57035456 (90.00%)

📊 [Epoch 6]
   🔧 Train Accuracy : 56.80% |  Train Loss: 1.3099
   🧪 

In [10]:
torch.save(net.state_dict(), "masked_net_trained.pth")