Skip to content

Commit

Permalink
fix diarise_audio params typo
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxwh committed Jan 31, 2024
1 parent e499ed6 commit 1cd3472
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions predict.py
Expand Up @@ -16,22 +16,20 @@ class Predictor(BasePredictor):
def setup(self):
"""Loads whisper models into memory to make running multiple predictions efficient"""
self.model_cache = "model_cache"
local_files_only = True # set to true after the model is cached to model_cache
model_id = "openai/whisper-large-v3"
torch_dtype = torch.float16
self.device = "cuda:0"
model = WhisperForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch_dtype,
cache_dir=self.model_cache,
local_files_only=local_files_only,
).to(self.device)

tokenizer = WhisperTokenizerFast.from_pretrained(
model_id, cache_dir=self.model_cache, local_files_only=local_files_only
model_id, cache_dir=self.model_cache
)
feature_extractor = WhisperFeatureExtractor.from_pretrained(
model_id, cache_dir=self.model_cache, local_files_only=local_files_only
model_id, cache_dir=self.model_cache
)

self.pipe = pipeline(
Expand Down Expand Up @@ -90,7 +88,7 @@ def predict(
return_timestamps="word" if timestamp == "word" else True,
)

if diarize_audio:
if diarise_audio:
if self.diarization_pipeline is None:
try:
self.diarization_pipeline = Pipeline.from_pretrained(
Expand Down

0 comments on commit 1cd3472

Please sign in to comment.