Skip to content

Commit

Permalink
class weigths work now
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Oct 8, 2023
1 parent b6e7c5b commit eac3633
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions src/gptchem/gpt_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
extractor: Optional[ClassificationExtractor] = None,
save_valid_file: bool = False,
bias_token: bool = True,
class_weights: Optional[dict] = None,
):
"""Initialize a GPTClassifier.
Expand All @@ -49,6 +50,10 @@ def __init__(
Defaults to False.
bias_tokens (bool, optional): Whether to add bias to tokens
to ensure that only the relevant tokens are generated.
Defaults to True.
class_weights (Optional[dict], optional): Class weights to be used for inference.
Defaults to None. If None, classes will be weighted equally.
Ensure that the weights add up to 1.
"""
self.property_name = property_name
self.tuner = tuner if tuner is not None else Tuner()
Expand Down Expand Up @@ -82,14 +87,25 @@ def __init__(
self.save_valid_file = save_valid_file
self.bias_token = bias_token
self._input_shape = None
self._class_weights = class_weights

def _get_bias_dict(self):
bias_dict = {}
bias = 100
encoding = tiktoken.encoding_for_model(self.tuner.base_model)
if self._class_weights is not None:
default_weight = 1 / len(self._class_weights)
else:
default_weight = 1
if self.bias_token:
encoding = tiktoken.encoding_for_model(self.tuner.base_model)
for char in self.formatter.allowed_characters:
for token in encoding.encode(char):
bias_dict[token] = 100
bias_dict[token] = 100 * default_weight

if self._class_weights is not None:
for class_, weight in self._class_weights.items():
encoded = encoding.encode(str(class_))
bias_dict[encoded[0]] = bias * weight
return bias_dict

@classmethod
Expand Down Expand Up @@ -147,7 +163,6 @@ def predict(self, X: ArrayLike) -> ArrayLike:

querier = Querier(self.model_name, **self.querier_setting, logit_bias=self._get_bias_dict())
completions = querier(formatted)
print(completions)
extracted = self.extractor(completions)
return extracted

Expand Down

0 comments on commit eac3633

Please sign in to comment.