In [1]:
import pandas as pd
import numpy as np

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, hamming_loss, f1_score

from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier

from skmultilearn.problem_transform import BinaryRelevance

In [2]:
pd.set_option('display.max_colwidth', -1)

In [3]:
df = pd.read_csv('resources/civil_court_orders_to_classifier.csv.gz', compression='gzip') 

## Data preparation

In [4]:
df.loc[:, ['vectors', 'pca', 'money', 'descr_articles']] = df.loc[:, ['vectors', 'pca', 'money', 'descr_articles']].applymap(eval)

In [5]:
vectors = df['vectors'].tolist()
articles = df['descr_articles'].tolist()

In [6]:
mlb = MultiLabelBinarizer()
binarized_articles = mlb.fit_transform(articles)

In [7]:
print(f"Number of unique articles in the dataset = {mlb.classes_.shape[0]}")

Number of unique articles in the dataset = 876


In [8]:
mlb.classes_

array(['1', '1.2', '1.3', '1.4', '1.5', '1.7', '10', '10.6', '100',
       '1005', '101', '101.4', '102', '103', '103.', '104', '104.1',
       '1041', '105', '1050', '106', '1064', '1065', '1068', '1069',
       '107', '1070', '1071', '1072', '1073', '1074', '1079', '108',
       '1080', '1081', '1082', '1083', '1084', '1085', '1086', '1088',
       '1089', '109', '109.1', '1091', '1094', '1095', '1096', '1098',
       '1099', '11', '11.1', '11.10', '11.2', '11.3', '11.5', '11.8',
       '11.9', '110', '1100', '1101', '1102', '1103', '1104', '1105',
       '1107', '1109', '111', '1110', '1111', '1112', '1113', '1114',
       '1115', '1116', '1117', '1118', '1119', '112', '1120', '1124',
       '1125', '1126', '1127', '1128', '1129', '113', '1130', '1131',
       '1132', '114', '1141', '1142', '1143', '1144', '1146', '1148',
       '1149', '115', '1150', '1151', '1152', '1153', '1154', '1155',
       '1156', '1157', '1158', '1159', '116', '1161', '1162', '1165',
       '1168', '1169', 

## Multi-label classification

In [9]:
x, y = np.array(vectors), binarized_articles

In [10]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=88, shuffle=True)

In [11]:
x_train.shape, x_test.shape

((5963, 100), (2556, 100))

In [12]:
y_train.shape, y_test.shape

((5963, 876), (2556, 876))

In [13]:
%%time

classifier = BinaryRelevance(GaussianNB())
classifier.fit(x_train, y_train)
predictions = classifier.predict(x_test)

Wall time: 20.8 s


In [14]:
print(f"F1-score = {f1_score(y_test, predictions, average='micro'):.4f}")

F1-score = 0.4078


In [16]:
results = pd.DataFrame()
results['True'] = pd.Series(mlb.inverse_transform(y_test))
results['Predicted'] = pd.Series(mlb.inverse_transform(predictions))
results['True subset'] = results.apply(lambda row: set(row['Predicted']) & set(row['True']), axis=1)
results.head(15)

Unnamed: 0,True,Predicted,True subset
0,"(131, 164, 165, 167, 17, 19, 25, 29, 35, 422, 432, 549, 550, 551, 554, 558)","(131, 165, 194, 209, 218, 35, 352, 50, 549, 551)","{165, 549, 551, 131, 35}"
1,"(1, 12.1, 129, 130, 14.1, 15, 194, 210, 284, 333.36, 35, 79)","(1, 1151, 1152, 12.1, 130, 14.1, 18, 284, 79)","{12.1, 284, 14.1, 1, 79, 130}"
2,"(167, 194, 198, 39)","(233, 309, 310, 330, 333, 333.19, 450, 809, 810, 811, 819, 88, 98)",{}
3,"(1, 103, 12, 16, 194, 3, 39, 4, 46, 5, 67, 7)","(1, 10, 100, 103, 11, 12, 123, 13, 15, 16, 167, 17, 18, 194, 3, 4, 41, 454, 46, 469, 470, 5, 67, 7, 8, 9)","{4, 1, 103, 5, 16, 46, 194, 12, 67, 3, 7}"
4,"(1, 11, 14, 19, 194, 39, 56, 6, 7)","(11, 173, 19, 194, 198, 2, 212, 28, 29, 34, 39, 42, 7, 8)","{11, 194, 19, 39, 7}"
5,"(12, 123, 194, 2, 209, 212, 28, 57)","(103, 11, 12, 151, 194, 2, 212, 213, 3, 34, 39, 7, 8)","{2, 12, 194, 212}"
6,"(309, 310, 333.19, 434, 438, 56, 807, 809, 810, 811, 819, 98)","(113, 117, 233, 309, 310, 319, 329, 330, 331, 333, 395, 420, 421, 422, 428, 432, 433, 434, 435, 438, 450, 56, 6.1, 67, 807, 809, 810, 811, 819, 820, 88, 96, 98)","{98, 434, 309, 811, 438, 810, 310, 807, 809, 819, 56}"
7,"(333, 810, 811, 819, 98)","(113, 123, 160, 167, 194, 233, 309, 310, 421, 432, 434, 56, 809, 810, 811, 819, 98)","{819, 810, 811, 98}"
8,"(173, 220, 264, 39)","(1099, 1101, 134, 135, 151, 193, 194, 220, 237, 24, 333.36, 39, 51, 54, 6, 61, 77)","{39, 220}"
9,"(10, 100, 103, 12, 13, 15, 167, 18, 194, 206, 333, 45, 56, 57, 67, 8, 88, 94)","(1, 10, 100, 103, 12, 13, 15, 150, 167, 18, 194, 206, 333, 333.19, 41, 67, 8, 88, 94)","{13, 8, 10, 103, 167, 100, 333, 94, 88, 194, 12, 18, 67, 15, 206}"
