In [1]:
%load_ext autoreload
%autoreload 2

%load_ext memory_profiler


In [2]:
USE_MPS = True  # if False - use CPU
MAX_TOKENS = 200  # restriction for inference


In [3]:
print('Check params:')
print(f'{MAX_TOKENS=}')
print(f'{USE_MPS=}')

Check params:
MAX_TOKENS=200
USE_MPS=True


In [4]:
# Check current memory usage
from utils import MPS_MemoryTracker

with MPS_MemoryTracker():
    print('Check current memory:')

Cache emptied: python (GC) and MPS 
Check current memory:
######## Memory consumption ########:
        MPS tensors    MPS Total    Process Memory    Total System Memory
------  -------------  -----------  ----------------  ---------------------
Before  0              0            244               16549
After   0              0            244               16549
Diff    +0 MB          +0 MB        +0 MB             +0 MB


# Check MPS device available

In [5]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


# Load model from HF

In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationMixin

def get_cpu_model() -> GenerationMixin: 
    torch.set_default_device("cpu")
    return AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2", 
        torch_dtype=torch.float32,  # You can't use half-precision for CPU
        device_map="cpu", trust_remote_code=True
    )
def get_mps_model() -> GenerationMixin: 
    torch.set_default_device("mps")
    return AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2", torch_dtype='auto',#torch.float32, 
        device_map="mps",
        trust_remote_code=True
    )

In [7]:

if 'model' in locals():
    # if you need to clean existing model from memory. But, doesn't work for MPS:(
    del model  # noqa: F821

print('Loading of PyTorch model:')
model = None# to reserve variable outside context manager
with MPS_MemoryTracker(clean_cache_before=True, clean_cache_after=True):
    %memit model: GenerationMixin = get_mps_model() if USE_MPS else get_cpu_model()

Loading of PyTorch model:
Cache emptied: python (GC) and MPS 


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

peak memory: 327.94 MiB, increment: 22.08 MiB
Cache emptied: python (GC) and MPS 
######## Memory consumption ########:
        MPS tensors    MPS Total    Process Memory    Total System Memory
------  -------------  -----------  ----------------  ---------------------
Before  0              11           305               16583
After   5572           6460         267               20252
Diff    +5572 MB       +6449 MB     -38 MB            +3669 MB


In [8]:
# TODO: MPS doesn't release memory even after transferring model to cpu devide!?
# model.cpu(); del model

# So, only restarting process helps


# Check precision used

In [9]:
print('dtype:')
model.config.torch_dtype

dtype:


torch.float16

In [10]:
for name, param in model.named_parameters():
    assert param.dtype == (torch.float16 if USE_MPS else torch.float32)

# Tokenize input

In [11]:
# input_txt = """
# ## INPUT
# Write a short poem about deep learning
# ## Output
# """
input_txt = 'What are the prime factors of 10?'


In [12]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
inputs = tokenizer(input_txt, return_tensors="pt", return_attention_mask=False)
inputs

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


{'input_ids': tensor([[2061,  389,  262, 6994, 5087,  286,  838,   30]], device='mps:0')}

In [13]:
inputs['input_ids'].device.type == ('mps' if USE_MPS else 'cpu')

True

# Perform inference


In [14]:
# %%time

# In order to measure CPU/GPU usage, run powermetrics tracker during inference:
# `sudo nice -n 10 powermetrics --samplers cpu_power,gpu_power,thermal -o powermetrics.txt -f plist -i 1000`
# or `make track`

print('Model inference:')

outputs = None  # to reserve variable outside context manager
with MPS_MemoryTracker(clean_cache_before=True, clean_cache_after=True):
    outputs = model.generate(**inputs, max_length=MAX_TOKENS, )
    
print("Check outputs size:", outputs.shape)
# assert outputs[1, 200] == MAX_TOKENS



Model inference:
Cache emptied: python (GC) and MPS 
Cache emptied: python (GC) and MPS 
######## Memory consumption ########:
        MPS tensors    MPS Total    Process Memory    Total System Memory
------  -------------  -----------  ----------------  ---------------------
Before  5572           6460         297               20333
After   5572           6727         2705              22576
Diff    +0 MB          +267 MB      +2408 MB          +2243 MB
Check outputs size: torch.Size([1, 200])


In [15]:
text = tokenizer.batch_decode(outputs)[0]
print(text)


What are the prime factors of 10?
<|question|>Student: The prime factors of 10 are 2 and 5.
<|question_end|>Tutor: That's correct! Now, let's find the prime factors of the denominator, which is 20.
<|question|>Student: The prime factors of 20 are 2 and 5.
<|question_end|>Tutor: Good job! Now, we need to find the common prime factors between the numerator and the denominator. What are they?
<|question|>Student: The common prime factors are 2 and 5.
<|question_end|>Tutor: Excellent! Now, we need to find the highest power of each common prime factor. What is the highest power of 2 in both the numerator and the denominator?
<|question|>Student: The highest power of 2 is 1.
<|question_end|>T


## Measure time

In [16]:
print('Measure time of inference:')
%timeit model.generate(**inputs, max_length=MAX_TOKENS, )

Measure time of inference:
21.6 s ± 521 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
