<a href="https://colab.research.google.com/github/msatkun/MSc_Project/blob/main/Pruning_Test1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# install missing packages
# resource used: https://stackoverflow.com/questions/66549818/getting-importerror-when-using-torchtext
!pip install -U torchtext==0.6



In [23]:
# import the IMBD dataset from pytorch
# https://pytorch.org/text/stable/datasets.html
import spacy
import torch
import en_core_web_sm
from torchtext.datasets import IMDB
from torchtext.data import Field, LabelField, TabularDataset, BucketIterator

# defining the text and label fields
txt = Field(tokenize=lambda x: x.split(), lower=True, include_lengths=True)
lbl = LabelField(sequential=False, use_vocab=False)

#loading the dataset
train_data, test_data = IMDB.splits(txt, lbl)

In [8]:
# for the training set we need to build a vocabulary set
# https://snyk.io/advisor/python/torchtext/functions/torchtext.data.BucketIterator.splits
txt.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
lbl.build_vocab(train_data)

.vector_cache/glove.6B.zip: 862MB [02:41, 5.35MB/s]                           
100%|█████████▉| 399999/400000 [00:26<00:00, 14980.54it/s]


In [9]:
# creating iterators for batching
# https://snyk.io/advisor/python/torchtext/functions/torchtext.data.BucketIterator.splits
BATCH_SIZE = 64

train_iterator, test_iterator = BucketIterator.splits((train_data, test_data),
                                                      batch_size=BATCH_SIZE,
                                                      sort_key=lambda x: len(x.text),
                                                      sort_within_batch=False)

In [10]:
# importing modules for pruning
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.utils.prune as prune

In [11]:
# simple feedforward neural network
class simpleNN(nn.Module):
  def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
    super(simpleNN, self).__init__()
    self.embedding = nn.Embedding(input_dim, embedding_dim)
    self.fc = nn.Linear(embedding_dim, hidden_dim)
    self.out = nn.Linear(hidden_dim, output_dim)

  def forward(self, text):
   embedded = self.embedding(text)
   hidden = self.fc(embedded)
   output = self.out(hidden)
   return output

In [12]:
# setting up the first model for training
INPUT_DIM = len(txt.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = 1

model = simpleNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)

In [13]:
# function to apply pruning on a model
def apply_pruning(model, pruning_rate):
  parameters_to_prune = []
  for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
      parameters_to_prune.append((module, 'weight'))

  prune.global_unstructured(parameters_to_prune,
                            pruning_method=prune.L1Unstructured,
                            amount=pruning_rate)

In [14]:
# apply pruning to the model
PRUNING_RATE = 0.3 # prune 30% of the weights
apply_pruning(model, PRUNING_RATE)

In [18]:
# training loop for a pruned model
optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

def train(model, iterator, optimizer, criterion):
  model.train()
  for batch in iterator:
    optimizer.zero_grad()
    predictions = model(batch.text).squeeze(1)
    loss = criterion(predictions, batch.label)
    loss.backward()
    optimizer.step()

In [24]:
# training the pruned model
N_EPOCHS = 5
for epoch in range(N_EPOCHS):
  train(model, train_iterator, optimizer, criterion)

ValueError: ignored

In [None]:
# create a function to evaluate the pruned model
def evaluate(model, iterator, criterion):
  model.eval()
  total_loss = 0.0
  correct = 0
  with torch.no_grad():
    for batch in iterator:
      predictions = model(batch.text).squeeze(1)
      loss = criterion(predictions, batch.label)
      total_loss += loss.item()
      preds = torch.round(torch.sigmoid(predictions))
      correct += (preds == batch.label).sum().item()
  return total_loss / len(iterator), correct / len(iterator.dataset) * 100

In [None]:
# evalute the pruned model
pruned_loss, pruned_accuracy = evaluate(model, test_iterator, criterion)
print(f"Pruned Model - Test Loss: {pruned_loss:.3f}, Test Accuracy: {pruned_accuracy:.2f}%")

In [None]:
# visualise the sparity of the model
import matplotlib.pyplot as plt

def show_sparsity(model):
  sparsity = []

  plt.xlabel('layers')
  plt.ylabel('sparsity')
  plt.title('Sparsity of Pruned Model Layers')
  plt.show()