In [None]:
from transformers import BertTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, f1_score
import torch
import torch.nn.utils.prune as prune
from torch.nn.functional import cross_entropy
from pyswarm import pso

# Load IMDb dataset
dataset = load_dataset('imdb')

# Use a small subset of the dataset for faster experimentation
small_train_dataset = dataset['train'].shuffle(seed=42).select(range(1000))
small_test_dataset = dataset['test'].shuffle(seed=42).select(range(1000))

# Load TinyBERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-tiny')
model = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-tiny', num_labels=2)

# Make all model parameters contiguous
for param in model.parameters():
    param.data = param.data.contiguous()

# Tokenize data with fixed sequence length
def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128)

tokenized_train = small_train_dataset.map(tokenize_function, batched=True)
tokenized_test = small_test_dataset.map(tokenize_function, batched=True)

# Remove columns that are not tensors
tokenized_train = tokenized_train.remove_columns(["text"])
tokenized_test = tokenized_test.remove_columns(["text"])

# Define compute metrics function with additional metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    # Convert logits and labels to tensors to compute cross-entropy loss
    logits_tensor = torch.tensor(logits)
    labels_tensor = torch.tensor(labels)

    accuracy = accuracy_score(labels, predictions)
    recall = recall_score(labels, predictions, average="binary")
    f1 = f1_score(labels, predictions, average="binary")
    loss = cross_entropy(logits_tensor, labels_tensor).item()  # Cross-entropy loss calculation

    return {
        "accuracy": accuracy,
        "recall": recall,
        "f1": f1,
        "loss": loss
    }

# Define PSO fitness function (objective function)
def fitness_function(hyperparams):
    lr, batch_size, epochs = hyperparams
    batch_size = int(batch_size)
    epochs = int(epochs)

    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        learning_rate=lr,
        evaluation_strategy="steps",
        eval_steps=50,
        save_steps=50,
        logging_steps=50,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_test,
        compute_metrics=compute_metrics
    )

    for param in model.parameters():
        param.data = param.data.contiguous()

    trainer.train()
    eval_results = trainer.evaluate()
    return -eval_results['eval_accuracy']  # Minimize negative accuracy for PSO

# PSO to optimize hyperparameters
lb = [1e-5, 8, 1]
ub = [1e-3, 16, 2]
best_hyperparams, _ = pso(fitness_function, lb, ub, swarmsize=5, maxiter=3)

print("Best hyperparameters:", best_hyperparams)
final_lr, final_batch_size, final_epochs = best_hyperparams
final_batch_size = int(final_batch_size)
final_epochs = int(final_epochs)

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=final_epochs,
    per_device_train_batch_size=final_batch_size,
    learning_rate=final_lr,
    evaluation_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    compute_metrics=compute_metrics
)

for param in model.parameters():
    param.data = param.data.contiguous()

# Apply structured pruning on entire channels of linear layers
def apply_structured_pruning(model, amount=0.4):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.ln_structured(module, name='weight', amount=amount, n=2, dim=0)
            prune.remove(module, 'weight')

# Apply structured pruning with 40% sparsity
apply_structured_pruning(model, amount=0.4)

# Train the final pruned model and evaluate with added metrics
trainer.train()
final_results = trainer.evaluate()
print("Final evaluation after structured pruning:", final_results)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.6897,0.667675,0.628,0.627049,0.621951




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.6274,0.671813,0.631,0.875,0.698283
100,0.5414,0.610344,0.686,0.735656,0.695736




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.3235,0.641385,0.697,0.745902,0.706111




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.3973,0.763328,0.711,0.854508,0.742654
100,0.3882,0.576943,0.722,0.797131,0.736742




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.2478,1.684048,0.626,0.954918,0.713629
100,0.3771,1.078143,0.71,0.663934,0.690832




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.1684,1.183388,0.689,0.915984,0.741909
100,0.272,0.942802,0.726,0.793033,0.73855




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0931,1.156278,0.733,0.702869,0.719832




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.022,1.30978,0.737,0.739754,0.732995




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0982,1.142066,0.717,0.631148,0.685206
100,0.2601,1.22488,0.684,0.858607,0.72617
150,0.0218,1.507363,0.725,0.756148,0.728529




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0999,1.991304,0.718,0.618852,0.681716




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0337,2.242895,0.705,0.891393,0.746781
100,0.1795,1.59365,0.73,0.678279,0.7103




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0229,2.577767,0.72,0.678279,0.70276




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0299,2.708285,0.724,0.64959,0.696703
100,0.1673,2.325746,0.733,0.711066,0.722164




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0,3.294989,0.736,0.803279,0.748092
100,0.3619,2.269863,0.732,0.786885,0.741313
150,0.034,2.229881,0.731,0.790984,0.741595




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0,3.497745,0.73,0.786885,0.739884




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0001,4.455726,0.708,0.856557,0.741135
100,0.0692,3.187846,0.73,0.688525,0.713376




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0084,3.80882,0.731,0.821721,0.748833
100,0.0268,3.257816,0.736,0.727459,0.728953
150,0.0143,3.190976,0.736,0.815574,0.750943




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0,4.09306,0.72,0.625,0.685393
100,0.0,3.94911,0.729,0.768443,0.734574
150,0.0,3.9518,0.73,0.770492,0.735812




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0,4.557057,0.728,0.854508,0.754069
100,0.1587,4.294619,0.682,0.438525,0.573727
150,0.0353,3.856544,0.716,0.637295,0.686534
200,0.0557,3.768021,0.72,0.651639,0.694323




Step,Training Loss,Validation Loss,Accuracy,Recall,F1
50,0.0066,5.766375,0.689,0.485656,0.603822


Stopping search: maximum iterations reached --> 3
Best hyperparameters: [7.10991852e-04 8.16467595e+00 1.96990985e+00]




Epoch,Training Loss,Validation Loss,Accuracy,Recall,F1
1,No log,1.549414,0.742,0.758197,0.741483


Final evaluation after structured pruning: {'eval_loss': 1.549413800239563, 'eval_accuracy': 0.742, 'eval_recall': 0.7581967213114754, 'eval_f1': 0.7414829659318637, 'eval_runtime': 0.689, 'eval_samples_per_second': 1451.29, 'eval_steps_per_second': 181.411, 'epoch': 1.0}


In [None]:
# Define the directory where the model and tokenizer will be saved
output_dir = "./saved_model"

# Save the model and tokenizer after training
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

# Optionally, you can also save the training arguments and optimizer states if needed
trainer.save_model(output_dir)  # This will save the model, tokenizer, and additional components