In [1]:
MAX_TOKENS = 200  # restriction for inference

In [2]:
print('Check params used:')
print('MAX_TOKENS =', MAX_TOKENS)

Check params used:
MAX_TOKENS = 200


In [3]:
%load_ext autoreload
%autoreload 2

%load_ext memory_profiler

In [4]:
# Check output - should be ARM!
# If it is i386 (and you have M series machine) then you are using a non-native Python. Switch your Python to a native Python. A good way to do this is with Conda.


!python -c "import platform; print(platform.processor())"


arm


In [5]:
from utils import MPS_MemoryTracker

print("Check current memory usage")

with MPS_MemoryTracker():
    pass

Check current memory usage
Cache emptied: python (GC) and MPS 
######## Memory consumption ########:
        MPS tensors    MPS Total    Process Memory    Total System Memory
------  -------------  -----------  ----------------  ---------------------
Before  0              0            245               14292
After   0              0            245               14292
Diff    +0 MB          +0 MB        +0 MB             +0 MB


# Download weights and convert them into MLX format

In [6]:
%%memit
 
from phi2_mlx import convert
# download weights and convert them into MLX format

print("Downloading weights and convertin to MLX format:")

with MPS_MemoryTracker(clean_cache_before=True, clean_cache_after=True):
    convert()

# Unfortunately, this function leaks memory, which could be seen in the ouput of memory tracker 

# objgraph.show_growth()

Downloading weights and convertin to MLX format:
Cache emptied: python (GC) and MPS 


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Cache emptied: python (GC) and MPS 
######## Memory consumption ########:
        MPS tensors    MPS Total    Process Memory    Total System Memory
------  -------------  -----------  ----------------  ---------------------
Before  0              0            267               14324
After   0              0            2947              17789
Diff    +0 MB          +0 MB        +2680 MB          +3465 MB
peak memory: 9383.31 MiB, increment: 9138.03 MiB


# Load the model

In [7]:
from phi2_mlx import load_model, generate, get_tokenizer


In [8]:
%%memit

if 'model' in locals():
    print('Delete existing model obj')
    del model

print("Loading of MLX model:")
with MPS_MemoryTracker():
    model = load_model()

Loading of MLX model:
Cache emptied: python (GC) and MPS 
######## Memory consumption ########:
        MPS tensors    MPS Total    Process Memory    Total System Memory
------  -------------  -----------  ----------------  ---------------------
Before  0              0            2947              17815
After   0              5560         6822              19793
Diff    +0 MB          +5560 MB     +3875 MB          +1978 MB
peak memory: 7301.50 MiB, increment: 4353.78 MiB


# Tokenize input

In [9]:
from transformers import AutoTokenizer


input_txt = 'What are the prime factors of 10?'

tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
inputs = tokenizer(input_txt, return_tensors="np", return_attention_mask=False)
inputs


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


{'input_ids': array([[2061,  389,  262, 6994, 5087,  286,  838,   30]])}

# Model inference

In [10]:
import mlx.core as mx

def generate_tokens(inputs, temp, max_tokens,
                    eos_token_id=tokenizer.eos_token_id
                    ):
    inputs = mx.array(inputs["input_ids"])

    tokens = []
    for token, ind in zip(generate(inputs, model, temp), range(max_tokens)):
        if token.item() == eos_token_id:
            print(f'---DEBUG--- EOS generated at {ind} position')
            break
        tokens.append(token)
    return tokens

In [11]:
# %%memit  # to use memit here, set TOKENIZERS_PARALLELISM=false


# Run `make track` during cell execution in mesure CPU/GPU usage
print('Model inference:')
with MPS_MemoryTracker(clean_cache_before=True, clean_cache_after=True):
    tokens = generate_tokens(inputs, temp=0.2, max_tokens=MAX_TOKENS, eos_token_id=-1)

len(tokens)

Model inference:
Cache emptied: python (GC) and MPS 
Cache emptied: python (GC) and MPS 
######## Memory consumption ########:
        MPS tensors    MPS Total    Process Memory    Total System Memory
------  -------------  -----------  ----------------  ---------------------
Before  0              5560         6846              19833
After   0              10052        4740              23719
Diff    +0 MB          +4492 MB     -2106 MB          +3886 MB


200

# Time measure

In [12]:
print('Time measurement for model inference:')

%timeit generate_tokens(inputs, temp=0.2, max_tokens=MAX_TOKENS, eos_token_id=-1)

Time measurement for model inference:


# Full text-to-text example

In [None]:


def generate_txt(input_txt: str, tokenizer, temp=0.2, max_tokens=MAX_TOKENS):
    inputs = tokenizer(
        input_txt,
        return_tensors="np",
        return_attention_mask=False,
    )


    print("[INFO] Generating with Phi-2 on MLX...", flush=True)
    print(input_txt, end="", flush=True)

    tokens = generate_tokens(inputs, temp, max_tokens, eos_token_id=tokenizer.eos_token_id)

    mx.eval(tokens)
    s = tokenizer.decode([t.item() for t in tokens])
    return s    

input_txt = """
## INPUT
Write a short poem about deep learning
## Output
"""
text = generate_txt(input_txt, tokenizer, max_tokens=1000,)
print(text)


[INFO] Generating with Phi-2 on MLX...



## INPUT
Write a short poem about deep learning
## Output
---DEBUG--- EOS generated at 58 position
Deep learning is a powerful tool
That can learn from data and rules
It can recognize patterns and features
And make predictions with accuracy
Deep learning is a fascinating field
That can solve many problems and challenges
It can inspire creativity and curiosity
And make the world a better place

