Skip to content

Commit

Permalink
Merge pull request #5719 from Masao-Someki/easy_to_ez
Browse files Browse the repository at this point in the history
Modified easy to ez
  • Loading branch information
mergify[bot] committed Mar 27, 2024
2 parents 7265e7d + 23047f6 commit fa822d5
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 22 deletions.
6 changes: 3 additions & 3 deletions espnetez/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import yaml

from espnetez.task import get_easy_task
from espnetez.task import get_ez_task


def convert_none_to_None(dic):
Expand All @@ -14,7 +14,7 @@ def convert_none_to_None(dic):


def from_yaml(task, path):
task_class = get_easy_task(task)
task_class = get_ez_task(task)
with open(path, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)

Expand All @@ -30,7 +30,7 @@ def from_yaml(task, path):
def update_finetune_config(task, pretrain_config, path):
with open(path, "r") as f:
finetune_config = yaml.load(f, Loader=yaml.Loader)
default_config = get_easy_task(task).get_default_config()
default_config = get_ez_task(task).get_default_config()

# update pretrain_config with finetune_config
# and update distributed related configs to the default.
Expand Down
2 changes: 1 addition & 1 deletion espnetez/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from espnet2.train.dataset import AbsDataset


class ESPnetEasyDataset(AbsDataset):
class ESPnetEZDataset(AbsDataset):
def __init__(self, dataset, data_info):
self.dataset = dataset
self.data_info = data_info
Expand Down
16 changes: 8 additions & 8 deletions espnetez/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# ESPnet-Easy Task class
# ESPnet-EZ Task class
# This class is a wrapper for Task classes to support custom datasets.
import argparse
import logging
Expand Down Expand Up @@ -64,13 +64,13 @@
)


def get_easy_task(task_name: str, use_custom_dataset: bool = False) -> AbsTask:
def get_ez_task(task_name: str, use_custom_dataset: bool = False) -> AbsTask:
task_class = TASK_CLASSES[task_name]

if use_custom_dataset:
return get_easy_task_with_dataset(task_name)
return get_ez_task_with_dataset(task_name)

class ESPnetEasyTask(task_class):
class ESPnetEZTask(task_class):
build_model_fn = None

@classmethod
Expand All @@ -80,13 +80,13 @@ def build_model(cls, args=None):
else:
return task_class.build_model(args=args)

return ESPnetEasyTask
return ESPnetEZTask


def get_easy_task_with_dataset(task_name: str) -> AbsTask:
def get_ez_task_with_dataset(task_name: str) -> AbsTask:
task_class = TASK_CLASSES[task_name]

class ESPnetEasyDataTask(task_class):
class ESPnetEZDataTask(task_class):
build_model_fn = None
train_dataset = None
valid_dataset = None
Expand Down Expand Up @@ -317,4 +317,4 @@ def build_streaming_iterator(
**kwargs,
)

return ESPnetEasyDataTask
return ESPnetEZDataTask
8 changes: 4 additions & 4 deletions espnetez/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from argparse import Namespace

from espnetez.task import get_easy_task
from espnetez.task import get_ez_task


def check_argument(
Expand Down Expand Up @@ -111,18 +111,18 @@ def __init__(
)

if train_dataset is not None and valid_dataset is not None:
self.task_class = get_easy_task(task, use_custom_dataset=True)
self.task_class = get_ez_task(task, use_custom_dataset=True)
self.task_class.train_dataset = train_dataset
self.task_class.valid_dataset = valid_dataset
elif train_dataloader is not None and valid_dataloader is not None:
self.task_class = get_easy_task(task, use_custom_dataset=True)
self.task_class = get_ez_task(task, use_custom_dataset=True)
self.task_class.train_dataloader = train_dataloader
self.task_class.valid_dataloader = valid_dataloader
else:
assert data_info is not None, "data_info should be provided."
assert train_dump_dir is not None, "Please provide train_dump_dir."
assert valid_dump_dir is not None, "Please provide valid_dump_dir."
self.task_class = get_easy_task(task)
self.task_class = get_ez_task(task)
train_dpnt = []
valid_dpnt = []
for k, v in data_info.items():
Expand Down
12 changes: 6 additions & 6 deletions test/espnetez/test_ez.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def test_join_dumps():

@pytest.mark.parametrize("task_name,task_class", TASK_CLASSES)
def test_task(task_name, task_class):
task = ez.task.get_easy_task(task_name)
task = ez.task.get_ez_task(task_name)
assert issubclass(task, task_class)


@pytest.mark.parametrize("task_name,task_class", TASK_CLASSES)
def test_task_with_dataset(task_name, task_class):
task = ez.task.get_easy_task_with_dataset(task_name)
task = ez.task.get_ez_task_with_dataset(task_name)
assert issubclass(task, task_class)


Expand Down Expand Up @@ -175,10 +175,10 @@ def test_load_config(task_name, task_class):
config_path = Path(temp_dir) / "config.yaml"
config_path.write_text("""task: {task_name}""")
default_config = task_class.get_default_config()
easy_config = ez.config.from_yaml(task_name, config_path)
ez_config = ez.config.from_yaml(task_name, config_path)

for k in default_config.keys():
assert default_config[k] == easy_config[k]
assert default_config[k] == ez_config[k]


@pytest.mark.parametrize("task_name,task_class", TASK_CLASSES)
Expand All @@ -188,11 +188,11 @@ def test_update_finetune_config(task_name, task_class):
config_path = Path(temp_dir) / "config.yaml"
config_path.write_text("""use_lora: true""")
pretrain_config = task_class.get_default_config()
easy_config = ez.config.update_finetune_config(
ez_config = ez.config.update_finetune_config(
task_name, pretrain_config, config_path
)

for k, v in easy_config.items():
for k, v in ez_config.items():
if k != "use_lora":
assert v == pretrain_config[k]
else:
Expand Down

0 comments on commit fa822d5

Please sign in to comment.