In [56]:
import os
import re
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import vowpalwabbit

**Loading https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html dataset for 20 class classifcation problem**

In [63]:
!mkdir -p RL

newsgroups= fetch_20newsgroups()

all_documents = newsgroups["data"]
topic_encoder = LabelEncoder()
all_targets_mult = topic_encoder.fit_transform(newsgroups["target"]) + 1

**Converting Text to VW format**

In [64]:
def to_vw_format(document, label=None):
    return (
        str(label or "")
        + " |text "
        + " ".join(re.findall("\w{3,}", document.lower()))
        + "\n"
    )

to_vw_format(text, 1 if target == "rec.autos" else -1)

'-1 |text subject win storm from srini shannon tisl ukans edu srini seetharam reply srini shannon tisl ukans edu srini seetharam distribution world organization elec comp eng univ kansas nntp posting host erlang tisl ukans edu originator srini erlang lines anyone have any info the video sound card from sigma designs called win storm they also have another card called the legend 24lx any info would appreciated incuding performance pricing and availability thanks srini\n'

In [65]:
PATH_TO_WRITE_DATA="./RL"

**Make Test and Train split and divide convert all the data to vw format**

In [66]:
train_documents, test_documents, train_labels, test_labels = train_test_split(
    all_documents[:10000], all_targets_mult[:10000], random_state=7
)

with open(os.path.join(PATH_TO_WRITE_DATA, "20news_train_mult.vw"), "w") as vw_train_data:
    for text, target in zip(train_documents, train_labels):
        vw_train_data.write(to_vw_format(text, target))
with open(os.path.join(PATH_TO_WRITE_DATA, "20news_test_mult.vw"), "w") as vw_test_data:
    for text in test_documents:
        vw_test_data.write(to_vw_format(text))

**Training**

In [67]:
#!vw -d RL/20news_train.vw  --loss_function hinge -f RL/20news_model.vw
!vw --oaa 20 ./RL/20news_train_mult.vw -f ./RL/20news_model_mult.vw --loss_function=hinge

final_regressor = ./RL/20news_model_mult.vw
using no cache
Reading datafile = ./RL/20news_train_mult.vw
num sources = 1
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
Enabled learners: gd, scorer-identity, oaa
Input label = MULTICLASS
Output pred = MULTICLASS
average  since         example        example        current        current  current
loss     last          counter         weight          label        predict features
[32m[info][m label 20 found -- labels are now considered 1-indexed.
1.000000 1.000000            1            1.0             20              1     2251
1.000000 1.000000            2            2.0             13             20       43
1.000000 1.000000            4            4.0             18             15      158
1.000000 1.000000            8            8.0              2              6       95
0.875000 0.750000           16           16.0              8             10      234
0.843750 0.812500           32           32.0        

**Inference**

In [68]:
!vw -i RL/20news_model_mult.vw -t -d RL/20news_test_mult.vw -p RL/20news_test_predictions_mult.txt

only testing
predictions = RL/20news_test_predictions_mult.txt
using no cache
Reading datafile = RL/20news_test_mult.vw
num sources = 1
Num weight bits = 18
learning rate = 0.5
initial_t = 7500
power_t = 0.5
Enabled learners: gd, scorer-identity, oaa
Input label = MULTICLASS
Output pred = MULTICLASS
average  since         example        example        current        current  current
loss     last          counter         weight          label        predict features
n.a.     n.a.                1            1.0        unknown              2      124
n.a.     n.a.                2            2.0        unknown              5      159
n.a.     n.a.                4            4.0        unknown             12      184
n.a.     n.a.                8            8.0        unknown             18      199
n.a.     n.a.               16           16.0        unknown              5      176
n.a.     n.a.               32           32.0        unknown             16       36
n.a.     n.a.      

**Results**

In [69]:
with open(os.path.join(PATH_TO_WRITE_DATA, "20news_test_predictions_mult.txt")) as pred_file:
    test_prediction_mult = [float(label) for label in pred_file.readlines()]

In [70]:
print(accuracy_score(test_labels, test_prediction_mult))

0.8652


In [71]:
M = confusion_matrix(test_labels, test_prediction_mult)
for i in np.where(M[0, :] > 0)[0][1:]:
    print(newsgroups["target_names"][i], M[0, i])

comp.os.ms-windows.misc 1
comp.windows.x 1
misc.forsale 1
sci.med 1
soc.religion.christian 2
talk.religion.misc 6
