Skip to content

Commit

Permalink
Weight & Biases publication (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrigal committed Nov 17, 2023
1 parent 58a30d9 commit e82c609
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 101 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Expand Up @@ -18,4 +18,4 @@ jobs:
- name: install pre-commit
run: pip install pre-commit
- name: run pre-commit
run: pre-commit run -a
run: pre-commit run -a --show-diff-on-failure
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -161,3 +161,4 @@ cython_debug/

# Parser output dir
output
wandb
9 changes: 9 additions & 0 deletions .isort.cfg
@@ -0,0 +1,9 @@
[settings]
# Compatible with black
multi_line_output = 3
include_trailing_comma = True
force_grid_wrap = 0
use_parentheses = True
line_length = 120

known_third_party = PyYAML,wandb
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Expand Up @@ -3,7 +3,7 @@ repos:
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "--line-length", "120"]
args: ["--profile", "black", "--line-length", "120", "--project", "translations_parser"]
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
@@ -1 +1,2 @@
PyYAML==6.0.1
wandb==0.16.0
36 changes: 32 additions & 4 deletions translations_parser/cli.py
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path

from translations_parser.parser import TrainingParser
from translations_parser.publishers import CSVExport, WandB


def get_args():
Expand All @@ -20,14 +21,41 @@ def get_args():
type=Path,
default=Path(__file__).parent.parent / "output",
)
parser.add_argument(
"--wandb-project",
help="Publish the training run to a Weight & Biases project.",
default=None,
)
parser.add_argument(
"--wandb-group",
help="Add the training run to a Weight & Biases group e.g. by language pair or experiment.",
default=None,
)
parser.add_argument(
"--wandb-run-name",
help="Use a custom name for the Weight & Biases run.",
default=None,
)
return parser.parse_args()


def main():
args = get_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
with args.input_file.open("r") as f:
lines = (line.strip() for line in f.readlines())
parser = TrainingParser(lines)
parser.parse()
args.output_dir.mkdir(parents=True, exist_ok=True)
parser.csv_export(args.output_dir)
publishers = [CSVExport(output_dir=args.output_dir)]
if args.wandb_project:
publishers.append(
WandB(
project=args.wandb_project,
group=args.wandb_group,
tags=["cli"],
name=args.wandb_run_name,
config={
"logs_file": args.input_file,
},
)
)
parser = TrainingParser(lines, publishers=publishers)
parser.run()
40 changes: 40 additions & 0 deletions translations_parser/data.py
@@ -0,0 +1,40 @@
from dataclasses import dataclass
from datetime import datetime
from typing import List


@dataclass
class TrainingEpoch:
epoch: int
up: int
sen: int
cost: float
time: float
rate: float
gnorm: float


@dataclass
class ValidationEpoch:
epoch: int
up: int
chrf: float
ce_mean_words: float
bleu_detok: float


@dataclass
class TrainingLog:
"""Results from the parsing of a training log file"""

# Runtime configuration
configuration: dict
training: List[TrainingEpoch]
validation: List[ValidationEpoch]
# Dict of log lines indexed by their header (e.g. marian, data, memory)
logs: dict
run_date: datetime

@property
def logs_str(self):
return "\n".join("".join(f"[{key}] {val}\n" for val in values) for key, values in self.logs.items())
166 changes: 71 additions & 95 deletions translations_parser/parser.py
@@ -1,13 +1,14 @@
import csv
import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from collections.abc import Iterable, Sequence
from datetime import datetime
from typing import List

import yaml

from translations_parser.data import TrainingEpoch, TrainingLog, ValidationEpoch
from translations_parser.publishers import Publisher

