In [7]:
#!/usr/bin/env python

import wuml
import torch
import torch.nn as nn

#	The idea of training a neural network boils down to 3 steps
#		1. Define a network structure
#			Example: This is a 3 layer network with 100 node width
#				networkStructure=[(100,'relu'),(100,'relu'),(1,'softmax')]
#		2. Define a cost function
#		3. Call train()
#		4. Display results

data = wuml.wData(xpath='../../data/wine.csv', ypath='../../data/wine_label.csv', batch_size=20, label_type='discrete')

def costFunction(x, y, ŷ, ind):
	lossFun = nn.CrossEntropyLoss() 
	loss = lossFun(ŷ, y) #weird from pytorch, dim of y is 1, and ŷ is 20x3	
	return loss

#It is important for pytorch that with classification, you need to define Y_dataType=torch.int64
#You can define a costFunction, but for classification it can be directly set to 'CE'
#bNet = wuml.basicNetwork(costFunction, data, networkStructure=[(100,'relu'),(100,'relu'),(3,'none')], 
bNet = wuml.basicNetwork('CE', data, networkStructure=[(100,'relu'),(100,'relu'),(3,'none')], 
						Y_dataType=torch.LongTensor, max_epoch=100, learning_rate=0.001)
bNet.train(print_status=False)

#	Report Results
Ŷ = bNet(data.X, output_type='ndarray', out_structural='1d_labels')
CR = wuml.summarize_classification_result(data.Y, Ŷ)
print('\nAccuracy : %.3f\n\n'%CR.avg_error())
print(CR.true_vs_predict(sort_based_on_label=True, print_result=False))

Network Info:
	Learning rate: 0.001
	Max number of epochs: 100
	Cost Function: CE
	Train Loop Callback: None
	Cuda Available: True
	Network Structure
		Linear(in_features=13, out_features=100, bias=True) , relu
		Linear(in_features=100, out_features=100, bias=True) , relu
		Linear(in_features=100, out_features=3, bias=True) , none


Accuracy : 0.758


Avg error: 0.7584

['y' 'ŷ']
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 0]
[0 1]
[0 1]
[0 1]
[0 0]
[0 0]
[0 1]
[0 0]
[0 1]
[0 1]
[0 1]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 1]
[0 1]
[0 1]
[0 0]
[0 0]
[0 1]
[0 0]
[0 0]
[0 1]
[0 0]
[0 0]
[0 0]
[0 0]
[0 1]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 0]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 1]
[0 0]
[0 1]
[0 0]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]
[1 1]