In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

In [20]:
class LogisticNet(torch.nn.Module):
    def __init__(self, D_in, D_out):
        super(LogisticNet, self).__init__()
        self.linear = nn.Linear(D_in, D_out)

    def forward(self, x):
        lin = self.linear(x)
        return lin

In [21]:
def train(model, loss_func, optimizer, trX, trY):
    x = Variable(trX, requires_grad=False)
    y = Variable(trY, requires_grad=False)
    optimizer.zero_grad()
    y_pred = model(x)
    loss = loss_func(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.data[0]

In [22]:
def predict(model, x_val):
    output = model.forward(x_val)
    return output.data.numpy().argmax(axis=1)

In [23]:
def valid(model, loss_func, valX, valY):
    x = Variable(valX, requires_grad=False)
    y = Variable(valY, requires_grad=False)

    outputs = model(x)
    val_loss = loss_func(outputs, y)
    # calculate accuracy
    _, predY = torch.max(outputs.data, 1)
    correct = (predY == y.data).sum()
    val_acc = float(correct) / y.size(0)
    return val_loss.data[0], val_acc

In [24]:
digits = load_digits()
data = digits['data']
target = digits['target']
# separate data
trX, teX, trY, teY = train_test_split(data, target, test_size=0.2, random_state=0)

n_samples = trX.shape[0]
input_dim = trX.shape[1]
n_classes = 10
model = LogisticNet(input_dim, n_classes)
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
loss_func = nn.CrossEntropyLoss()

trX = torch.from_numpy(trX).float()
teX = torch.from_numpy(teX).float()
trY = torch.from_numpy(trY.astype(np.int64))
teY = torch.from_numpy(teY.astype(np.int64))

N_EPOCHS = 300

for epoch in range(N_EPOCHS):
    loss = train(model, loss_func, optimizer, trX, trY)
    val_loss, val_acc = valid(model, loss_func, teX, teY)
    print 'val loss:%.3f val acc:%.3f' % (val_loss, val_acc)


val loss:6.147 val acc:0.078
val loss:6.078 val acc:0.078
val loss:5.981 val acc:0.081
val loss:5.862 val acc:0.081
val loss:5.724 val acc:0.081
val loss:5.574 val acc:0.081
val loss:5.415 val acc:0.086
val loss:5.252 val acc:0.083
val loss:5.088 val acc:0.086
val loss:4.928 val acc:0.083
val loss:4.774 val acc:0.081
val loss:4.630 val acc:0.081
val loss:4.498 val acc:0.072
val loss:4.380 val acc:0.072
val loss:4.274 val acc:0.083
val loss:4.183 val acc:0.086
val loss:4.104 val acc:0.094
val loss:4.036 val acc:0.089
val loss:3.976 val acc:0.106
val loss:3.923 val acc:0.114
val loss:3.875 val acc:0.122
val loss:3.828 val acc:0.128
val loss:3.783 val acc:0.128
val loss:3.737 val acc:0.136
val loss:3.689 val acc:0.136
val loss:3.639 val acc:0.144
val loss:3.587 val acc:0.147
val loss:3.531 val acc:0.158
val loss:3.474 val acc:0.186
val loss:3.414 val acc:0.197
val loss:3.352 val acc:0.208
val loss:3.289 val acc:0.211
val loss:3.225 val acc:0.217
val loss:3.161 val acc:0.222
val loss:3.097