# Bayesian Regularized Iterative Soft Thresholding Algorithm

### How to Run

The BARISTA algorithm is a class-specific attribute weighted Naive Bayes framework designed to mitigate overfitting and alleviating the condition independence assumption of Naive Bayes. This is a brief tutorial on running the model.

**Data Importing**

Please upload your data from a csv file into a pandas dataframe.

In [29]:
import pandas as pd
breast_w = pd.read_csv('/filepath/breast_w.csv')

**Pre-processing**

Missing values are imputed with the mean value or max frequency depending on the attribute type. Numerical attributes are discretized using the MDL discretization technique (see our paper for details). Pass in the dataset, the target attribute column name as a string, and a list of attribute column names that are numerical; in this case there are none). Next, use the get_data function to get a design matrix, $X$ and a vector of labels, $y$. Finally, we split the data in a testing set and a training set to evaluate generalization performance.

In [30]:
import preprocess
breast_w = preprocess.Preprocess(breast_w, "Class", [])
X, y = breast_w.get_data()

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
X_train = X_train.reset_index(drop = True)
y_train = y_train.reset_index(drop = True)
X_test = X_test.reset_index(drop = True)
y_test = y_test.reset_index(drop = True)

**Fitting the Data**

We can now fit the training data using the BARISTA algorithm. Once the fit method is used, the optimization procedure will be used. For parameter details, see the readMe file.

In [31]:
import BARISTA

barista = BARISTA.BARISTA()
barista.fit(training_samples = X_train, training_labels = y_train, scheme = 'FISTA', 
                    learning_rate = 0.1, convergence_constant= 1e-6, max_iterations = 5000, 
                    l1_penalty = 0.01, l2_penalty = 0.001)

Computing Model Parameters...
Prior Probability Distribution Computed...
Prior Probability Distribution Computed...
Initial Posterior Probability Distribution Computed...
FISTA Scheme Selected (Default)
__Optimizing__...
Initial Loss: 134.02415441350448
Iteration: 1
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9999999728938206, 2.710617945104171e-08]

Weight Matrix: [[0.9940581  0.9901933  0.99001009 0.99130065 0.99034259 0.99393587
  0.992734   0.98944736 0.99731843]
 [1.00329713 1.00258097 1.00213735 1.00354258 1.00227644 1.003067
  1.00183815 1.00395688 1.00032793]]

Gradient Norm: 0.24483661379631588

Gradient Matrix: [[ 0.04941899  0.08806701  0.08989906  0.07699352  0.08657407  0.05064131
   0.06266002  0.09552644  0.01681571]
 [-0.04297135 -0.03580969 -0.03137351 -0.04542576 -0.03276444 -0.04066996
  -0.02838146 -0.04956877 -0.01327929]]

Model Loss: 130.80171125513738

Iteration: 2
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_P

Iteration: 15
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9999999993061297, 6.938703111165525e-10]

Weight Matrix: [[0.8642286  0.70444428 0.70011271 0.72500988 0.7014746  0.8731658
  0.79516356 0.67558119 0.90975962]
 [1.06470455 1.02166544 1.00215741 1.07283241 1.02677844 1.03017257
  1.00144266 1.09167247 0.98960822]]

Gradient Norm: 0.07394602831727796

Gradient Matrix: [[-0.02307642  0.01149921  0.00671243  0.022792    0.01184816 -0.02811156
  -0.00399624  0.01346581  0.00832518]
 [ 0.00866452  0.02654368  0.02723065  0.00715288  0.01792566  0.02471054
   0.02280844  0.00173804  0.00373972]]

Model Loss: 77.46765920468016

Iteration: 16
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9999999992502184, 7.49781565214307e-10]

Weight Matrix: [[0.86286306 0.68559455 0.68246046 0.70364391 0.6822237  0.87383488
  0.78489902 0.6548846  0.90046145]
 [1.06159189 1.01122345 0.990566   1.070

Iteration: 26
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9999999305421323, 6.945786767322137e-08]

Weight Matrix: [[0.93118451 0.56504436 0.58781321 0.52595971 0.55702255 0.96806667
  0.74912568 0.52071252 0.78832314]
 [0.94525255 0.7948213  0.76769688 0.96511374 0.84787482 0.81431257
  0.79442522 1.01697364 0.90198355]]

Gradient Norm: 0.057590529739459204

Gradient Matrix: [[-0.00053785  0.02188333  0.01951003  0.02180637  0.02124079 -0.00035246
   0.00936126  0.02291017  0.01298461]
 [-0.00024061  0.01181407  0.01471458 -0.00261705  0.00884669  0.01146914
   0.01139827 -0.00667986 -0.00126171]]

Model Loss: 59.646350777120794

Iteration: 27
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9999998914294695, 1.0857053051945414e-07]

