Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lighteval.evaluator import evaluate, make_results_table
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import hlog, htrack, htrack_block
from lighteval.models.model_config import EnvConfig
from lighteval.models.model_loader import ModelInfo
from lighteval.models.nanotron_model import NanotronLightevalModel
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
Expand Down Expand Up @@ -35,7 +36,7 @@

@htrack()
def main(
local_config_path: str,
checkpoint_config_path: str,
lighteval_config_path: Optional[str] = None,
cache_dir: str = None,
config_cls: Type = Config,
Expand All @@ -45,16 +46,16 @@ def main(
if cache_dir is None:
cache_dir = CACHE_DIR

# env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)

dist.initialize_torch_distributed()

with htrack_block("get config"):
if not local_config_path.endswith(".yaml"):
if not checkpoint_config_path.endswith(".yaml"):
raise ValueError("The checkpoint path should point to a YAML file")

nanotron_config: config_cls = get_config_from_file(
local_config_path,
checkpoint_config_path,
config_class=config_cls,
model_config_class=model_config_cls,
skip_unused_config_keys=True,
Expand Down Expand Up @@ -91,7 +92,7 @@ def main(
with htrack_block("Model loading"):
# We need to load the model in the main process first to avoid downloading the model multiple times
model = NanotronLightevalModel(
checkpoint_path=os.path.dirname(local_config_path),
checkpoint_path=os.path.dirname(checkpoint_config_path),
model_args=nanotron_config.model,
tokenizer=nanotron_config.tokenizer,
parallel_context=parallel_context,
Expand All @@ -101,6 +102,7 @@ def main(
cache_dir=os.environ.get("HF_HOME", "/scratch"),
debug_one_layer_model=False,
model_class=model_cls,
env_config=env_config,
)
model_info = ModelInfo(model_name=f"{nanotron_config.general.run}/{nanotron_config.general.step}")
evaluation_tracker.general_config_logger.log_model_info(model_info)
Expand Down
15 changes: 8 additions & 7 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
LoglikelihoodSingleTokenDataset,
)
from lighteval.models.base_model import LightevalModel
from lighteval.models.model_config import EnvConfig
from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
from lighteval.tasks.requests import (
GreedyUntilRequest,
Expand Down Expand Up @@ -71,9 +72,9 @@ def __init__(
add_special_tokens: Optional[bool] = True,
dtype: Optional[Union[str, torch.dtype]] = None,
trust_remote_code: bool = False,
cache_dir: str = "/scratch",
debug_one_layer_model: bool = False,
model_class: Optional[Type] = None,
env_config: EnvConfig = None,
):
"""Initializes a nanotron model for evaluation.
Args:
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(
self._add_special_tokens = add_special_tokens
self._tokenizer = self._create_auto_tokenizer(
pretrained=tokenizer.tokenizer_name_or_path,
cache_dir=cache_dir,
env_config=env_config,
trust_remote_code=trust_remote_code,
)
self._tokenizer.model_max_length = self.max_length
Expand Down Expand Up @@ -206,24 +207,24 @@ def _create_auto_tokenizer(
*,
pretrained: str,
tokenizer: Optional[str] = None,
cache_dir: str = "/scratch",
env_config: EnvConfig = None,
trust_remote_code: bool = False,
) -> transformers.PreTrainedTokenizer:
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""

try:
tokenizer = AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
cache_dir=cache_dir,
token=os.getenv("HUGGING_FACE_HUB_TOKEN"),
cache_dir=env_config.cache_dir,
token=env_config.token,
trust_remote_code=trust_remote_code,
)
except RecursionError:
tokenizer = AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
cache_dir=cache_dir,
cache_dir=env_config.cache_dir,
token=env_config.token,
unk_token="<unk>",
token=os.getenv("HUGGING_FACE_HUB_TOKEN"),
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
Expand Down
28 changes: 22 additions & 6 deletions src/lighteval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,42 @@ def rec(nest: dict, prefix: str, into: dict):
return flat


def clean_s3_links(key, value):
def clean_s3_links(value: str) -> str:
"""Cleans and formats s3 bucket links for better display in the result table (nanotron models)

Args:
value (str): path to clean

Returns:
str : cleaned path
"""
s3_bucket, s3_prefix = str(value).replace("s3://", "").split("/", maxsplit=1)
if not s3_prefix.endswith("/"):
s3_prefix += "/"
link_str = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?prefix={s3_prefix}"
value = f'<a href="{link_str}" target="_blank"> {value} </a>'
return key, value
return value


def obj_to_markdown(obj, convert_s3_links: bool = True) -> str:
"""Convert a (potentially nested) dataclass object or a dict in a readable markdown string for logging"""
from pytablewriter import MarkdownTableWriter

if is_dataclass(obj):
obj = asdict(obj)
config_dict = flatten_dict(obj)
config_markdown = "| Key | Value |\n| --- | --- |\n"

md_writer = MarkdownTableWriter()
md_writer.headers = ["Key", "Value"]

values = []
for key, value in config_dict.items():
if convert_s3_links and "s3://" in str(value):
key, value = clean_s3_links(key, value)
config_markdown += f"| {key} | {value} |\n"
return config_markdown
value = clean_s3_links(value)
values.append([key, value])
md_writer.value_matrix = values

return md_writer.dumps()


def sanitize_numpy(example_dict: dict) -> dict:
Expand Down