Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LLM QA #167

Merged
merged 8 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion llmtune/cli/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from llmtune.finetune.lora import LoRAFinetune
from llmtune.inference.lora import LoRAInference
from llmtune.pydantic_models.config_model import Config
from llmtune.qa.generics import LLMTestSuite, QaTestRegistry
from llmtune.qa.generics import LLMTestSuite
from llmtune.qa.qa_tests import QaTestRegistry
from llmtune.ui.rich_ui import RichUI
from llmtune.utils.ablation_utils import generate_permutations
from llmtune.utils.save_utils import DirectoryHelper
Expand Down Expand Up @@ -92,6 +93,7 @@ def run_one_experiment(config: Config, config_path: Path) -> None:
tests = QaTestRegistry.create_tests_from_list(llm_tests)
test_suite = LLMTestSuite.from_csv(results_file_path, tests)
test_suite.save_test_results(qa_file_path)
test_suite.print_test_results()


@app.command("run")
Expand Down
18 changes: 10 additions & 8 deletions llmtune/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ lora:
r: 32
lora_alpha: 64
lora_dropout: 0.1
target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
- up_proj
- down_proj
- gate_proj
target_modules: "all-linear"
# to target specific modules
# target_modules:
# - q_proj
# - v_proj
# - k_proj
# - o_proj
# - up_proj
# - down_proj
# - gate_proj

# Training -------------------
training:
Expand Down
5 changes: 4 additions & 1 deletion llmtune/pydantic_models/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ class LoraConfig(BaseModel):
lora_alpha: Optional[int] = Field(16, description="The alpha parameter for Lora scaling")
bias: Optional[str] = Field("none", description="Bias type for Lora. Can be 'none', 'all' or 'lora_only'")
lora_dropout: Optional[float] = Field(0.1, description="The dropout probability for Lora layers")
target_modules: Optional[List[str]] = Field(None, description="The names of the modules to apply Lora to")
target_modules: Optional[Union[List[str], Literal["all-linear"]]] = Field(
"all-linear", description="The names of the modules to apply Lora to"
)
fan_in_fan_out: Optional[bool] = Field(
False,
description="Flag to indicate if the layer to replace stores weight like (fan_in, fan_out)",
Expand Down Expand Up @@ -242,3 +244,4 @@ class Config(BaseModel):
lora: LoraConfig
training: TrainingConfig
inference: InferenceConfig
qa: QaConfig
38 changes: 17 additions & 21 deletions llmtune/qa/generics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import statistics
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Union

import pandas as pd
Expand All @@ -18,23 +19,6 @@ def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[f
pass


class QaTestRegistry:
registry = {}

@classmethod
def register(cls, *names):
def inner_wrapper(wrapped_class):
for name in names:
cls.registry[name] = wrapped_class
return wrapped_class

return inner_wrapper

@classmethod
def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]:
return [cls.create_test(test) for test in test_names]


class LLMTestSuite:
def __init__(
self,
Expand All @@ -51,11 +35,17 @@ def __init__(
self._results = {}

@staticmethod
def from_csv(file_path: str, tests: List[LLMQaTest]) -> "LLMTestSuite":
def from_csv(
file_path: str,
tests: List[LLMQaTest],
prompt_col: str = "Prompt",
gold_col: str = "Ground Truth",
pred_col="Predicted",
) -> "LLMTestSuite":
results_df = pd.read_csv(file_path)
prompts = results_df["prompt"].tolist()
ground_truths = results_df["ground_truth"].tolist()
model_preds = results_df["model_prediction"].tolist()
prompts = results_df[prompt_col].tolist()
ground_truths = results_df[gold_col].tolist()
model_preds = results_df[pred_col].tolist()
return LLMTestSuite(tests, prompts, ground_truths, model_preds)

def run_tests(self) -> Dict[str, List[Union[float, int, bool]]]:
Expand Down Expand Up @@ -84,5 +74,11 @@ def print_test_results(self):

def save_test_results(self, path: str):
# TODO: save these!
path = Path(path)
dir = path.parent

if not dir.exists():
dir.mkdir(parents=True, exist_ok=True)

resultant_dataframe = pd.DataFrame(self.test_results)
resultant_dataframe.to_csv(path, index=False)
19 changes: 18 additions & 1 deletion llmtune/qa/qa_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rouge_score import rouge_scorer
from transformers import DistilBertModel, DistilBertTokenizer

from llmtune.qa.generics import LLMQaTest, QaTestRegistry
from llmtune.qa.generics import LLMQaTest


model_name = "distilbert-base-uncased"
Expand All @@ -21,6 +21,23 @@
nltk.download("averaged_perceptron_tagger")


class QaTestRegistry:
registry = {}

@classmethod
def register(cls, *names):
def inner_wrapper(wrapped_class):
for name in names:
cls.registry[name] = wrapped_class
return wrapped_class

return inner_wrapper

@classmethod
def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]:
return [cls.registry[test]() for test in test_names]


@QaTestRegistry.register("summary_length")
class LengthTest(LLMQaTest):
@property
Expand Down