In [1]:
from transformers import BertModel, BertConfig
import torch

def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

# Load the model
model_name = "google/bert_uncased_L-6_H-512_A-8"
model = BertModel.from_pretrained(model_name)

# Get the model size
model_size_mb = get_model_size(model)

print(f"Model: {model_name}")
print(f"Size: {model_size_mb:.2f} MB")

# Get the number of parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params:,}")

# Print model configuration
config = model.config
print("\nModel Configuration:")
print(f"Hidden Size: {config.hidden_size}")
print(f"Number of Hidden Layers: {config.num_hidden_layers}")
print(f"Number of Attention Heads: {config.num_attention_heads}")

  from .autonotebook import tqdm as notebook_tqdm


Model: google/bert_uncased_L-6_H-512_A-8
Size: 133.78 MB
Number of parameters: 35,068,416

Model Configuration:
Hidden Size: 512
Number of Hidden Layers: 6
Number of Attention Heads: 8


In [2]:
from transformers import DistilBertModel, DistilBertConfig
import torch

def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

# Load the model
model_name = "distilbert-base-uncased"
model = DistilBertModel.from_pretrained(model_name)

# Get the model size
model_size_mb = get_model_size(model)

print(f"Model: {model_name}")
print(f"Size: {model_size_mb:.2f} MB")

# Get the number of parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params:,}")

# Print model configuration
config = model.config
print("\nModel Configuration:")
print(f"Hidden Size: {config.hidden_size}")
print(f"Number of Hidden Layers: {config.num_hidden_layers}")
print(f"Number of Attention Heads: {config.num_attention_heads}")

Model: distilbert-base-uncased
Size: 253.16 MB
Number of parameters: 66,362,880

Model Configuration:
Hidden Size: 768
Number of Hidden Layers: 6
Number of Attention Heads: 12


In [5]:
import torch
from typing import Iterable, List
# Define special symbols and indices, <unk> symbol will be in the last positions
PAD_IDX, BOS_IDX, EOS_IDX, URL_IDX, EMAIL_IDX, PHONE_IDX, TGT_UNK_IDX = 0, 1, 2, 3, 4, 5, 6
# TGT_UNK are for the tokens that neither appear in the vocab nor the text

class Lang:
    def __init__(self, num_position_markers = 1):
        assert num_position_markers >= 1
        self.num_position_markers = num_position_markers


    def build_vocab(self, 
                    data_iter: Iterable,
                    vocab_size: int):
        token_counter = Counter()
        for tokens in tqdm(yield_tokens(data_iter)):
            token_counter.update(tokens)
        # Make sure the tokens are in order of their indices to properly insert them in vocab
        special_symbols = ['<pad>', '<bos>', '<eos>', '<url>', '<email>', '<phone>', "<tgt_unk>"]
        self.vocab = special_symbols + list(sorted(token_counter.keys(), key = lambda x: -token_counter[x]))[:vocab_size]

        self.word2index = {self.vocab[index]:index for index in range(len(self.vocab))}
        self.special_symbols = special_symbols

    
    def __len__(self):
        return len(self.vocab) + self.num_position_markers


    def _lookup_index(self, token: str, position = 0):
        # position should start from 0
        assert position is None or position < self.num_position_markers
        if token in self.word2index: return self.word2index[token]
        else: 
            if position is not None:
                return len(self.vocab) + position
            else: return TGT_UNK_IDX


    def lookup_indices(self, tokens: List[str], src_tokens: List[str] = None) -> List[int]:
        assert hasattr(self, "vocab"), "Vocab has not been built"
        if self.num_position_markers == 1:
            # disregard the position of oov token, map to the same index (index of <unk>)
            indices = [self._lookup_index(token) for token in tokens]
        else:
            # regard the position of oov token
            indices = []
            cache = {}
            for i in range(len(tokens)):
                token = tokens[i]
                if token in self.special_symbols:
                    indices.append(self.special_symbols.index(token))
                    continue
                if token not in cache:
                    position = src_tokens.index(token) if ((src_tokens is not None) and (token in src_tokens)) == True else None
                    token_index = self._lookup_index(token, position = position)
                    cache[token] = token_index
                indices.append(cache[token])
        return indices

    def lookup_token(self, index: int, src_tokens: List[int]):
        if index < len(self.vocab): return self.vocab[index]
        else:
            if self.num_position_markers == 1:
                # disregard position of oov token
                return "<unk>"
            else:
                assert index - len(self.vocab) < self.num_position_markers
                if src_tokens is None: 
                    return f"<unk-{index - len(self.vocab)}>"
                else:
                    return src_tokens[index - len(self.vocab)]
    
    
    def lookup_tokens(self, indices: List[int], src_tokens: List[str] = None) -> List[str]:
        assert hasattr(self, "vocab"), "Vocab has not been built"
        return [self.lookup_token(index, src_tokens) for index in indices]

CKPT_PATH = f"/scratch/lamdo/unsupervised_keyphrase_prediction_2022/data/supervised_checkpoints/final/1/supervised.pth"

CKPT = torch.load(CKPT_PATH, map_location = torch.device("cpu"))


CONFIG = CKPT["config"]


num_params = sum(p.numel() for p in CKPT["transformer"].values())

num_params

  CKPT = torch.load(CKPT_PATH, map_location = torch.device("cpu"))


36756296