Skip to content

Commit

Permalink
attempted to make multioutput work
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Oct 7, 2023
1 parent c76d89b commit f8934c2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
3 changes: 2 additions & 1 deletion src/gptchem/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,7 @@ def __init__(
self.num_digits = num_digits
self.bins = None
self.encoding = encoding
self._label_set = None

@property
def class_names(self) -> List[int]:
Expand Down Expand Up @@ -1181,7 +1182,7 @@ def format_many(self, df: pd.DataFrame) -> pd.DataFrame:
if self.encoding:
encoded = self.encoding.batch_encode(representation)
decoded = self.encoding.batch_decode(encoded)
self._allowed_tokens = list(set(decoded))
self._label_set = list(set(decoded))
prop = df[self.property_columns].values

if self.num_classes is not None:
Expand Down
60 changes: 43 additions & 17 deletions src/gptchem/gpt_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -7,8 +7,8 @@
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB

from gptchem.extractor import ClassificationExtractor
from gptchem.formatter import ClassificationFormatter
from gptchem.extractor import ClassificationExtractor, MultiOutputExtractor
from gptchem.formatter import ClassificationFormatter, MultiOutputClassificationFormatter
from gptchem.querier import Querier
from gptchem.tuner import Tuner
import tiktoken
Expand All @@ -22,41 +22,58 @@ class GPTClassifier:

def __init__(
self,
property_name: str,
tuner: Tuner,
property_name: Union[str, List[str]],
tuner: Optional[Tuner] = None,
querier_settings: Optional[dict] = None,
extractor: ClassificationExtractor = ClassificationExtractor(),
extractor: Optional[ClassificationExtractor] = None,
save_valid_file: bool = False,
bias_token: bool = True,
):
"""Initialize a GPTClassifier.
Args:
property_name (str): Name of the property to be predicted.
property_name (Union[str, List[str]]): Name of the property to be predicted.
This will be part of the prompt.
A list of strings can be provided to predict multiple properties
(requires a `MultiOutputClassificationFormatter` and `MultiOutputExtractor`).
tuner (Tuner): Tuner object to be used for fine tuning.
This specifies the model to be used and the fine-tuning settings.
Defaults to None. If None, a default tuner will be used.
This default Tuner will use the `ada` model.
querier_settings (Optional[dict], optional): Settings for the querier.
Defaults to None.
extractor (ClassificationExtractor, optional): Callable object that can extract
integers from the completions produced by the querier.
Defaults to ClassificationExtractor().
Defaults to None. If None, a default extractor will be used.
save_valid_file (bool, optional): Whether to save the validation file.
Defaults to False.
bias_tokens (bool, optional): Whether to add bias to tokens
to ensure that only the relevant tokens are generated.
"""
self.property_name = property_name
self.tuner = tuner
self.tuner = tuner if tuner is not None else Tuner()
self.querier_setting = (
querier_settings if querier_settings is not None else {"max_tokens": 3}
)
self.extractor = extractor
self.formatter = ClassificationFormatter(
representation_column="repr",
label_column="prop",
property_name=property_name,
num_classes=None,
if extractor is None:
if isinstance(property_name, str):
extractor = ClassificationExtractor()
else:
extractor = MultiOutputExtractor()
self.formatter = (
ClassificationFormatter(
representation_column="repr",
label_column="prop",
property_name=property_name,
num_classes=None,
)
if isinstance(property_name, str)
else MultiOutputClassificationFormatter(
representation_column="repr",
label_columns=property_name,
property_names=property_name,
num_classes=None,
)
)
self.model_name = None
self.tune_res = None
Expand All @@ -80,8 +97,17 @@ def from_finetune_id(cls, finetune_id: str, **kwargs):

def _prepare_df(self, X: ArrayLike, y: ArrayLike):
rows = []
for i in range(len(X)):
rows.append({"repr": X[i], "prop": y[i]})
# if y is one column we add one column "prop"
# else we add one column per property
if y.ndim == 1:
for i in range(len(X)):
rows.append({"repr": X[i], "prop": y[i]})
else:
for i in range(len(X)):
row = {"repr": X[i]}
y_dict = dict(zip(self.property_name, y[i]))
row.update(y_dict)
rows.append(row)
return pd.DataFrame(rows)

def fit(self, X: ArrayLike, y: ArrayLike) -> None:
Expand Down

0 comments on commit f8934c2

Please sign in to comment.