Weight Matrix: [[0.93462234 0.54844174 0.57398013 0.50493304 0.54032928 0.97322349
  0.74213461 0.50303786 0.77326169]
 [0.93261333 0.77111802 0.74275265 

Iteration: 41
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9999627921602859, 3.72078397140117e-05]

Weight Matrix: [[0.87692121 0.27038479 0.32886299 0.21667586 0.26542568 0.90885949
  0.59044088 0.23203937 0.47350497]
 [0.72664646 0.41653435 0.35987423 0.78171484 0.51348886 0.44934003
  0.44314566 0.88629177 0.76031347]]

Gradient Norm: 0.04048803504929215

Gradient Matrix: [[-0.00113701 -0.00298362 -0.00303916 -0.00800549 -0.00220497  0.00451854
  -0.0022555  -0.00700116  0.01291546]
 [ 0.00930944  0.01520167  0.01736483  0.0064334   0.0159649   0.01397081
   0.01339773  0.00439508 -0.00010093]]

Model Loss: 40.02816068712838

Iteration: 42
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9999337821107734, 6.621788922657063e-05]

Weight Matrix: [[0.87049175 0.25499427 0.3152101  0.20321699 0.24977786 0.89937091
  0.58089069 0.21930631 0.44732142]
 [0.70602197 0.38587288 0.32682612 0.76

Iteration: 55
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9977194501436579, 0.0022805498563419923]

Weight Matrix: [[0.75365968 0.12629184 0.18737593 0.10853165 0.12230118 0.73252216
  0.45785398 0.12552144 0.22819078]
 [0.55260847 0.17306881 0.09487377 0.63577482 0.25700661 0.21822431
  0.21030105 0.76753314 0.64781319]]

Gradient Norm: 0.029517758061989988

Gradient Matrix: [[ 0.00689054 -0.00967913 -0.00616012 -0.01167623 -0.00926983  0.01318252
   0.00132145 -0.01497401  0.00146785]
 [-0.00065081 -0.00319717 -0.00273803 -0.00166926  0.00504026 -0.00380833
  -0.00086412 -0.00214783  0.00135092]]

Model Loss: 34.67377825121244

Iteration: 56
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9975286859285594, 0.0024713140714406104]

Weight Matrix: [[0.7434357  0.12573989 0.18475265 0.10941704 0.12155759 0.71818657
  0.45081279 0.1282713  0.22098142]
 [0.5469541  0.16876353 0.09024017 0.

Iteration: 67
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9955702727312363, 0.004429727268763665]

Weight Matrix: [[0.64658674 0.13008792 0.1677827  0.11770098 0.12189865 0.59774733
  0.38980564 0.15768125 0.16546724]
 [0.50055487 0.14357894 0.06464117 0.58762511 0.16993362 0.19158831
  0.16367458 0.72164783 0.57246267]]

Gradient Norm: 0.03516352191182359

Gradient Matrix: [[ 0.00495815 -0.01349259 -0.00966358 -0.01146264 -0.01212619  0.00433841
  -0.00216127 -0.01547303 -0.00444064]
 [-0.00391598 -0.00940648 -0.0100691  -0.0036432   0.00084661 -0.00972977
  -0.00620652 -0.0037622   0.00297577]]

Model Loss: 34.32786099323668

Iteration: 68
Learning Rate: 0.1
L1_Penalty Term: 0.01
L2_Penalty Term: 0.001
Posterior Cache First Sample: [0.9954050824977806, 0.0045949175022193655]

Weight Matrix: [[0.63706465 0.13201165 0.16733067 0.11852434 0.12303947 0.58831266
  0.38467116 0.16103865 0.16183965]
 [0.4968998  0.14328617 0.06472613 0.58

**Classification Performance**

Now that the data has been fit, we can classify the testing instances. We also provide an accuracy score.

In [33]:
barista.predict(X_test, y_test)
print("Testing Accuracy:", barista.accuracy)

Testing Accuracy: 0.9657142857142857
