In [1]:
# Reference from https://towardsdatascience.com/journey-to-the-center-of-multi-label-classification-384c40229bff
import pandas as pd
import re
from nltk.stem.porter import PorterStemmer
from sklearn.model_selection import train_test_split
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer

In [2]:
def tokenize(text):
    text = re.sub('[^a-zA-Z ]', '', text)
    text = re.sub('\s+', ' ', text).strip().lower()
    return text

def remove_stopwords(text):
    stop_words = stopwords.words('english')
    re_stop_words = re.compile(r'\b(' + '|'.join(stop_words) + ')\s')
    return re_stop_words.sub('', text)

def stem_text(text):
    words = text.split(' ')
    stemmer = PorterStemmer()
    return ' '.join([stemmer.stem(word) for word in words])

In [3]:
reuters_df = pd.read_csv('reuters_new.csv')

In [4]:
reuters_df['text'] = reuters_df['text'].apply(tokenize).apply(remove_stopwords).apply(stem_text)

In [43]:
X_train, X_test, y_train, y_test = train_test_split(reuters_df['text'], reuters_df.iloc[:, 2:], test_size = 0.3)

In [44]:
# vectorize data
vectorizer = TfidfVectorizer(analyzer='word')

xs = {'train': [], 'test': []}
xs['train'] = vectorizer.fit_transform(X_train).toarray()
xs['test'] = vectorizer.transform(X_test).toarray()

ys = {'train': [], 'test': []}
ys['train'] = y_train.values
ys['test'] = y_test.values

  if hasattr(X, 'dtype') and np.issubdtype(X.dtype, np.float):


In [45]:
categories = ['corn', 'cotton', 'rice', 'soybean', 'wheat']

In [46]:
print('Total')
print(len(reuters_df))
print('Training dataset')
print(len(X_train))
print('Testing dataset')
print(len(X_test))
len(vectorizer.get_feature_names())

Total
560
Training dataset
392
Testing dataset
168


3827

In [47]:
print('Training dataset')
print(y_train.sum())
print('\nTesting dataset')
print(y_test.sum())

y_train['sum'] = y_train.sum(axis = 1)
print('Training')
print(y_train['sum'].value_counts())
y_test['sum'] = y_test.sum(axis = 1)
print('\nTesting')
print(y_test['sum'].value_counts())

Training dataset
corn       149
cotton      43
rice        43
soybean     77
wheat      209
dtype: int64

Testing dataset
corn       74
cotton     19
rice       24
soybean    34
wheat      78
dtype: int64
Training
1    298
2     66
3     23
4      3
5      2
Name: sum, dtype: int64

Testing
1    132
2     18
3     13
4      3
5      2
Name: sum, dtype: int64


In [48]:
# https://pystruct.github.io/auto_examples/multi_label.html#sphx-glr-auto-examples-multi-label-py
import itertools
import numpy as np
import pandas as pd
from sklearn.metrics import hamming_loss
from pystruct.learners import OneSlackSSVM
from pystruct.models import MultiLabelClf
from pystruct.datasets import load_scene

In [49]:
n_labels = len(categories)
print(n_labels)
full = np.vstack([x for x in itertools.combinations(range(n_labels), 2)])

5


In [54]:
full_model = MultiLabelClf(edges=full, inference_method='qpbo')
full_model.n_labels

In [55]:
full_ssvm = OneSlackSSVM(full_model, inference_cache=50, C=.1, tol=0.01)

In [56]:
print('Fitting full model...')
full_ssvm.fit(xs['train'], ys['train'])

Fitting full model...


OneSlackSSVM(C=0.1, break_on_bad=False, cache_tol='auto',
       check_constraints=False, inactive_threshold=1e-05,
       inactive_window=50, inference_cache=50, logger=None, max_iter=10000,
       model=MultiLabelClf(n_states: 2, inference_method: qpbo), n_jobs=1,
       negativity_constraint=None, show_loss_every=0, switch_to=None,
       tol=0.01, verbose=0)

In [57]:
print("Training loss full model: %f"
      % hamming_loss(ys['train'], np.vstack(full_ssvm.predict(xs['train']))))
print("Test loss full model: %f"
      % hamming_loss(ys['test'], np.vstack(full_ssvm.predict(xs['test']))))

Training loss full model: 0.067347
Test loss full model: 0.134524


In [58]:
predictions = full_ssvm.predict(xs['test'])

