Skip to content

Commit

Permalink
Improve Pooling tokenizer load method, closes #499
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Jul 8, 2023
1 parent b35f5c7 commit 197f681
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
29 changes: 26 additions & 3 deletions src/python/txtai/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

import torch

from transformers import AutoConfig, AutoModel, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
from transformers import (
AutoConfig,
AutoModel,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)

from .onnx import OnnxModel

Expand Down Expand Up @@ -186,12 +193,28 @@ def load(path, config=None, task="default"):
return models[task](path) if task in models else path

@staticmethod
def task(path):
def tokenizer(path, **kwargs):
"""
Loads a tokenizer from path.
Args:
path: path to tokenizer
kwargs: optional additional keyword arguments
Returns:
tokenizer
"""

return AutoTokenizer.from_pretrained(path, **kwargs) if isinstance(path, str) else path

@staticmethod
def task(path, **kwargs):
"""
Attempts to detect the model task from path.
Args:
path: path to model
kwargs: optional additional keyword arguments
Returns:
inferred model task
Expand All @@ -202,7 +225,7 @@ def task(path):
if isinstance(path, (list, tuple)) and hasattr(path[0], "config"):
config = path[0].config
elif isinstance(path, str):
config = AutoConfig.from_pretrained(path)
config = AutoConfig.from_pretrained(path, **kwargs)

# Attempt to resolve task using configuration
task = None
Expand Down
6 changes: 2 additions & 4 deletions src/python/txtai/models/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from torch import nn

from transformers import AutoTokenizer

from .models import Models


Expand All @@ -31,7 +29,7 @@ def __init__(self, path, device, tokenizer=None, maxlength=None):
super().__init__()

self.model = Models.load(path)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer if tokenizer else path)
self.tokenizer = Models.tokenizer(tokenizer if tokenizer else path)
self.device = Models.device(device)

# Detect unbounded tokenizer typically found in older models
Expand Down Expand Up @@ -60,7 +58,7 @@ def encode(self, documents, batch=32):

# Sort document indices from largest to smallest to enable efficient batching
# This performance tweak matches logic in sentence-transformers
lengths = np.argsort([-len(x) for x in documents])
lengths = np.argsort([-len(x) if x else 0 for x in documents])
documents = [documents[x] for x in lengths]

for chunk in self.chunk(documents, batch):
Expand Down

0 comments on commit 197f681

Please sign in to comment.