In [1]:
import torch
from torch.nn import Linear
import torchvision
import actuallysparse.converter as converter
import actuallysparse.layers as layers
import pretrained
import copy

In [6]:
training_device = "cuda" if torch.cuda.is_available() else "cpu"
pruning_device = "cpu"

In [7]:
# Ładowanie modeli
pretrained_model = torch.load(".weights/full/pretrained", map_location=torch.device('cpu'))
extra_trained = torch.load(".weights/full/extra_trained", map_location=torch.device('cpu'))
pruned_model = torch.load(".weights/full/pruned", map_location=torch.device('cpu'))
very_pruned_model = torch.load(".weights/full/very_pruned", map_location=torch.device('cpu'))

In [11]:
dataloader_train, dataloader_test = pretrained.load_cifar10_dataloaders()

Files already downloaded and verified
Files already downloaded and verified


In [12]:
def eval_accuracy(model, dataset="train"):
    dataloader = dataloader_train if dataset == "train" else dataloader_test
    with torch.no_grad():
        model.to(training_device)
        correct = 0
        all_so_far = 0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(training_device), labels.to(training_device)
            pred = torch.argmax(model(inputs), dim=1)

            all_so_far += labels.size().numel()
            correct += torch.sum(pred.eq(labels))
    return correct/all_so_far

In [4]:
def convolutional_pass(model):
    passed_data = []
    for data in dataloader_train:
        with torch.no_grad():
            inputs, labels = data
            inputs_for_sparse = model.features(inputs)
            inputs_for_sparse = model.avgpool(inputs_for_sparse)
            inputs_for_sparse = inputs_for_sparse.view(inputs_for_sparse.size(0), -1)
            *conv_data, = inputs_for_sparse, labels
            passed_data.append(conv_data)
    return passed_data


In [3]:
def train_prune_loop(model, data, optimizer, criterion,max_epochs = 2000, epochs_to_prune = 15):
    in_classifier_features = model.classifier[0].in_features
    dummy_input = torch.ones(in_classifier_features)
    model.train()
    for epoch in range(max_epochs):
        if epoch % epochs_to_prune == 0:
            layers.prune_model(model.classifier, dummy_input)
            print("Pruned!")
        for entry in data:
            inputs_for_sparse, labels = entry
            optimizer.zero_grad()
            outputs = model.classifier(inputs_for_sparse)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f"Epoch:{epoch}, loss:{loss}")

In [8]:
pruned_model_coo = copy.deepcopy(pruned_model)
pruned_model_coo.classifier = converter.convert_model(pruned_model_coo.classifier, Linear, 'coo')
for child in pruned_model_coo.classifier.children():
    if type(child) is layers.SparseLayer:
        child.set_k(0.02)

In [9]:
optimizer = torch.optim.Adam(pruned_model_coo.classifier.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

In [15]:
passed_data = convolutional_pass(pruned_model_coo)

In [41]:
train_prune_loop(pruned_model_coo, passed_data, optimizer, criterion, max_epochs=59)

Pruned!
Epoch:0, loss:0.006521960254758596
Epoch:1, loss:0.004795870278030634
Epoch:2, loss:0.006124990060925484
Epoch:3, loss:0.008161676116287708
Epoch:4, loss:0.006097306497395039
Epoch:5, loss:0.01150859147310257
Epoch:6, loss:0.008162113837897778
Epoch:7, loss:0.0070147328078746796
Epoch:8, loss:0.0071838791482150555
Epoch:9, loss:0.005080321803689003
Epoch:10, loss:0.004910708405077457
Epoch:11, loss:0.006445376668125391
Epoch:12, loss:0.008207106962800026
Epoch:13, loss:0.01086659636348486
Epoch:14, loss:0.010599001310765743
Pruned!
Epoch:15, loss:0.015457019209861755
Epoch:16, loss:0.004612844903022051
Epoch:17, loss:0.008892080746591091
Epoch:18, loss:0.0054727233946323395
Epoch:19, loss:0.00818454660475254
Epoch:20, loss:0.010733014903962612
Epoch:21, loss:0.008907218463718891
Epoch:22, loss:0.009959240444004536
Epoch:23, loss:0.009082750417292118
Epoch:24, loss:0.0062526194378733635
Epoch:25, loss:0.00681970315054059


KeyboardInterrupt: 

In [None]:
pruned_model_coo.eval()

In [14]:
eval_accuracy(pruned_model_coo, dataset="test")

tensor(0.8237)