In [59]:
def display(predictions):
    to_display = []
    for (actual, prediction) in zip(ys['test'], predictions):
        actual_cats = []
        predicted_cats = []
        for (idx, category) in enumerate(actual):
            if category != 0:
                actual_cats.append(categories[idx])

        for (idx, category) in enumerate(prediction):
            if category != 0:
                predicted_cats.append(categories[idx])
        to_display.append((actual_cats, predicted_cats))
    return to_display

In [72]:
to_display = display(predictions)
to_display

[(['corn'], ['corn']),
 (['soybean'], ['soybean']),
 (['corn'], ['corn']),
 (['corn', 'wheat'], ['corn']),
 (['soybean'], []),
 (['rice'], ['rice']),
 (['corn'], ['wheat']),
 (['corn', 'wheat'], ['corn']),
 (['rice'], []),
 (['wheat'], ['wheat']),
 (['corn', 'cotton', 'rice', 'wheat'], ['corn']),
 (['corn'], ['corn']),
 (['wheat'], ['wheat']),
 (['corn'], ['corn']),
 (['corn'], ['corn']),
 (['rice'], []),
 (['soybean'], ['soybean']),
 (['wheat'], ['wheat']),
 (['corn', 'soybean', 'wheat'], ['corn']),
 (['corn'], ['corn']),
 (['corn', 'soybean'], ['corn']),
 (['wheat'], ['wheat']),
 (['rice', 'wheat'], ['wheat']),
 (['wheat'], ['wheat']),
 (['cotton'], []),
 (['wheat'], ['wheat']),
 (['corn'], ['corn']),
 (['corn', 'cotton', 'wheat'], ['wheat']),
 (['corn'], ['corn']),
 (['corn'], ['corn']),
 (['corn'], ['corn']),
 (['corn'], ['wheat']),
 (['cotton'], []),
 (['corn'], ['corn']),
 (['wheat'], ['wheat']),
 (['cotton'], []),
 (['wheat'], ['wheat']),
 (['corn'], []),
 (['corn'], []),
 (['co

In [73]:
fobj = open('example_text.txt', mode = 'r')
data = fobj.read()
fobj.close()
example_text = stem_text(remove_stopwords(tokenize(data)))
values = vectorizer.transform([example_text]).toarray()

  if hasattr(X, 'dtype') and np.issubdtype(X.dtype, np.float):


In [86]:
example_prediction = full_ssvm.predict(values)
example_prediction_values = []
for (idx, category) in enumerate(example_prediction[0]):
    if category != 0:
        example_prediction_values.append(categories[idx])
print(example_prediction_values)


['corn']


In [87]:
values_list = list(values[0])

In [88]:
list_val = [val for val in values_list if val != 0]
str_val = ''
for val in list_val:
    str_val = '{}, {:.4f}'.format(str_val, val)
str_val

', 0.0517, 0.0757, 0.0414, 0.0510, 0.0594, 0.1003, 0.1101, 0.0717, 0.0939, 0.0763, 0.0440, 0.0717, 0.0560, 0.1003, 0.0661, 0.1003, 0.2115, 0.2113, 0.0450, 0.2532, 0.0429, 0.0731, 0.0360, 0.0560, 0.0828, 0.1564, 0.0606, 0.1385, 0.1763, 0.0329, 0.0893, 0.0803, 0.0489, 0.0893, 0.0460, 0.0871, 0.0844, 0.1418, 0.0717, 0.1066, 0.0565, 0.0546, 0.0782, 0.0692, 0.0893, 0.0546, 0.0880, 0.0620, 0.1003, 0.0767, 0.0463, 0.0420, 0.0613, 0.0628, 0.1984, 0.2007, 0.0420, 0.0782, 0.0582, 0.0652, 0.0746, 0.0537, 0.0560, 0.1003, 0.1003, 0.0939, 0.0696, 0.0542, 0.0455, 0.0893, 0.1026, 0.0661, 0.0565, 0.1464, 0.0161, 0.0365, 0.1723, 0.0746, 0.0939, 0.0893, 0.0576, 0.0782, 0.0939, 0.1241, 0.0704, 0.0939, 0.0857, 0.1588, 0.1003, 0.0985, 0.0717, 0.1877, 0.0764, 0.0803, 0.0801, 0.1489, 0.0571, 0.1003, 0.1435, 0.1072, 0.0828, 0.0798, 0.0652, 0.0972, 0.0904'