Skip to content

Commit

Permalink
set up pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
bethanyconnolly committed Mar 1, 2023
1 parent 4fcc99a commit df9af1c
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 6 deletions.
35 changes: 35 additions & 0 deletions model/model_training/pyproject.toml
@@ -0,0 +1,35 @@
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"

[project]
name = "model_training"
description = "Open Assistant Model Training Module"
version = "1.0.0"
authors = [
{ name = "LAION-AI", email = "contact@laion.ai" }
]
dependencies = [
"accelerate==0.15.0",
"bitsandbytes==0.36.0.post2",
"datasets==2.8.0",
"deepspeed==0.7.7",
"evaluate==0.4.0",
"gdown",
#"git+https://github.com/CarperAI/trlx.git@b91da7b03d8e9fa0c0d6dce10a8f2611aca3013f",
"nltk==3.8.1",
"numpy>=1.22.4",
"py7zr",
"scikit-learn==1.2.0",
"sentencepiece==0.1.97",
"torch>=1.12.1",
"transformers==4.25.1",
"wandb==0.13.7",
]

[tool.setuptools]
py-modules = []

[tool.black]
line-length = 120
target-version = ['py310']
Empty file.
13 changes: 8 additions & 5 deletions model/model_training/tests/test_datasets.py
@@ -1,10 +1,13 @@
from argparse import Namespace

from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
import pytest


# TODO:
@pytest.mark.skip(reason="Cannot import glibcxx in pytest")
def test_all_datasets():
from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
qa_base = QA_DATASETS
summarize_base = SUMMARIZATION_DATASETS
others = ["prompt_dialogue", "webgpt", "soda", "joke", "instruct_tuning", "explain_prosocial", "prosocial_dialogue"]
Expand All @@ -20,8 +23,10 @@ def test_all_datasets():
for idx in range(min(len(eval), 1000)):
eval[idx]


@pytest.mark.skip(reason="Cannot import glibcxx in pytest")
def test_collate_fn():
from custom_datasets import QA_DATASETS, SUMMARIZATION_DATASETS, get_one_dataset
from custom_datasets.dialogue_collator import DialogueDataCollator
from torch.utils.data import ConcatDataset, DataLoader
from utils import get_tokenizer

Expand Down Expand Up @@ -49,5 +54,3 @@ def test_collate_fn():
for batch in dataloader:
assert batch["targets"].shape[1] <= 620


test_collate_fn()
1 change: 0 additions & 1 deletion model/model_training/tests/test_utils.py
Expand Up @@ -7,7 +7,6 @@
def test_tokenizer():
get_tokenizer(Namespace(model_name="Salesforce/codegen-2B-multi", cache_dir=".cache"))
get_tokenizer(Namespace(model_name="facebook/galactica-1.3b", cache_dir=".cache"))
get_tokenizer(Namespace(model_name="", cache_dir=".cache"))


def test_tokenizer_successful_match():
Expand Down

0 comments on commit df9af1c

Please sign in to comment.