In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

In [20]:
class LogisticNet(torch.nn.Module):
    def __init__(self, D_in, D_out):
        super(LogisticNet, self).__init__()
        self.linear = nn.Linear(D_in, D_out)

    def forward(self, x):
        lin = self.linear(x)
        return lin

In [21]:
def train(model, loss_func, optimizer, trX, trY):
    x = Variable(trX, requires_grad=False)
    y = Variable(trY, requires_grad=False)
    optimizer.zero_grad()
    y_pred = model(x)
    loss = loss_func(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.data[0]

In [23]:
def valid(model, loss_func, valX, valY):
    x = Variable(valX, requires_grad=False)
    y = Variable(valY, requires_grad=False)

    outputs = model(x)
    val_loss = loss_func(outputs, y)
    # calculate accuracy
    _, predY = torch.max(outputs.data, 1)
    correct = (predY == y.data).sum()
    val_acc = float(correct) / y.size(0)
    return val_loss.data[0], val_acc

In [25]:
digits = load_digits()
data = digits['data']
target = digits['target']
# separate data
trX, teX, trY, teY = train_test_split(data, target, test_size=0.2, random_state=0)

n_samples = trX.shape[0]
input_dim = trX.shape[1]
n_classes = 10
model = LogisticNet(input_dim, n_classes)
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
loss_func = nn.CrossEntropyLoss()

trX = torch.from_numpy(trX).float()
teX = torch.from_numpy(teX).float()
trY = torch.from_numpy(trY.astype(np.int64))
teY = torch.from_numpy(teY.astype(np.int64))

N_EPOCHS = 200

for epoch in range(N_EPOCHS):
    loss = train(model, loss_func, optimizer, trX, trY)
    val_loss, val_acc = valid(model, loss_func, teX, teY)
    print 'val loss:%.3f val acc:%.3f' % (val_loss, val_acc)


val loss:9.123 val acc:0.106
val loss:8.922 val acc:0.106
val loss:8.652 val acc:0.119
val loss:8.338 val acc:0.122
val loss:8.005 val acc:0.131
val loss:7.672 val acc:0.136
val loss:7.353 val acc:0.131
val loss:7.054 val acc:0.122
val loss:6.777 val acc:0.119
val loss:6.518 val acc:0.108
val loss:6.278 val acc:0.111
val loss:6.054 val acc:0.111
val loss:5.844 val acc:0.111
val loss:5.649 val acc:0.106
val loss:5.466 val acc:0.097
val loss:5.293 val acc:0.103
val loss:5.128 val acc:0.111
val loss:4.971 val acc:0.117
val loss:4.820 val acc:0.128
val loss:4.673 val acc:0.128
val loss:4.531 val acc:0.136
val loss:4.392 val acc:0.139
val loss:4.256 val acc:0.144
val loss:4.122 val acc:0.150
val loss:3.991 val acc:0.164
val loss:3.863 val acc:0.169
val loss:3.737 val acc:0.172
val loss:3.613 val acc:0.189
val loss:3.492 val acc:0.197
val loss:3.374 val acc:0.203
val loss:3.260 val acc:0.217
val loss:3.150 val acc:0.219
val loss:3.045 val acc:0.211
val loss:2.946 val acc:0.217
val loss:2.852