# Pytorch_memlab
**A library for memory profiling. Uses torch.cuda.memory_stats() inside.**

In [1]:
!pip3 install torch==1.6.0 transformers==3.2.0 pytorch-memlab==0.2.4

You should consider upgrading via the '/home/lexi/.venv/bin/python3 -m pip install --upgrade pip' command.[0m


In [2]:
import torch
from pytorch_memlab import LineProfiler, MemReporter, profile
from transformers import BertForTokenClassification, BertTokenizerFast

In [3]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
model = BertForTokenClassification.from_pretrained(
                'bert-base-cased',
                num_labels=10
).cuda()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

In [5]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')

# Memory Reporter

We can inspect the memory used by the model tensors.

In [6]:
reporter = MemReporter(model)
reporter.report(device=device)

Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
Tensor207                                           (1, 512)     4.00K
bert.embeddings.word_embeddings.weight          (28996, 768)    84.95M
bert.embeddings.position_embeddings.weight          (512, 768)     1.50M
bert.embeddings.token_type_embeddings.weight            (2, 768)     6.00K
bert.embeddings.LayerNorm.weight                      (768,)     3.00K
bert.embeddings.LayerNorm.bias                        (768,)     3.00K
bert.encoder.layer.0.attention.self.query.weight          (768, 768)     2.25M
bert.encoder.layer.0.attention.self.query.bias              (768,)     3.00K
bert.encoder.layer.0.attention.self.key.weight          (768, 768)     2.25M
bert.encoder.layer.0.attention.self.key.bias              (768,)     3.00K
bert.encoder.layer.0.attention.self.value.weight          (768, 768)     2.25M
bert.encoder



In [7]:
data = tokenizer(['This is a sentence'], return_tensors='pt').to(device)
labels = torch.Tensor([1] * len(data.input_ids[0])).to(dtype=torch.long).cuda()

In [8]:
loss, logits = model(data.input_ids, token_type_ids=None, attention_mask=data.attention_mask, labels=labels)
loss.backward()

We calculated the gradients and they are now shown in memory inspection.

In [9]:
reporter = MemReporter(model)
reporter.report(device=device)

Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
bert.embeddings.word_embeddings.weight          (28996, 768)    84.95M
bert.embeddings.word_embeddings.weight.grad        (28996, 768)    84.95M
bert.embeddings.position_embeddings.weight          (512, 768)     1.50M
bert.embeddings.position_embeddings.weight.grad          (512, 768)     1.50M
bert.embeddings.token_type_embeddings.weight            (2, 768)     6.00K
bert.embeddings.token_type_embeddings.weight.grad            (2, 768)     6.00K
bert.embeddings.LayerNorm.weight                      (768,)     3.00K
bert.embeddings.LayerNorm.weight.grad                 (768,)     3.00K
bert.embeddings.LayerNorm.bias                        (768,)     3.00K
bert.embeddings.LayerNorm.bias.grad                   (768,)     3.00K
bert.pooler.dense.weight                          (768, 768)     2.25M
bert.pooler.dense.bias   

## Shared parameters

We can also see that some variables are shared: reused memory is shown by '->'

In [10]:
%reset -f

In [11]:
import torch
from pytorch_memlab import LineProfiler, MemReporter
device = torch.device('cuda:0')

In [12]:
# use verbose=True to see reused memory
lstm = torch.nn.LSTM(1024, 1024).cuda()
reporter = MemReporter(lstm)
reporter.report(device=device, verbose=True)

Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
weight_ih_l0                                    (4096, 1024)    32.03M
weight_hh_l0(->weight_ih_l0)                    (4096, 1024)     0.00B
bias_ih_l0(->weight_ih_l0)                           (4096,)     0.00B
bias_hh_l0(->weight_ih_l0)                           (4096,)     0.00B
-------------------------------------------------------------------------------
Total Tensors: 8396800 	Used Memory: 32.03M
The allocated memory on cuda:0: 32.03M
-------------------------------------------------------------------------------


## Leaking memory

Sometimes used memory and allocated memory are not equal. This is due to memory leaks, the fact of which you can see but unfortunately not inspect. In the example below *input_tensor + 2* is a temporary operation result which is stored but not shown in memory inspection.

(Actually, if you try to run this notebook on torch==1.10.2 + transformers==4.17.0 + pytorch-memlab==0.2.4, memory leakage is gone - you'll see Tensor2 to account for a temporary result.)

In [13]:
%reset -f

In [14]:
import torch
from pytorch_memlab import LineProfiler, MemReporter

linear = torch.nn.Linear(1024, 1024).cuda()
input_tensor = torch.Tensor(512, 1024).cuda()
reporter = MemReporter(linear)
reporter.report()

out = linear(input_tensor * (input_tensor + 2)).mean()
reporter.report()

Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
weight                                          (1024, 1024)     4.00M
bias                                                 (1024,)     4.00K
Tensor0                                          (512, 1024)     2.00M
-------------------------------------------------------------------------------
Total Tensors: 1573888 	Used Memory: 6.00M
The allocated memory on cuda:0: 6.00M
-------------------------------------------------------------------------------
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
weight                                          (1024, 1024)     4.00M
bias                                                 (1024,)     4.00K
Tensor0                                          (512, 1024)     2

# Line Profiler

Line profiler can show memory usage line by line.

In [15]:
%reset -f

In [16]:
import torch
from pytorch_memlab import LineProfiler, MemReporter, profile
from transformers import BertForTokenClassification, BertTokenizerFast, BertModel

### A simple case

In [17]:
def inner():
    torch.nn.Linear(100, 100).cuda()

def outer():
    linear = torch.nn.Linear(100, 100).cuda()
    linear2 = torch.nn.Linear(100, 100).cuda()
    inner()

with LineProfiler(outer, inner) as prof:
    outer()
prof.display()

active_bytes,reserved_bytes,line,code
all,all,Unnamed: 2_level_1,Unnamed: 3_level_1
peak,peak,Unnamed: 2_level_2,Unnamed: 3_level_2
0.00B,0.00B,4,def outer():
40.00K,2.00M,5,"linear = torch.nn.Linear(100, 100).cuda()"
80.00K,2.00M,6,"linear2 = torch.nn.Linear(100, 100).cuda()"
120.00K,2.00M,7,inner()

active_bytes,reserved_bytes,line,code
all,all,Unnamed: 2_level_1,Unnamed: 3_level_1
peak,peak,Unnamed: 2_level_2,Unnamed: 3_level_2
80.00K,2.00M,1,def inner():
120.00K,2.00M,2,"torch.nn.Linear(100, 100).cuda()"


### Trying to profile BERT

In [18]:
def initialize_model():
    model = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=10).cuda()
    return model

In [19]:
def get_data():
    device = torch.device('cuda:0')
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
    data = tokenizer(['This is a sentence'], return_tensors='pt').to(device)
    labels = torch.Tensor([1] * len(data.input_ids[0])).to(dtype=torch.long).cuda()
    return data, labels

In [20]:
def run_model():
    model = initialize_model()
    data, labels = get_data()
    loss, logits = model(data.input_ids, token_type_ids=None, attention_mask=data.attention_mask, labels=labels)
    return loss

In [21]:
with LineProfiler(run_model, initialize_model, get_data) as prof:
    run_model()
prof.display()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

active_bytes,reserved_bytes,line,code
all,all,Unnamed: 2_level_1,Unnamed: 3_level_1
peak,peak,Unnamed: 2_level_2,Unnamed: 3_level_2
0.00B,0.00B,1,def run_model():
413.70M,468.00M,2,model = initialize_model()
413.71M,468.00M,3,"data, labels = get_data()"
417.21M,472.00M,4,"loss, logits = model(data.input_ids, token_type_ids=None, attention_mask=data.attention_mask, labels=labels)"
417.19M,472.00M,5,return loss

active_bytes,reserved_bytes,line,code
all,all,Unnamed: 2_level_1,Unnamed: 3_level_1
peak,peak,Unnamed: 2_level_2,Unnamed: 3_level_2
0.00B,0.00B,1,def initialize_model():
413.70M,468.00M,2,"model = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=10).cuda()"
413.70M,468.00M,3,return model

active_bytes,reserved_bytes,line,code
all,all,Unnamed: 2_level_1,Unnamed: 3_level_1
peak,peak,Unnamed: 2_level_2,Unnamed: 3_level_2
413.70M,468.00M,1,def get_data():
413.70M,468.00M,2,device = torch.device('cuda:0')
413.70M,468.00M,3,tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
413.71M,468.00M,4,"data = tokenizer(['This is a sentence'], return_tensors='pt').to(device)"
413.71M,468.00M,5,labels = torch.Tensor([1] * len(data.input_ids[0])).to(dtype=torch.long).cuda()
413.71M,468.00M,6,"return data, labels"


Not very much useful data. Let's try to look at the BERT forward function...

### BERT forward function

In [22]:
%reset -f

In [23]:
import torch
from pytorch_memlab import LineProfiler, MemReporter, profile
from transformers import BertForTokenClassification, BertTokenizerFast

class ProfiledBertForTokenClassification(BertForTokenClassification):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, *args, **kwargs):
        with LineProfiler(super().forward) as prof:
            result = super().forward(*args, **kwargs)
        # jupyter display stops working here, so I had to print stats
        print(prof.display())
        return result

model = ProfiledBertForTokenClassification.from_pretrained('bert-base-cased', num_labels=10).cuda()
device = torch.device('cuda:0')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
data = tokenizer(['This is a sentence'], return_tensors='pt').to(device)
labels = torch.Tensor([1] * len(data.input_ids[0])).to(dtype=torch.long).cuda()
loss, logits = model(data.input_ids, token_type_ids=None, attention_mask=data.attention_mask, labels=labels)

Some weights of the model checkpoint at bert-base-cased were not used when initializing ProfiledBertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing ProfiledBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing ProfiledBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ProfiledBertForTokenClassification were not initialized from the m

## BertForTokenClassification.forward

active_bytes reserved_bytes  line code                                                                                                  
         all            all                                                                                                             
        peak           peak                                                                                                             
     413.71M        468.00M  1463      @add_start_docstrings_to_callable(BERT_INP...                                                    
                             1464      @add_code_sample_docstrings(              ...                                                    
                             1465          tokenizer_class=_TOKENIZER_FOR_DOC,   ...                                                    
                             1466          checkpoint="bert-base-uncased",       ...                                                    
  

Unfortunately, to get deeper - in self.bert layers, for example - we'll need to wrap self.bert call (line 1490) in LineProfiler.