In [1]:
%load_ext autoreload
%autoreload 2

%load_ext memory_profiler



In [2]:
# Check current memory usage
from utils import MPS_MemoryTracker

with MPS_MemoryTracker():
    pass # Some operation here

Python (GC) and MPS cache emptied
######## Memory consumption:
        MPS tensors    MPS Total    Process Memory
------  -------------  -----------  ----------------
Before  0              0            238
After   0              0            238
Diff    +0 MB          +0 MB        +0 MB


# Check MPS device available

In [3]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


# Load model from HF

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationMixin

def get_cpu_model() -> GenerationMixin: 
    torch.set_default_device("cpu")
    return AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2", torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True
    )
def get_mps_model() -> GenerationMixin: 
    torch.set_default_device("mps")
    return AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2", torch_dtype='auto',#torch.float32, 
        device_map="mps",
        trust_remote_code=True
    )

In [5]:
if 'model' in locals():
    del model # if you need to clean existing model from memory. But, doesn't work for MPS:(

with MPS_MemoryTracker(clean_cache_before=True, clean_cache_after=True):
    # model: GenerationMixin = get_cpu_model()    
    %memit model: GenerationMixin = get_mps_model()

Python (GC) and MPS cache emptied


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

peak memory: 324.22 MiB, increment: 21.97 MiB
Python (GC) and MPS cache emptied
######## Memory consumption:
        MPS tensors    MPS Total    Process Memory
------  -------------  -----------  ----------------
Before  0              9            302
After   5572           6460         271
Diff    +5572 MB       +6451 MB     -31 MB


In [38]:
# TODO: MPS doesn't release memory even after transferring model to cpu devide!?
# model.cpu(); del model

# So, only restarting process helps


# Check precision used

In [6]:
model.config.torch_dtype

torch.float16

In [7]:
for name, param in model.named_parameters():
    assert param.dtype == torch.float16

# Tokenize input

In [9]:
# input_txt = """
# ## INPUT
# Write a short poem about deep learning
# ## Output
# """
input_txt = 'What are the prime factors of 10?'


In [12]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
inputs = tokenizer(input_txt, return_tensors="pt", return_attention_mask=False)
inputs

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


{'input_ids': tensor([[2061,  389,  262, 6994, 5087,  286,  838,   30]], device='mps:0')}

# Perform inference


In [16]:
%%time

# In order to measure CPU/GPU usage, run powermetrics tracker during inference:
# `sudo nice -n 10 powermetrics --samplers cpu_power,gpu_power,thermal -o powermetrics.txt -f plist -i 1000`
# or `make track`

with MPS_MemoryTracker(clean_cache_before=True, clean_cache_after=True):
    outputs = model.generate(**inputs, max_length=200, )


Python (GC) and MPS cache emptied
Python (GC) and MPS cache emptied
######## Memory consumption:
        MPS tensors    MPS Total    Process Memory
------  -------------  -----------  ----------------
Before  5572           6727         2789
After   5572           6727         2800
Diff    +0 MB          +0 MB        +11 MB
CPU times: user 20.7 s, sys: 875 ms, total: 21.6 s
Wall time: 20.6 s


In [14]:
text = tokenizer.batch_decode(outputs)[0]
print(text)


What are the prime factors of 10?
<|question|>Student: The prime factors of 10 are 2 and 5.
<|question_end|>Tutor: That's correct! Now, let's find the prime factors of the denominator, which is 20.
<|question|>Student: The prime factors of 20 are 2 and 5.
<|question_end|>Tutor: Good job! Now, we need to find the common prime factors between the numerator and the denominator. What are they?
<|question|>Student: The common prime factors are 2 and 5.
<|question_end|>Tutor: Excellent! Now, we need to find the highest power of each common prime factor. What is the highest power of 2 in both the numerator and the denominator?
<|question|>Student: The highest power of 2 is 1.
<|question_end|>T


## Measure time

In [15]:
%timeit model.generate(**inputs, max_length=200, )

20.9 s ± 229 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
