From f8934c2adcf5ac57b134997038ab142771de51fd Mon Sep 17 00:00:00 2001 From: Kevin Maik Jablonka Date: Sat, 7 Oct 2023 19:14:50 +0200 Subject: [PATCH] attempted to make multioutput work --- src/gptchem/formatter.py | 3 +- src/gptchem/gpt_classifier.py | 60 +++++++++++++++++++++++++---------- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/src/gptchem/formatter.py b/src/gptchem/formatter.py index 21834c9d..07077565 100644 --- a/src/gptchem/formatter.py +++ b/src/gptchem/formatter.py @@ -1125,6 +1125,7 @@ def __init__( self.num_digits = num_digits self.bins = None self.encoding = encoding + self._label_set = None @property def class_names(self) -> List[int]: @@ -1181,7 +1182,7 @@ def format_many(self, df: pd.DataFrame) -> pd.DataFrame: if self.encoding: encoded = self.encoding.batch_encode(representation) decoded = self.encoding.batch_decode(encoded) - self._allowed_tokens = list(set(decoded)) + self._label_set = list(set(decoded)) prop = df[self.property_columns].values if self.num_classes is not None: diff --git a/src/gptchem/gpt_classifier.py b/src/gptchem/gpt_classifier.py index 725b6ca4..3d9daeaf 100644 --- a/src/gptchem/gpt_classifier.py +++ b/src/gptchem/gpt_classifier.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union import numpy as np import pandas as pd @@ -7,8 +7,8 @@ from sklearn.feature_extraction.text import CountVectorizer from sklearn.naive_bayes import MultinomialNB -from gptchem.extractor import ClassificationExtractor -from gptchem.formatter import ClassificationFormatter +from gptchem.extractor import ClassificationExtractor, MultiOutputExtractor +from gptchem.formatter import ClassificationFormatter, MultiOutputClassificationFormatter from gptchem.querier import Querier from gptchem.tuner import Tuner import tiktoken @@ -22,41 +22,58 @@ class GPTClassifier: def __init__( self, - property_name: str, - tuner: Tuner, + property_name: Union[str, List[str]], + tuner: Optional[Tuner] = None, querier_settings: Optional[dict] = None, - extractor: ClassificationExtractor = ClassificationExtractor(), + extractor: Optional[ClassificationExtractor] = None, save_valid_file: bool = False, bias_token: bool = True, ): """Initialize a GPTClassifier. Args: - property_name (str): Name of the property to be predicted. + property_name (Union[str, List[str]]): Name of the property to be predicted. This will be part of the prompt. + A list of strings can be provided to predict multiple properties + (requires a `MultiOutputClassificationFormatter` and `MultiOutputExtractor`). tuner (Tuner): Tuner object to be used for fine tuning. This specifies the model to be used and the fine-tuning settings. + Defaults to None. If None, a default tuner will be used. + This default Tuner will use the `ada` model. querier_settings (Optional[dict], optional): Settings for the querier. Defaults to None. extractor (ClassificationExtractor, optional): Callable object that can extract integers from the completions produced by the querier. - Defaults to ClassificationExtractor(). + Defaults to None. If None, a default extractor will be used. save_valid_file (bool, optional): Whether to save the validation file. Defaults to False. bias_tokens (bool, optional): Whether to add bias to tokens to ensure that only the relevant tokens are generated. """ self.property_name = property_name - self.tuner = tuner + self.tuner = tuner if tuner is not None else Tuner() self.querier_setting = ( querier_settings if querier_settings is not None else {"max_tokens": 3} ) - self.extractor = extractor - self.formatter = ClassificationFormatter( - representation_column="repr", - label_column="prop", - property_name=property_name, - num_classes=None, + if extractor is None: + if isinstance(property_name, str): + extractor = ClassificationExtractor() + else: + extractor = MultiOutputExtractor() + self.formatter = ( + ClassificationFormatter( + representation_column="repr", + label_column="prop", + property_name=property_name, + num_classes=None, + ) + if isinstance(property_name, str) + else MultiOutputClassificationFormatter( + representation_column="repr", + label_columns=property_name, + property_names=property_name, + num_classes=None, + ) ) self.model_name = None self.tune_res = None @@ -80,8 +97,17 @@ def from_finetune_id(cls, finetune_id: str, **kwargs): def _prepare_df(self, X: ArrayLike, y: ArrayLike): rows = [] - for i in range(len(X)): - rows.append({"repr": X[i], "prop": y[i]}) + # if y is one column we add one column "prop" + # else we add one column per property + if y.ndim == 1: + for i in range(len(X)): + rows.append({"repr": X[i], "prop": y[i]}) + else: + for i in range(len(X)): + row = {"repr": X[i]} + y_dict = dict(zip(self.property_name, y[i])) + row.update(y_dict) + rows.append(row) return pd.DataFrame(rows) def fit(self, X: ArrayLike, y: ArrayLike) -> None: