In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset
from src.pruning.prune import prune_concern_identification

In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ci_ratio = 0.3
seed = 44
include_layers = (include_layers,)
exclude_layers = ["attention"]

In [4]:
config = Config(name, device)

In [5]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [7]:
positive_samples = SamplingDataset(
    train_dataloader,
    config,
    0,
    num_samples,
    True,
    4,
    resample=False,
)

In [8]:
negative_samples = SamplingDataset(
    train_dataloader,
    config,
    0,
    num_samples,
    False,
    4,
    resample=False,
)

In [9]:
module = copy.deepcopy(model)

In [10]:
from src.models.evaluate import get_perplexity

In [11]:
prune_concern_identification(
    module,
    config,
    positive_samples,
    negative_samples,
    include_layers=include_layers,
    exclude_layers=exclude_layers,
    sparsity_ratio=ci_ratio,
)

In [12]:
get_perplexity(module, test_dataloader, config)

3.3947603702545166

In [13]:
get_perplexity(model, test_dataloader, config)

3.4007437229156494