Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add custom embedder #2236

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open

Conversation

vonodiripsa
Copy link
Contributor

Related Issues/PRs

#xxx

What changes are proposed in this pull request?

Briefly describe the changes included in this Pull Request.

How is this patch tested?

  • I have written tests (not required for typo or doc fix) and confirmed the proposed feature/bug-fix/change works.

Does this PR change any dependencies?

  • No. You can skip this section.
  • Yes. Make sure the dependencies are resolved correctly, and list changes here.

Does this PR add a new feature? If so, have you added samples on website?

  • No. You can skip this section.
  • Yes. Make sure you have added samples following below steps.
  1. Find the corresponding markdown file for your new feature in website/docs/documentation folder.
    Make sure you choose the correct class estimators/transformers and namespace.
  2. Follow the pattern in markdown file and add another section for your new API, including pyspark, scala (and .NET potentially) samples.
  3. Make sure the DocTable points to correct API link.
  4. Navigate to website folder, and run yarn run start to make sure the website renders correctly.
  5. Don't forget to add <!--pytest-codeblocks:cont--> before each python code blocks to enable auto-tests for python samples.
  6. Make sure the WebsiteSamplesTests job pass in the pipeline.

self,
inputCol=None,
outputCol=None,
useTRTFlag=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: useTRTFlag -> runtime: "cpu", "gpu", "tensorrt", default cpu


# Define additional parameters
useTRT = Param(Params._dummy(), "useTRT", "True if use TRT acceleration")
driverOnly = Param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove driver Only code

Comment on lines 210 to 211
inputCol="combined",
outputCol="embeddings",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look at other examples of proper defaults for these columns in library

Comment on lines 306 to 308
for batch_size in [64, 32, 16, 8, 4, 2, 1]:
for sentence_length in [20, 300, 512]:
yield (batch_size, sentence_length)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make these magic numbers, parameters with defaults

"""
Create a data loader with synthetic data using Faker.
"""
faker = Faker()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lets try to remove this dependency

for sentence_length in [20, 300, 512]:
yield (batch_size, sentence_length)

def get_dataloader(repeat_times: int = 2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _get_dataloader

func, dataloader=tqdm(get_dataloader(), total=total_batches), config=conf
)

def run_on_driver(self, queries, spark):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise _

"""
return self._defaultCopy(extra)

def load_data_food_reviews(self, spark, path=None, limit=1000):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this code into demo

Comment on lines 15 to 30
class SuppressLogging:
def __init__(self):
self._original_stderr = None

def start(self):
"""Start suppressing logging by redirecting sys.stderr to /dev/null."""
if self._original_stderr is None:
self._original_stderr = sys.stderr
sys.stderr = open('/dev/null', 'w')

def stop(self):
"""Stop suppressing logging and restore sys.stderr."""
if self._original_stderr is not None:
sys.stderr.close()
sys.stderr = self._original_stderr
self._original_stderr = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

FloatType,
)

class EmbeddingTransformer(Transformer, HasInputCol, HasOutputCol):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: HuggingFaceSentenceEmbedder

Also name the file HuggingFaceSentenceEmbedder.py

Comment on lines 202 to 203
modelName="intfloat/e5-large-v2",
moduleName="e5-large-v2",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no defaults here, and try to make this module Name thing go away

Initialize the EmbeddingTransformer with input/output columns and optional TRT flag.
"""
super(EmbeddingTransformer, self).__init__()
self._setDefault(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/databricks/python/bin/pip install --extra-index-url https://pypi.nvidia.com cudf-cu11~=${RAPIDS_VERSION} cuml-cu11~=${RAPIDS_VERSION} pylibraft-cu11~=${RAPIDS_VERSION} rmm-cu11~=${RAPIDS_VERSION}

# install model navigator
/databricks/python/bin/pip install --extra-index-url https://pypi.nvidia.com onnxruntime-gpu==1.16.3 "tensorrt==9.3.0.post12.dev1" "triton-model-navigator<1" "sentence_transformers~=2.2.2" "faker" "urllib3<2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove faker

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@mhamilton723
Copy link
Collaborator

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@mhamilton723
Copy link
Collaborator

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@bvonodiripsa
Copy link
Contributor

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@bvonodiripsa
Copy link
Contributor

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@bvonodiripsa
Copy link
Contributor

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@bvonodiripsa
Copy link
Contributor

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@bvonodiripsa
Copy link
Contributor

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@bvonodiripsa
Copy link
Contributor

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

@bvonodiripsa
Copy link
Contributor

/azp run

Copy link

Azure Pipelines successfully started running 1 pipeline(s).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants