Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/bug fix UI merge #87

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions toolkit/src/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,3 @@ def finetune(self, train_dataset: Dataset):
def save_model(self) -> None:
self._trainer.model.save_pretrained(self._weights_path)
self.tokenizer.save_pretrained(self._weights_path)

self._console.print(f"Run saved at {self._weights_path}")
7 changes: 6 additions & 1 deletion toolkit/src/inference/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from rich.live import Live
from rich.text import Text
from datasets import Dataset
from transformers import AutoTokenizer
from transformers import AutoTokenizer, BitsAndBytesConfig
from peft import AutoPeftModelForCausalLM
import torch


from src.pydantic_models.config_model import Config
from src.utils.save_utils import DirectoryHelper
from src.inference.inference import Inference
Expand Down Expand Up @@ -60,6 +61,10 @@ def _get_merged_model(self, weights_path: str):
weights_path,
torch_dtype=dtype,
device_map=self.device_map,
quantization_config=(
BitsAndBytesConfig(self.config.model.bitsandbytes)
),

)

"""TODO: figure out multi-gpu
Expand Down
3 changes: 3 additions & 0 deletions toolkit/src/ui/rich_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def results_found(results_path: str):
# Display functions
@staticmethod
def inference_ground_truth_display(title: str, prompt: str, label: str):
prompt = prompt.replace("[INST]", "").replace("[/INST]", "")
label = label.replace("[INST]", "").replace("[/INST]", "")

table = Table(title=title, show_lines=True)
table.add_column("prompt")
table.add_column("ground truth")
Expand Down
16 changes: 8 additions & 8 deletions toolkit/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def run_one_experiment(config: Config) -> None:
RichUI.inference_found(results_path)

# QA -------------------------------
console.rule("[bold blue]:thinking_face: Running LLM Unit Tests")
qa_path = dir_helper.save_paths.qa
if not exists(qa_path) or not listdir(qa_path):
# TODO: Instantiate unit test classes
# TODO: Load results.csv
# TODO: Run Unit Tests
# TODO: Save Unit Test Results
pass
# console.rule("[bold blue]:thinking_face: Running LLM Unit Tests")
# qa_path = dir_helper.save_paths.qa
# if not exists(qa_path) or not listdir(qa_path):
# # TODO: Instantiate unit test classes
# # TODO: Load results.csv
# # TODO: Run Unit Tests
# # TODO: Save Unit Test Results
# pass


if __name__ == "__main__":
Expand Down