In [78]:
#Enable multiple outputs per cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

#Autoreload imported modules
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Test HF Hub Search functions

In [60]:

MISTRAL_REPO = "mistralai/Mistral-7B-v0.1"
CANDLE_MISTRAL_REPO = "lmz/candle-mistral"
ZEPHYR_REPO = "HuggingFaceH4/zephyr-7b-alpha"
DINOV2_REPO = "facebook/dinov2-base"


In [3]:
from functools import partial

from model_checking import find_pytorch_index, find_safetensor_files, find_safetensors_index

def test_search_func(*, func, repo, expected):
    found_files = func(repo)
    assert len(found_files) == expected
    return found_files

test_pytorch_index = partial(test_search_func, func=find_pytorch_index)
test_safetensors_index = partial(test_search_func, func=find_safetensors_index)
test_safetensor_files = partial(test_search_func, func=find_safetensor_files)



  from .autonotebook import tqdm as notebook_tqdm


In [4]:
test_pytorch_index(repo=MISTRAL_REPO, expected=1)
test_safetensors_index(repo=MISTRAL_REPO, expected=0)
test_safetensor_files(repo=MISTRAL_REPO, expected=0)


['mistralai/Mistral-7B-v0.1/pytorch_model.bin.index.json']

[]

[]

In [5]:
test_pytorch_index(repo=CANDLE_MISTRAL_REPO, expected=0)
test_safetensors_index(repo=CANDLE_MISTRAL_REPO, expected=0)
test_safetensor_files(repo=CANDLE_MISTRAL_REPO, expected=2)


[]

[]

['lmz/candle-mistral/pytorch_model-00001-of-00002.safetensors',
 'lmz/candle-mistral/pytorch_model-00002-of-00002.safetensors']

In [6]:
test_pytorch_index(repo=ZEPHYR_REPO, expected=1)
test_safetensors_index(repo=ZEPHYR_REPO, expected=1)
test_safetensor_files(repo=ZEPHYR_REPO, expected=8)


['HuggingFaceH4/zephyr-7b-alpha/pytorch_model.bin.index.json']

['HuggingFaceH4/zephyr-7b-alpha/model.safetensors.index.json']

['HuggingFaceH4/zephyr-7b-alpha/model-00001-of-00008.safetensors',
 'HuggingFaceH4/zephyr-7b-alpha/model-00002-of-00008.safetensors',
 'HuggingFaceH4/zephyr-7b-alpha/model-00003-of-00008.safetensors',
 'HuggingFaceH4/zephyr-7b-alpha/model-00004-of-00008.safetensors',
 'HuggingFaceH4/zephyr-7b-alpha/model-00005-of-00008.safetensors',
 'HuggingFaceH4/zephyr-7b-alpha/model-00006-of-00008.safetensors',
 'HuggingFaceH4/zephyr-7b-alpha/model-00007-of-00008.safetensors',
 'HuggingFaceH4/zephyr-7b-alpha/model-00008-of-00008.safetensors']

### Test Safe Tensor Index Download

In [7]:
from model_checking import get_safetensor_index

def test_safetensor_index_download(repo):
    index = get_safetensor_index(repo)
    assert 'metadata' in index
    assert 'weight_map' in index
    return index

zephyr_index = test_safetensor_index_download(repo=ZEPHYR_REPO)

### Test Safe Tensor Headers
- Parses safetensor files for headers from HF Hub
- Returns aggregated tensors as a dict with tensor names as keys and shapes and dtypes as values 

In [8]:
from model_checking import get_safetensor_headers

def test_safetensor_header_parsing(repo, expected_num_tensors: int):
    headers = get_safetensor_headers(repo, merge=True)
    if '__metadata__' in headers:
        headers.pop('__metadata__')
    assert len(headers) == expected_num_tensors

In [10]:
EXPECTED_MISTRAL_TENSORS = 291
test_safetensor_header_parsing(repo=CANDLE_MISTRAL_REPO, expected_num_tensors=EXPECTED_MISTRAL_TENSORS)
test_safetensor_header_parsing(repo=ZEPHYR_REPO, expected_num_tensors=EXPECTED_MISTRAL_TENSORS)

### Test the pytorch to safetensors conversion

In [14]:
from model_checking import convert_pytorch_to_safetensors

