In [3]:
#!/usr/bin/env python
import wuml


##	We generated a synthetic data for regression with 4 dimensions where
##	x1 x2 has positive influence
##	x3 has no influence
##	x4 has negative influence
#
#	The key to this example is to show that if we use Gaussian type of data instead of
#	Uniform, the explanation labels no longer have the correct signs. (Magnitude still makes sense)
#	Implying that perhaps we should always map data to uniform distribution

data = wuml.wData(xpath='../../data/shap_regress_example_gaussian.csv', batch_size=20, 
					label_type='continuous', label_column_name='label', 
					row_id_with_label=0)


EXP = wuml.explainer(data, 	loss='mse',		# This will create a network for regression and explain instance wise 
						networkStructure=[(100,'relu'),(100,'relu'),(1,'none')], 
						max_epoch=150, learning_rate=0.001, print_network_training_status=False)


# Show the regression results
Ŷ = EXP.net(data, output_type='ndarray')
SR_train = wuml.summarize_regression_result(data.Y, Ŷ)
print(SR_train.true_vs_predict(print_result=False))


Network Info:
	Learning rate: 0.001
	Max number of epochs: 150
	Cost Function: mse
	Train Loop Callback: None
	Cuda Available: True
	Network Structure
		Linear(in_features=4, out_features=100, bias=True) , relu
		Linear(in_features=100, out_features=100, bias=True) , relu
		Linear(in_features=100, out_features=1, bias=True) , none

Avg error: 0.1383

['y' 'ŷ']
[-18.04 -18.  ]
[  7.69   7.43]
[-16.01 -15.98]
[  1.14   1.34]
[ 12.33  12.29]
[ -8.95  -9.09]
[  8.02   8.04]
[  8.09   8.01]
[ -5.95  -6.13]
[ 11.07  11.2 ]
[-28.4  -28.26]
[-10.19 -10.5 ]
[-11.34 -11.45]
[ -6.92  -6.66]
[  3.29   3.11]
[  4.38   4.49]
[ -3.71  -3.83]
[ -5.6   -5.66]
[  7.6    7.89]
[ -4.9   -4.82]
[  1.58   1.65]
[ 13.54  13.47]
[-18.2  -17.93]
[ -4.42  -4.41]
[ -1.66  -1.76]
[ -4.1   -4.  ]
[ -1.35  -1.26]
[  5.06   5.13]
[-11.76 -11.95]
[ -4.65  -5.05]



In [4]:
# Show the explanation results
explanation = EXP(data)	# outputs the weight importance
print(explanation)

  0%|          | 0/30 [00:00<?, ?it/s]

[[  6.131    0.3491   0.3403 -25.0978]
 [  7.5935  -1.5995   0.0205   1.1335]
 [ -0.4465   0.0581   0.0831 -15.9543]
 [  3.7118   0.188   -0.0655  -2.7714]
 [  7.1184  -0.2533  -0.098    5.2478]
 [-14.0462   0.1772   0.0659   4.4375]
 [  8.7578  -2.3954  -0.0049   1.4002]
 [  8.3315   2.7062   0.06    -3.3639]
 [ -4.2112  -0.948   -0.0493  -1.2052]
 [  6.8344   2.0732  -0.2046   2.2138]
 [ -4.8004  -0.7342  -2.0454 -20.9634]
 [ -2.3959  -0.2576  -0.6701  -7.4583]
 [ -8.1887  -0.0442  -0.2969  -3.2003]
 [ -2.4816  -0.7023   0.      -3.7596]
 [  0.3662   0.194   -0.1493   2.4192]
 [ -3.8864  -0.1323  -0.1625   8.3939]
 [  0.8205  -0.1832  -0.9343  -3.8095]
 [ -4.8955   0.023   -0.1446  -0.9184]
 [  4.6805  -1.4442  -0.0546   4.4297]
 [ -4.4958  -0.2119   0.0011  -0.3891]
 [ -6.898    0.59     0.0183   7.662 ]
 [  8.3319   3.4807   0.1055   1.2695]
 [ -5.6383   0.3458  -0.1172 -12.801 ]
 [  1.1386   0.9947   0.0959  -6.9236]
 [  0.0636   1.5234   0.0508  -3.6794]
 [  7.6236  -1.9248  -0.2