In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1, padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(in_features=50 * 7 * 7, out_features=500)
        self.fc2 = nn.Linear(in_features=500, out_features=10)
        
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 50 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return x


In [7]:
import torch
import torchvision
import torchvision.transforms as transforms

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

# Initialize the model, loss function, and optimizer
model = LeNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Train the model
for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    print('Epoch {} loss: {:.4f}'.format(epoch + 1, running_loss / len(trainloader)))

print('Finished Training')

KeyboardInterrupt: 

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

# Initialize the model, loss function, and optimizer
model = LeNet().to(torch.device("cuda" if torch.cuda.is_available() else "mps"))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Train the model
for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    running_corrects = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs = inputs.to(torch.device("mps"))
        labels = labels.to(torch.device("mps"))
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels.data)
        
    epoch_loss = running_loss / len(trainloader)
    epoch_acc = running_corrects.float() / len(trainset)
    print('Epoch {} loss: {:.4f} accuracy: {:.4f}'.format(epoch + 1, epoch_loss, epoch_acc))

print('Finished Training')

In [None]:
def calculate_saliency(weight_i, bias_i, weight_j, bias_j, coef_j):
    epsilon = weight_i - weight_j + bias_i - bias_j
    saliency = (coef_j**2) * torch.norm(epsilon, 2)**2
    return saliency

In [None]:
def get_matrix_saliency(weights, bias, coefs):
    saliency_matrix = torch.empty((weights.shape[0], weights.shape[0]))
    for i in range(weights.shape[0]):
        for j in range(weights.shape[0]):
            saliency_matrix[i][j] = calculate_saliency(weights[i], bias[i], weights[j], bias[j], coefs[j]) 
    return saliency_matrix


def get_smallest_saliency_id(saliency_matrix):
    lowest_saliency = float("inf")
    lowest_saliency_indices = (0, 0)

    for i in range(saliency_matrix.shape[0]):
        for j in range(saliency_matrix.shape[0]):
            #On ne veut pas la diagonale
            if j !=i:
                # Calculate the saliency for the current pair of weight sets
                current_saliency = saliency_matrix[i][j]
                # Update the lowest saliency and the corresponding indices if necessary
                if current_saliency < lowest_saliency:
                    lowest_saliency = current_saliency
                    lowest_saliency_indices = (i, j)
    return lowest_saliency_indices


def update_model_and_saliency_matrix(model, saliency_matrix):
    id = get_smallest_saliency_id(saliency_matrix)
    i,j = id[0], id[1]

    weights = model.fc1.weight
    bias = model.fc1.bias
    coefs = model.fc2.weight[0]

    model.fc1.weight.data = torch.cat((weights[:j], weights[j+1:]))
    model.fc1.bias.data = torch.cat((bias[:j], bias[j+1:]))


    for nb in range(model.fc2.weight.shape[0]):
      if nb == 0:
        model.fc2.weight.data[nb][i] += model.fc2.weight.data[nb][j]
        new_fc2 = torch.cat((model.fc2.weight.data[nb][:j], model.fc2.weight.data[nb][j+1:]))
      else:
        model.fc2.weight.data[nb][i] += model.fc2.weight.data[nb][j]
        new_fc2 = torch.cat((new_fc2, torch.cat((model.fc2.weight.data[nb][:j], model.fc2.weight.data[nb][j+1:]))))
    
    model.fc2.weight.data = torch.reshape(new_fc2, [10,int(int(len(new_fc2)) / 10)])

    # Update the saliency matrix by removing the j-th column and row
    saliency_matrix = torch.cat((saliency_matrix[:i], saliency_matrix[i+1:]))
    saliency_matrix = torch.cat((saliency_matrix[:, :j],saliency_matrix[:, j+1:]), dim=1)

    weights = model.fc1.weight
    bias = model.fc1.bias
    coefs = model.fc2.weight[0]

    # Update the saliency values for the remaining weight pairs
    for k in range(saliency_matrix.shape[0]):
      saliency_matrix[k][j] = calculate_saliency(weights[k], bias[k], weights[i], bias[i], coefs[i])

    return saliency_matrix

In [None]:
weights = model.fc1.weight
bias = model.fc1.bias 
coefs = model.fc2.weight[0]

#Premier calcul de la matrice de saliency
matrix_saliency = get_matrix_saliency(weights, bias, coefs)


#Le prunning commence. Il suffit d'appeler autant de fois cette fonction que l'on veut :) 
for _ in range(10):
  matrix_saliency = update_model_and_saliency_matrix(model, matrix_saliency)

In [None]:
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

test_loss = 0.0
test_corrects = 0
with torch.no_grad():
      for data in testloader:
            inputs, labels = data
            inputs = inputs.to(torch.device("cuda"))
            labels = labels.to(torch.device("cuda"))
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            test_corrects += torch.sum(preds == labels.data)
test_loss = test_loss / len(testloader)
test_acc = test_corrects.double() / len(testset)

print(test_acc)