logging.basicConfig(
level=logging.INFO,
format="[%(levelname)s] %(message)s",
Expand All @@ -29,54 +30,28 @@
MARIAN_MAJOR, MARIAN_MINOR = 1, 10


@dataclass
class TrainingEpoch:
epoch: str
up: str
sen: str
cost: str
time: int
rate: str
gnorm: str


@dataclass
class ValidationEpoch:
epoch: str
up: str
chrf: str
ce_mean_words: str
bleu_detok: str


@dataclass
class TrainingLog:
"""Results from the parsing of a training log file"""

# Marian information
info: dict
# Runtime configuration
configuration: dict
training: List[TrainingEpoch]
validation: List[ValidationEpoch]
# Dict of log lines indexed by their header (e.g. marian, data, memory)
logs: dict


class TrainingParser:
def __init__(self, logs_iter):
def __init__(self, logs_iter: Iterable[str], publishers: Sequence[Publisher]):
# Iterable reading logs lines
self.logs_iter = logs_iter
self._current_index = 0
self.parsed = False
self.config = {}
self.indexed_logs = defaultdict(list)
# List of TrainingEpoch
self.training = []
# List of ValidationEpoch
self.validation = []
# Dict mapping (epoch, up) to values parsed on multiple lines
self.validation_entries = defaultdict(dict)
self._validation_entries = defaultdict(dict)
# Marian exection data
self.version = None
self.version_hash = None
self.release_date = None
self.run_date = None
self.description = None
# Data publication after parsing logs
self.publishers = publishers

def get_headers(self, line):
"""
Expand Down Expand Up @@ -105,35 +80,43 @@ def parse_training_log(self, headers, text):
if not match:
return
values = match.groupdict()
# Transform sen value from 1,234,567 to 1234567
values["sen"] = values["sen"].replace(",", "")
self.training.append(TrainingEpoch(**values))
# Update sen value from 1,234,567 to 1_234_567 that Python interprets
values["sen"] = values["sen"].replace(",", "_")
# Transform values to match output types
values = {k: TrainingEpoch.__annotations__[k](v) for k, v in values.items()}
training_epoch = TrainingEpoch(**values)
self.training.append(training_epoch)
return training_epoch

def parse_validation_log(self, headers, text):
if ("valid",) not in headers:
return
match = VALIDATION_RE.match(text)
if not match:
if ("valid",) not in headers or not (match := VALIDATION_RE.match(text)):
return
epoch, up, key, val = match.groups()
# Replace items keys to match ValidationEpoch dataclass
key = key.replace("-", "_")
self.validation_entries[(epoch, up)].update({key: val})
# Transform values to match output types
epoch, up = int(epoch), int(up)
val = ValidationEpoch.__annotations__[key](val)
self._validation_entries[(epoch, up)].update({key: val})
return (epoch, up)

def _iter_log_entries(self):
for index, line in enumerate(self.logs_iter, start=1):
for line in self.logs_iter:
self._current_index += 1
headers, position = self.get_headers(line)
timestamp = next((ts for ts in map(self.check_task_timestamp_header, headers) if ts), None)
if timestamp is None:
logger.debug(f"Skipping line {index} : Headers does not match [task <timestamp>]")
logger.debug(f"Skipping line {self._current_index} : Headers does not match [task <timestamp>]")
continue
elif self.run_date is None:
self.run_date = timestamp
text = line[position:]

# Record logs depending on Marian headers
if len(headers) >= 2:
# First is task timestamp, second is marian timestamp
_, _, *marian_tags = headers
tag = "_".join(*marian_tags) if marian_tags else "_default"
tag = "_".join(*marian_tags) if marian_tags else "_"
self.indexed_logs[tag].append(text)

yield headers, text
Expand Down Expand Up @@ -161,6 +144,7 @@ def _parse(self):
if ("marian",) not in headers:
break
desc.append(text)
self.description = " ".join(desc)

# Try to parse all following config lines as YAML
config_yaml = ""
Expand All @@ -177,38 +161,30 @@ def _parse(self):

# Iterate until the end of file to find training or validation logs
while True:
if train := self.parse_training_log(headers, text):
self.training.append(train)
elif val := self.parse_validation_log(headers, text):
self.validation.append(val)
try:
headers, text = next(logs_iter)
try:
training = self.parse_training_log(headers, text)
if not training:
self.parse_validation_log(headers, text)
except ValueError as e:
logger.warning(f"Line {self._current_index} could not be stored: {e}.")
headers, text = next(logs_iter)
finally:
headers, text = next(logs_iter)
except StopIteration:
break

count = sum(len(vals) for vals in self.indexed_logs.values())
logger.info(f"Successfully parsed {count} lines")
logger.info(f"Found {len(self.training)} training entries")
logger.info(f"Found {len(list(self.validation))} validation entries")
# Build validation epochs from matched log entries
for validation in self.build_validation_epochs():
self.validation.append(validation)
self.parsed = True

def parse(self):
"""
Parse the log lines
A StopIteration can be raised if some required lines are never found
"""
try:
self._parse()
except StopIteration:
raise ValueError("Logs file ended up unexpectedly")

@property
def validation(self):
def build_validation_epochs(self):
"""
Build validation entries from complete entries
as validation logs are displayed on multiple lines
"""
for (epoch, up), parsed in self.validation_entries.items():
for (epoch, up), parsed in self._validation_entries.items():
# Ensure required keys have been parsed
diff = set(("chrf", "ce_mean_words", "bleu_detok")) - set(parsed.keys())
if diff:
Expand All @@ -221,32 +197,32 @@ def output(self):
if not self.parsed:
raise Exception("Please run the parser before reading the output")
return TrainingLog(
info=self.info,
run_date=self.run_date,
configuration=self.config,
training=self.training,
validation=list(self.validation),
logs=self.indexed_logs,
)

def csv_export(self, output_dir):
assert output_dir.is_dir(), "Output must be a valid directory"
# Publish two files, validation.csv and training.csv
training_output = output_dir / "training.csv"
if training_output.exists():
print(f"Output file {training_output} exists, skipping.")
else:
with open(training_output, "w") as f:
writer = csv.DictWriter(f, fieldnames=TrainingEpoch.__annotations__)
writer.writeheader()
for entry in self.training:
writer.writerow(vars(entry))

validation_output = output_dir / "validation.csv"
if validation_output.exists():
print(f"Output file {validation_output} exists, skipping.")
else:
with open(validation_output, "w") as f:
writer = csv.DictWriter(f, fieldnames=ValidationEpoch.__annotations__)
writer.writeheader()
for entry in self.validation:
writer.writerow(vars(entry))
def publish(self, publisher):
logger.info(f"Publishing data using {publisher.__class__.__name__}")
publisher.publish(self.output)
publisher.close()

def run(self):
"""
Parse the log lines.
"""
try:
self._parse()
except StopIteration:
# A StopIteration can be raised if some required lines are never found.
raise ValueError("Logs file ended up unexpectedly")

count = sum(len(vals) for vals in self.indexed_logs.values())
logger.info(f"Successfully parsed {count} lines")
logger.info(f"Found {len(self.training)} training entries")
logger.info(f"Found {len(self.validation)} validation entries")

for publisher in self.publishers:
self.publish(publisher)

0 comments on commit e82c609

Please sign in to comment.