In [2]:
import torch
import numpy as np
from sklearn import datasets
from matplotlib import pyplot as plt

# keeps double precision - not needed if we use built-in function log_softmax
# torch.set_default_tensor_type(torch.DoubleTensor)

In [3]:
source = datasets.load_digits()
data = source.data
target = source.target

In [4]:
features = 64
classes = 10

design = np.insert(data, 0, 1., 1)
onehot = (np.arange(classes) == target[:, None])

# the default torch precision is float32 and int64 respectively
DESIGN = torch.tensor(design, dtype = torch.float32)
TARGET = torch.tensor(target, dtype = torch.int64)

print(design.dtype, target.dtype)
print(DESIGN.dtype, TARGET.dtype)

float64 int64
torch.float32 torch.int64


In [7]:
PARAM = torch.zeros(1+ features, classes, requires_grad = True)

opt = torch.optim.SGD([PARAM], lr = 10**(-4))
for epoch in range(1000):
    
    ACTIVATION = DESIGN @ PARAM
    
    VALUE = torch.argmax( ACTIVATION, 1)
    HIT = (VALUE == TARGET)
    ACC = torch.mean(HIT.float())
    
    LOSS = torch.nn.functional.cross_entropy(ACTIVATION, TARGET, reduction = "sum")
        
    LOSS.backward()
    opt.step()
    opt.zero_grad()
    
    print("%5d %12.3f %4.3f" % (epoch, LOSS, ACC), flush = True)

    0     4137.681 0.100
    1      649.797 0.884
    2    12075.448 0.336
    3    70934.688 0.469
    4   113185.977 0.360
    5   147726.484 0.251
    6   124789.055 0.417
    7   120977.617 0.386
    8    95676.625 0.298
    9    77149.336 0.297
   10    90201.805 0.403
   11    65161.609 0.606
   12    32901.043 0.648
   13    15186.608 0.641
   14    16719.033 0.770
   15     3437.151 0.869
   16     2830.863 0.890
   17     2769.283 0.874
   18     4733.464 0.869
   19     4185.443 0.835
   20     9211.494 0.839
   21     4076.286 0.863
   22     2610.437 0.898
   23     1527.203 0.928
   24     1182.090 0.940
   25      920.574 0.950
   26      862.211 0.953
   27      808.708 0.954
   28      772.095 0.955
   29      738.572 0.955
   30      708.423 0.958
   31      681.311 0.959
   32      657.220 0.960
   33      635.690 0.963
   34      616.342 0.964
   35      598.695 0.966
   36      582.465 0.965
   37      567.384 0.966
   38      553.299 0.966
   39      540.060 0.967


  328       30.590 0.998
  329       30.388 0.998
  330       30.188 0.998
  331       29.988 0.998
  332       29.791 0.998
  333       29.594 0.998
  334       29.399 0.998
  335       29.205 0.998
  336       29.012 0.998
  337       28.821 0.998
  338       28.630 0.998
  339       28.441 0.998
  340       28.254 0.998
  341       28.067 0.998
  342       27.881 0.998
  343       27.697 0.998
  344       27.514 0.998
  345       27.332 0.998
  346       27.151 0.998
  347       26.971 0.998
  348       26.792 0.998
  349       26.615 0.998
  350       26.438 0.998
  351       26.263 0.998
  352       26.089 0.998
  353       25.915 0.998
  354       25.743 0.998
  355       25.572 0.998
  356       25.402 0.998
  357       25.233 0.998
  358       25.065 0.998
  359       24.898 0.998
  360       24.732 0.998
  361       24.567 0.998
  362       24.403 0.998
  363       24.240 0.998
  364       24.078 0.998
  365       23.917 0.998
  366       23.757 0.998
  367       23.598 0.998


  656        6.739 1.000
  657        6.725 1.000
  658        6.711 1.000
  659        6.697 1.000
  660        6.683 1.000
  661        6.670 1.000
  662        6.656 1.000
  663        6.643 1.000
  664        6.629 1.000
  665        6.616 1.000
  666        6.603 1.000
  667        6.589 1.000
  668        6.576 1.000
  669        6.563 1.000
  670        6.550 1.000
  671        6.537 1.000
  672        6.524 1.000
  673        6.511 1.000
  674        6.499 1.000
  675        6.486 1.000
  676        6.473 1.000
  677        6.461 1.000
  678        6.448 1.000
  679        6.436 1.000
  680        6.423 1.000
  681        6.411 1.000
  682        6.399 1.000
  683        6.387 1.000
  684        6.374 1.000
  685        6.362 1.000
  686        6.350 1.000
  687        6.338 1.000
  688        6.326 1.000
  689        6.314 1.000
  690        6.303 1.000
  691        6.291 1.000
  692        6.279 1.000
  693        6.268 1.000
  694        6.256 1.000
  695        6.244 1.000


  983        4.239 1.000
  984        4.235 1.000
  985        4.231 1.000
  986        4.227 1.000
  987        4.223 1.000
  988        4.218 1.000
  989        4.214 1.000
  990        4.210 1.000
  991        4.206 1.000
  992        4.202 1.000
  993        4.197 1.000
  994        4.193 1.000
  995        4.189 1.000
  996        4.185 1.000
  997        4.181 1.000
  998        4.177 1.000
  999        4.173 1.000
