<a href="https://colab.research.google.com/github/msatkun/MSc_Project/blob/main/Pruning_Test2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# install torchvision
!pip install torch torchvision
!pip install -U torchtext==0.6
!pip install -U spacy



In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchtext.datasets import IMDB
from torchtext.data import Field, LabelField, TabularDataset, BucketIterator


In [7]:
import spacy
# https://stackoverflow.com/questions/66087475/chatterbot-error-oserror-e941-cant-find-model-en
nlp = spacy.load("en_core_web_sm")

In [10]:
# Define preprocessing fields which are the text and labels from the dataset
TEXT = Field(sequential=True,lower=True)
LABEL = LabelField(dtype=torch.float)

# loading IMDB dataset
train_data, test_data = IMDB.splits(TEXT, LABEL)

# building a vocabulary using the training data and load pre-trained word embeddings from GloVe
TEXT.build_vocab(train_data, max_size=10000, vectors="glove.6B.100d",
                 unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

.vector_cache/glove.6B.zip: 862MB [02:41, 5.35MB/s]                           
100%|█████████▉| 399999/400000 [00:25<00:00, 15824.59it/s]


In [11]:
# create data iterators
batch_size = 64
# bucketiterator helps batch sequences of simliar lengths
train_iterator, test_iterator = BucketIterator.splits(
    (train_data, test_data), batch_size=batch_size)

# printing vocabulary size
print(f"Vocabulary size: {len(TEXT.vocab)}")

Vocabulary size: 10002


In [19]:
# simple feedforward NN
class simplemodel(nn.Module):
  def __init__(self, vocab_size, embedding_dim, input_size, hidden_size, output_size):
    super(simplemodel, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.fc1 = nn.Linear(input_size, hidden_size)
    self.fc2 = nn.Linear(hidden_size, output_size)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.fc1(x)
    x = self.sigmoid(x)
    x = self.fc2(x)
    return x

In [20]:
input_size = 10002
hidden_size = 128
output_size = 1
model = simplemodel(input_size, hidden_size, output_size)

TypeError: ignored

In [21]:
# training loop
num_epochs = 5
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
  for example in train_data:
    inputs = example.text
    labels = example.label

    # Convert tokenized words to tensor using the vocabulary
    inputs = [TEXT.vocab.stoi[token] for token in inputs]
    inputs = torch.tensor(inputs)

    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

RuntimeError: ignored

In [None]:
# defining pruning parameters
prune_percent = 20  # Percentage of connections to prune

# pruning the model's connections
parameters_to_prune = [(model.fc1, 'weight'), (model.fc2, 'weight')]

prune.global_unstructured(parameters_to_prune,
                          pruning_method=prune.L1Unstructured,
                          amount=prune_percent)

# removing the pruned connections from the model
prune.remove(model.fc1, 'weight')
prune.remove(model.fc2, 'weight')

In [None]:
# Evaluate the pruned model
correct = 0
total = 0
with torch.no_grad():
  for inputs, labels in test_data:
    outputs = model(inputs)
    predicted = (outputs > 0.5).float()
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f'Test Accuracy: {accuracy:.4f}')

In [None]:
# visualising sparsity
def count_parameters(model):
  return sum(p.numel() for p in model.parameters())

def count_zero_parameters(model):
  return sum(p.numel() - p.nonzero().size(0) for p in model.parameters())

total_params = count_parameters(model)
total_zero_params = count_zero_parameters(model)
sparsity = total_zero_params / total_params

print(f"Total parameters: {total_params}")
print(f"Total zero parameters: {total_zero_params}")
print(f"Sparsity: {sparsity:.4f}")