Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add special token modification capability #7166

Merged
merged 18 commits into from
May 9, 2024
92 changes: 72 additions & 20 deletions gguf-py/scripts/gguf-new-metadata.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from pathlib import Path

import numpy as np
from typing import Any, Sequence
from tqdm import tqdm
from typing import Any, Sequence, NamedTuple

# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
Expand All @@ -18,6 +19,12 @@
logger = logging.getLogger("gguf-new-metadata")


class MetadataDetails(NamedTuple):
type: gguf.GGUFValueType
value: Any
description: str = ''


def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
# Host is little endian
Expand Down Expand Up @@ -59,7 +66,16 @@ def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
return decode_field(field)


def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, str], remove_metadata: Sequence[str]) -> None:
def find_token(token_list: Sequence[int], token: str) -> Sequence[int]:
token_ids = [index for index, value in enumerate(token_list) if value == token]

if len(token_ids) == 0:
raise LookupError(f'Unable to find "{token}" in token list!')

return token_ids


def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, MetadataDetails], remove_metadata: Sequence[str]) -> None:
for field in reader.fields.values():
# Suppress virtual fields and fields written by GGUFWriter
if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
Expand All @@ -75,54 +91,64 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
logger.debug(f'Removing {field.name}')
continue

old_val = decode_field(field)
old_val = MetadataDetails(field.types[0], decode_field(field))
val = new_metadata.get(field.name, old_val)

if field.name in new_metadata:
logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"')
logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}')
del new_metadata[field.name]
elif val is not None:
elif val.value is not None:
logger.debug(f'Copying {field.name}')

if val is not None:
if val.value is not None:
writer.add_key(field.name)
writer.add_val(val, field.types[0])
writer.add_val(val.value, val.type)

if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
logger.debug('Adding chat template(s)')
writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE])
writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value)
del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]

# TODO: Support other types than string?
for key, val in new_metadata.items():
logger.debug(f'Adding {key}: {val}')
logger.debug(f'Adding {key}: "{val.value}" {val.description}')
writer.add_key(key)
writer.add_val(val, gguf.GGUFValueType.STRING)
writer.add_val(val.value, val.type)

total_bytes = 0

for tensor in reader.tensors:
total_bytes += tensor.n_bytes
# Dimensions are written in reverse order, so flip them first
shape = np.flipud(tensor.shape).tolist()
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)

bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)

writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_ti_data_to_file()

for tensor in reader.tensors:
writer.write_tensor_data(tensor.data)
bar.update(tensor.n_bytes)

writer.close()


def main() -> None:
tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_'))
token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id'))

parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
parser.add_argument("input", type=Path, help="GGUF format model input filename")
parser.add_argument("output", type=Path, help="GGUF format model output filename")
parser.add_argument("--general-name", type=str, help="The models general.name")
parser.add_argument("--general-description", type=str, help="The models general.description")
parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)")
parser.add_argument("--chat-template-config", type=Path, help="Config file (tokenizer_config.json) containing chat template(s)")
parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model")
parser.add_argument("--general-name", type=str, help="The models general.name", metavar='"name"')
parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."')
parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."')
parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json')
parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url')
parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '"<token>"'))
parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0'))
parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
Expand All @@ -133,20 +159,20 @@ def main() -> None:
remove_metadata = args.remove_metadata or []

if args.general_name:
new_metadata[gguf.Keys.General.NAME] = args.general_name
new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name)

if args.general_description:
new_metadata[gguf.Keys.General.DESCRIPTION] = args.general_description
new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description)

if args.chat_template:
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template)

if args.chat_template_config:
with open(args.chat_template_config, 'r') as fp:
config = json.load(fp)
template = config.get('chat_template')
if template:
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = template
new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template)

if remove_metadata:
logger.warning('*** Warning *** Warning *** Warning **')
Expand All @@ -166,6 +192,32 @@ def main() -> None:
arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
endianess = get_byteorder(reader)

token_list = get_field_data(reader, gguf.Keys.Tokenizer.LIST) or []

for name, token in args.special_token or []:
if name not in token_names:
logger.warning(f'Unknown special token "{name}", ignoring...')
else:
ids = find_token(token_list, token)
new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}')

if len(ids) > 1:
logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:')
logger.warning(', '.join(str(i) for i in ids))

for name, id_string in args.special_token_by_id or []:
if name not in token_names:
logger.warning(f'Unknown special token "{name}", ignoring...')
elif not id_string.isdecimal():
raise LookupError(f'Token ID "{id_string}" is not a valid ID!')
else:
id_int = int(id_string)

if id_int >= 0 and id_int < len(token_list):
new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}')
else:
raise LookupError(f'Token ID {id_int} is not within token list!')

if os.path.isfile(args.output) and not args.force:
logger.warning('*** Warning *** Warning *** Warning **')
logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
Expand Down
Loading