def test_pt_conversion(repo, outdir):
    import os
    from pathlib import Path
    import shutil
    from safetensors.torch import load_file
    import torch
    
    if os.path.exists(outdir):
        shutil.rmtree(outdir)
    _ = convert_pytorch_to_safetensors(repo, outdir=outdir, force=True, upload_to_hub=False)
    
    assert os.path.exists(outdir)
    
    safetensor_files = sorted(os.listdir(outdir))
    
    modelname = os.path.basename(repo)
    cached_model_path = Path("~/.cache/huggingface/hub").expanduser()
    pt_files = cached_model_path.glob(f"models*{modelname}*/**/pytorch_model*.bin")
    pt_files = sorted([f.as_posix() for f in pt_files])
    assert len(safetensor_files) == len(pt_files)
    
    for pt, st in zip(pt_files, safetensor_files):
        pt_weight_map = torch.load(pt)
        safetensor_weight_map = load_file(os.path.join(outdir, st))
        assert len(pt_weight_map) == len(safetensor_weight_map)
        assert pt_weight_map.keys() == safetensor_weight_map.keys()
            

In [15]:
test_pt_conversion(ZEPHYR_REPO, outdir="zephyr")

Converting PyTorch files to SafenTensor format: 100%|██████████| 8/8 [00:27<00:00,  3.38s/it]


Saved safetensors to zephyr


In [16]:
test_pt_conversion(DINOV2_REPO, outdir="dinov2")

Converting PyTorch files to SafenTensor format: 100%|██████████| 1/1 [00:00<00:00,  2.13it/s]

Saved safetensors to dinov2





In [13]:
from model_checking import _upload_to_hub
from huggingface_hub import HfApi

def test_model_upload(repo_id, upload_folder):
    _upload_to_hub(upload_folder, repo_id)

repo_id = "jeromeku/test-upload"
upload_folder = 'dinov2'
api = HfApi()

In [14]:
_upload_to_hub(upload_folder, repo_id)
assert api.repo_exists(repo_id)

### Test model trace checking

In [None]:

from model_checking import parse_trace, get_safetensor_headers

mistral_trace_path = '../fixtures/mistral.trace.json'
mistral_traces = parse_trace(mistral_trace_path)
mistral_tensor_map = get_safetensor_headers(CANDLE_MISTRAL_REPO)

assert len(set(mistral_tensor_map.keys()).symmetric_difference(mistral_traces.keys())) == 0


In [117]:
for m, mtrace in zip(mistral_tensor_map.values(), mistral_traces.values()):
    if m['shape'] != mtrace['shape']:
        print(m, mtrace)
    if m['dtype'] != mtrace['dtype']:
        print(m, mtrace)
#assert all(m['dtype'] == mtrace['dtype'] for m, mtrace in zip(mistral_tensor_map.values(), mistral_traces.values()))
#assert all(m['dtype'] == mtrace['dtype'] for m, mtrace in zip(mistral_tensor_map.values(), mistral_traces.values()))


{'dtype': 'BF16', 'shape': [32000, 4096], 'data_offsets': [0, 262144000]} {'shape': [32000, 4096], 'dtype': 'f32'}
{'dtype': 'BF16', 'shape': [32000, 4096], 'data_offsets': [0, 262144000]} {'shape': [32000, 4096], 'dtype': 'f32'}
{'dtype': 'BF16', 'shape': [4096], 'data_offsets': [262144000, 262152192]} {'shape': [4096], 'dtype': 'f32'}
{'dtype': 'BF16', 'shape': [4096, 14336], 'data_offsets': [262152192, 379592704]} {'shape': [4096, 14336], 'dtype': 'f32'}
{'dtype': 'BF16', 'shape': [14336, 4096], 'data_offsets': [379592704, 497033216]} {'shape': [14336, 4096], 'dtype': 'f32'}
{'dtype': 'BF16', 'shape': [14336, 4096], 'data_offsets': [497033216, 614473728]} {'shape': [14336, 4096], 'dtype': 'f32'}
{'dtype': 'BF16', 'shape': [4096], 'data_offsets': [614473728, 614481920]} {'shape': [4096], 'dtype': 'f32'}
{'dtype': 'BF16', 'shape': [1024, 4096], 'data_offsets': [614481920, 622870528]} {'shape': [1024, 4096], 'dtype': 'f32'}
{'dtype': 'BF16', 'shape': [4096, 4096], 'data_offsets': [6228

In [103]:
zephyr_tensor_map = get_safetensor_headers(ZEPHYR_REPO)

assert len(set(zephyr_tensor_map.keys()).symmetric_difference(mistral_traces.keys())) == 0
assert all(z['shape'] == m['shape'] for z,m in zip(zephyr_tensor_map.values(), mistral_tensor_map.values()))
assert all(z['dtype'] == m['dtype'] for z,m in zip(zephyr_tensor_map.values(), mistral_tensor_map.values()))
