In [61]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer

In [75]:
# Define the category map
category_map = {'talk.politics.misc': 'Politics', 'rec.autos': 'Autos', 'rec.sport.hockey': 'Hockey',
               'sci.electronics': 'Electronics', 'sci.med': 'Medicine'}

# Get the training dataset using fetch20newsgroups
training_data = fetch_20newsgroups(subset = 'train', categories = category_map.keys(), shuffle = True, random_state = 5)

In [76]:
# Build a countvectorizer and extract term counts
count_vectorizer = CountVectorizer()
train_tc = count_vectorizer.fit_transform(training_data.data)
print("\nDimensions of training data:", train_tc.shape)


Dimensions of training data: (2844, 40321)


In [78]:
# Create a tf-idf transformer
tfidf = TfidfTransformer()
train_tfidf = tfidf.fit_transform(train_tc)

# Define the test data
input_data = [
    'You need to be careful with cars when you are driving on slippery roads',
'A lot of devices can be operated wirelessly',
'Players need to be careful when they are close to goal posts'
]

In [79]:
# Train the Multinomial Bayes classifier
classifier = MultinomialNB().fit(train_tfidf, training_data.target)

# Tranform the input data using count vectorizer
input_tc = count_vectorizer.transform(input_data)

# Transform vectorized data using tfidf transformer
input_tfidf = tfidf.transform(input_tc)

In [80]:
# Predict the output categories
prediction = classifier.predict(input_tfidf)

In [81]:
# Print the outputs
for sent, category in zip(input_data, prediction):
    print('\nInput:', sent, '\nPredicted category:',
         category_map[training_data.target_names[category]])


Input: You need to be careful with cars when you are driving on slippery roads 
Predicted category: Autos

Input: A lot of devices can be operated wirelessly 
Predicted category: Electronics

Input: Players need to be careful when they are close to goal posts 
Predicted category: Hockey
