In [1]:
from sklearn.datasets import fetch_20newsgroups
from transformers import pipeline
import time
from datetime import timedelta
import psutil
import ray

# Load 20newsgroups news articles that that belong to “rec.motorcycles” and “rec.sport.baseball” classes of the test set only
test_data = fetch_20newsgroups(subset='test', shuffle=False, categories=['rec.motorcycles', 'rec.sport.baseball'], remove=('headers', 'footers', 'quotes'))
# Remove empty news article texts
test_data = [text for text in test_data.data if text!='']
print('Number of text articles:', len(test_data))

"""
HuggingFace pipelines are objects that abstract most of the complex code from the library, 
offering a simple API dedicated to several tasks, including text classification.
All pipelines can use batching.
However, this is not automatically a win for performance. 
It can be either a 10x speedup or 5x slowdown depending on hardware, data and the actual model being used.
Batching is only recommended on GPU. 
If you are using CPU, don’t batch.
"""
# Init pipeline with batchsize 1 on CPU for our example
pipe = pipeline(task = 'zero-shot-classification', 
                model='typeform/distilbert-base-uncased-mnli', 
                batch_size=8, 
                device='mps')


2023-12-30 23:20:03,091	INFO util.py:159 -- Outdated packages:
  ipywidgets==7.7.2 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Number of text articles: 777


The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


In [2]:
# Predict single text
prediction = pipe(test_data[100], ['motorcycle', 'baseball'])
print('Text:', prediction['sequence'])
print('Labels:', prediction['labels'])
print('Scores:', prediction['scores'])

Text: Hey, the Lone Biker of the Apocalypse (see Raising Arizona) had flames coming
out of both his exhaust pipes. I love to toggle the kill switch on my Sportster
to produce flaming backfires, especially underneath overpasses at night (it's
loud and lights up the whole underpass!!!
Labels: ['motorcycle', 'baseball']
Scores: [0.9970590472221375, 0.002940941136330366]


In [3]:

# Predict multipe texts on single CPU and time the inference duration
start = time.time()

predictions = [pipe(text, ['motorcycle', 'baseball']) for text in test_data]
end = time.time()
print('Prediction time:', str(timedelta(seconds=end-start)))

Prediction time: 0:00:52.313775


In [5]:
num_cpus = psutil.cpu_count(logical=True)
print('Number of available CPUs:', num_cpus)

# Start Ray cluster
num_cpus = num_cpus
#ray.init(num_cpus=num_cpus, ignore_reinit_error=True)

"""
The command ray.put(x) would be run by a worker process or by the driver process (the driver process is the one running your script). 
It takes a Python object and copies it to the local object store (here local means on the same node). 
Once the object has been stored in the object store, its value cannot be changed.
In addition, ray.put(x) returns an object ID, which is essentially an ID that can be used to refer to the newly created remote object. 
If we save the object ID in a variable with x_id = ray.put(x), 
then we can pass x_id into remote functions, 
and those remote functions will operate on the corresponding remote object.
"""
pipe_id = ray.put(pipe)

# @ray.remote decorator enables to use this 
# function in distributed setting
@ray.remote
def predict(pipeline, text_data, label_names):
    return pipeline(text_data, label_names)

# Predict multipe texts on all available CPUs and time the inference duration
start = time.time()

# Run the function using multiple cores and gather the results
predictions = ray.get([predict.remote(pipe_id, text, ['motorcycle', 'baseball']) for text in test_data])

end = time.time()
print('Prediction time:', str(timedelta(seconds=end-start)))

# Stop running Ray cluster
#ray.shutdown()

Number of available CPUs: 12


2023-12-30 23:22:18,171	INFO worker.py:1612 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


Prediction time: 0:00:28.430376


In [6]:
ray.shutdown()