# Softmax exercise

*Complete and hand in this completed worksheet (including its outputs and any supporting code outside of the worksheet) with your assignment submission. For more details see the [assignments page](http://vision.stanford.edu/teaching/cs231n/assignments.html) on the course website.*

This exercise is analogous to the SVM exercise. You will:

- implement a fully-vectorized **loss function** for the Softmax classifier
- implement the fully-vectorized expression for its **analytic gradient**
- **check your implementation** with numerical gradient
- use a validation set to **tune the learning rate and regularization** strength
- **optimize** the loss function with **SGD**
- **visualize** the final learned weights


In [1]:
import random
import numpy as np
from cs231n.data_utils import load_CIFAR10
import matplotlib.pyplot as plt

from __future__ import print_function

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading extenrnal modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=1000, num_dev=500):
    """
    Load the CIFAR-10 dataset from disk and perform preprocessing to prepare
    it for the linear classifier. These are the same steps as we used for the
    SVM, but condensed to a single function.  
    """
    # Load the raw CIFAR-10 data
    cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
    
    # subsample the data
    mask = list(range(num_training, num_training + num_validation))
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = list(range(num_training))
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = list(range(num_test))
    X_test = X_test[mask]
    y_test = y_test[mask]
    mask = np.random.choice(num_training, num_dev, replace=False)
    X_dev = X_train[mask]
    y_dev = y_train[mask]
    
    # Preprocessing: reshape the image data into rows
    X_train = np.reshape(X_train, (X_train.shape[0], -1))
    X_val = np.reshape(X_val, (X_val.shape[0], -1))
    X_test = np.reshape(X_test, (X_test.shape[0], -1))
    X_dev = np.reshape(X_dev, (X_dev.shape[0], -1))
    
    # Normalize the data: subtract the mean image
    mean_image = np.mean(X_train, axis = 0)
    X_train -= mean_image
    X_val -= mean_image
    X_test -= mean_image
    X_dev -= mean_image
    
    # add bias dimension and transform into columns
    X_train = np.hstack([X_train, np.ones((X_train.shape[0], 1))])
    X_val = np.hstack([X_val, np.ones((X_val.shape[0], 1))])
    X_test = np.hstack([X_test, np.ones((X_test.shape[0], 1))])
    X_dev = np.hstack([X_dev, np.ones((X_dev.shape[0], 1))])
    
    return X_train, y_train, X_val, y_val, X_test, y_test, X_dev, y_dev


# Invoke the above function to get our data.
X_train, y_train, X_val, y_val, X_test, y_test, X_dev, y_dev = get_CIFAR10_data()
print('Train data shape: ', X_train.shape)
print('Train labels shape: ', y_train.shape)
print('Validation data shape: ', X_val.shape)
print('Validation labels shape: ', y_val.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)
print('dev data shape: ', X_dev.shape)
print('dev labels shape: ', y_dev.shape)

Train data shape:  (49000, 3073)
Train labels shape:  (49000,)
Validation data shape:  (1000, 3073)
Validation labels shape:  (1000,)
Test data shape:  (1000, 3073)
Test labels shape:  (1000,)
dev data shape:  (500, 3073)
dev labels shape:  (500,)


## Softmax Classifier

Your code for this section will all be written inside **cs231n/classifiers/softmax.py**. 


In [11]:
# First implement the naive softmax loss function with nested loops.
# Open the file cs231n/classifiers/softmax.py and implement the
# softmax_loss_naive function.

from cs231n.classifiers.softmax import softmax_loss_naive
import time

# Generate a random softmax weight matrix and use it to compute the loss.
W = np.random.randn(3073, 10) * 0.0001
loss, grad = softmax_loss_naive(W, X_dev, y_dev, 0.0)

# As a rough sanity check, our loss should be something close to -log(0.1).
print('loss: %f' % loss)
print('sanity check: %f' % (-np.log(0.1)))

[-8.70553086  1.40058557  4.43307136 ..., -2.38971874 -1.22608474
  0.06663659]
[-12.56923659   2.0221962    6.40056576 ...,  -3.45032839  -1.77024806
   0.09621137]
[-9.62292937  1.54818083  4.90023334 ..., -2.64154996 -1.35529091
  0.07365883]
[-17.42590601   2.80355936   8.87370181 ...,  -4.78351233  -2.45426014
   0.13338681]
[-11.91914912   1.91760716   6.06952516 ...,  -3.27187561  -1.67868991
   0.09123527]
[-11.86143144   1.90832127   6.0401339  ...,  -3.25603177  -1.67056096
   0.09079347]
[ 115.14661137  -18.52531282  -58.63549893 ...,   31.60841309   16.21721924
   -0.88139114]
[-17.55915453   2.82499699   8.94155524 ...,  -4.82008982  -2.47302682
   0.13440676]
[-8.3634281   1.34554652  4.25886419 ..., -2.29580955 -1.17790307
  0.06401796]
[-17.11984536   2.75431892   8.71784816 ...,  -4.69949691  -2.41115463
   0.13104406]
[ -3.07174983   5.70658856  10.08492715 ...,  -3.02047107  -3.09829499
   0.17423753]
[-60.04908014 -34.26752596 -41.23160655 ...,   1.865466    14.0081

[  26.98196195  -88.14820084  -17.17646317 ...,   54.11673202  171.66756551
    1.78759937]
[  20.28011478  211.03396838  353.86980119 ..., -227.658804   -141.27775489
   -3.28163765]
[  5.39562529e+02   5.92991638e+02   6.04610991e+02 ...,   3.42866205e+02
   3.42449266e+02  -5.08373677e-01]
[ 165.38980445  294.23369104  365.23537397 ...,  -50.62067571  -82.35625667
    2.79501438]
[ 347.7714463   266.52058451  254.43749569 ...,   87.73333668  103.62954542
    1.80811792]
[-503.93081062 -614.50032953 -672.64073467 ...,  -20.18833823  169.27939201
   -2.42788208]
[ -81.1444873  -141.4413283  -184.71717685 ...,  181.29092385   91.66277483
   -0.493219  ]
[-281.21923273 -305.00685659 -293.63507449 ..., -299.8886567  -310.39648321
    1.81607749]
[  -3.8747105   -81.0584292  -203.80064327 ..., -104.90604671 -216.15143973
    5.13290539]
[-231.8716786  -136.72749851 -208.49576608 ...,   42.46783251 -118.03606736
   -6.48824718]
[  25.66232858  -89.49846275  -18.66121143 ...,   57.46387992 

[  6.14050154e+02   4.71626378e+02   6.62704159e+02 ...,   2.11236921e+02
   3.32762354e+02   2.07421364e-01]
[-585.97993556 -690.51247895 -785.82791394 ...,  -52.56179213  214.01311597
   -1.43067868]
[ -28.83888093  -73.37676402 -166.19810042 ...,  379.29369417  211.13366314
   -0.70523473]
[-614.79927348 -612.63620401 -614.23541899 ..., -563.65473843 -568.24360256
    0.68756468]
[ -42.30102342 -247.89105474 -545.91310206 ..., -112.8844032  -318.31549787
    4.43827813]
[-222.02348096 -210.54472198 -323.73863327 ...,  -70.05658173 -270.74732869
   -6.28699577]
[ -44.86488059 -168.44706556 -168.37423178 ...,  107.37517922  241.40172185
    4.1996355 ]
[  16.29431809  237.60446076  391.99444108 ..., -233.07051897 -139.57966326
   -0.92475724]
[ 758.35747886  853.20351485  945.30968902 ...,  512.70489202  501.4773173
   -2.21406744]
[ 134.42303403  430.97080423  594.45820324 ..., -222.49183164 -250.03363176
    1.65859828]
[  6.18100041e+02   4.74209612e+02   6.65240336e+02 ...,   2.22

[-349.35324735 -303.06296302 -453.74163073 ...,  -82.54260975 -357.3650435
   -4.55570554]
[ -98.72738872 -229.07098347 -108.09132375 ...,  199.57054316  469.51284031
    4.56301368]
[ 265.48049645  606.68754481  856.15444309 ...,   24.30495994  176.39915152
   -3.97657927]
[  895.37968812   986.65626657  1163.1964915  ...,   619.78748486
   764.58789871    -4.21662325]
[  80.16729968  435.61259622  655.47303229 ...,  -10.52023666  -84.26567727
    4.95310938]
[  4.81224570e+02   3.40714005e+02   6.92363289e+02 ...,  -9.93854295e+01
   8.93552680e+01   3.24070137e-01]
[-600.35846307 -699.45938471 -838.76877544 ..., -105.68283802  164.95032453
    1.38695821]
[  -1.96585844 -122.370516   -379.71500427 ...,  500.20188733  172.36253658
   -1.35805674]
[-410.90584415 -522.10050138 -689.49462357 ..., -780.28238818 -825.32581619
   -2.9807397 ]
[-257.23847401 -488.94199888 -894.32047888 ..., -262.77202922 -573.58586209
    5.97155398]
[-346.25186755 -299.15642783 -451.18246734 ...,  -80.2984

[ -546.77522963  -856.21644552 -1417.38199732 ...,  -483.04964892
  -867.273365       4.06153001]
[-624.26586827 -550.73879987 -712.74423835 ...,  -69.43934828 -374.14421157
   -7.05837229]
[ -99.70868444 -268.64577461 -152.98228815 ...,  136.16907341  361.19009886
    5.81400708]
[  568.94388905   944.72552987  1191.41326968 ...,   114.54605916
   206.05406963    -3.67229926]
[  877.15294855   975.42457417  1196.42670535 ...,   425.40018093
   629.62926806    -4.90958647]
[ 166.55166513  534.84435823  832.20492458 ..., -124.35845462 -120.06792186
    4.65455452]
[  756.31872394   642.34016644  1075.91788544 ...,    75.56122611
   347.93396358    -2.01120747]
[-702.45014098 -821.29510333 -942.18488383 ...,  -54.39440289  297.266188
    1.65099038]
[ -89.05392179 -165.11388738 -439.00798624 ...,  707.63811342  310.28002982
   -0.80158802]
[-306.71338156 -435.324618   -631.66139116 ..., -728.07279831 -790.86811953
    2.27197152]
[ -546.6441942   -858.33733631 -1420.80467108 ...,  -489.4

[ -616.47462935  -947.14374487 -1557.2551179  ...,  -551.94962613
 -1012.76354623     6.75903727]
[-803.58206829 -706.01067094 -845.55758739 ...,  -64.18273035 -420.87796702
   -6.49837984]
[-152.19236552 -299.3646474   -41.1249011  ...,  290.93657015  598.95940039
    4.39591926]
[  826.15530643  1273.26623675  1525.86843795 ...,   355.47447541
   422.83683776    -5.30196564]
[  937.22131092   993.36433026  1130.70399461 ...,   497.13544006
   709.64644244    -2.53134663]
[  441.33036545   827.39303909  1257.02558945 ...,    45.5032684
    31.78633317     5.36632245]
[  652.3066248    589.39608316  1043.65929225 ...,  -168.64695675
   144.59239805    -3.75515674]
[ -837.71486156  -908.67672147 -1217.14224813 ...,  -130.44532619
   331.39525185    -2.05485683]
[ -30.95823338 -294.06710505 -626.44911345 ...,  797.07401242  324.94472719
    1.8496473 ]
[ -416.0914495   -528.15679953  -669.72834628 ..., -1070.89912704
 -1130.51987759     1.7707794 ]
[ -611.94836893  -942.81183841 -1552.79

[ -581.83743337  -923.46298551 -1557.77465351 ...,  -354.85402859
  -842.62154944     8.29205421]
[-775.02471681 -657.18054234 -734.07351669 ...,  -92.96163256 -462.79457859
   -6.64826004]
[-133.81477986 -297.10089298  -41.04907198 ...,  222.1549347   519.67110687
    8.0083783 ]
[  8.90590205e+02   1.34688651e+03   1.63192351e+03 ...,   3.30702030e+02
   4.19614648e+02  -1.62600074e+00]
[ 1042.5401752   1112.5995948   1312.52357369 ...,   602.46915149
   798.75119754    -3.70195586]
[  807.42300912  1215.58460923  1708.92109263 ...,   184.42463488
   301.12212604     6.73244997]
[  611.27645523   649.53264621  1188.28782882 ...,  -236.04538002
   126.37608651    -4.17931702]
[-1007.75521276 -1065.42175047 -1354.80115819 ...,  -447.21513091
    83.52404018    -4.27803992]
[  -67.26796339  -445.2808292  -1008.59496304 ...,   899.42684143
   252.96988121    -2.00131695]
[ -7.86129739e+02  -9.36156357e+02  -1.14536264e+03 ...,  -1.10810142e+03
  -1.19661296e+03  -5.97991943e-01]
[ -573.3

[-1581.3420478  -1625.64689699 -1911.16241434 ...,  -712.62193498
  -111.30119038    -5.55683574]
[  217.514841    -198.82563499  -800.70805873 ...,  1397.54498837
   624.41713924    -1.87588716]
[-1004.38235111 -1142.995277   -1351.52006953 ..., -1251.62579765
 -1331.85325657     1.90076645]
[ -640.04882487 -1196.48392672 -2029.00514314 ...,  -658.575312
 -1190.26174477     9.28552914]
[-909.47354299 -768.91099347 -849.38815281 ...,   89.71257048 -275.90068144
   -3.53255292]
[-264.40997834 -491.7276423  -247.10462759 ...,  -15.50451086  300.65223047
    6.57653025]
[ 1207.8511441   1721.16363763  2018.0461416  ...,   396.35605835
   471.20148415    -3.66775427]
[ 1117.63913128  1286.05670469  1567.45948121 ...,   693.39718033
   930.58804892    -3.0375018 ]
[ 1109.89230167  1581.92178314  2113.93645791 ...,   302.97952652
   421.73402038     5.53051663]
[  755.63844136   837.29093695  1500.656537   ...,  -228.76661983
   182.47168889    -5.88674081]
[-1584.67859831 -1626.33933446 -19

   460.79873937     6.08383127]
[  909.27002192  1052.24350422  1803.47540141 ...,   -94.84753625
   343.46638365    -5.77612328]
[-1623.25871924 -1642.54869758 -1865.70406956 ...,  -763.34463127
  -112.09057105    -3.65014585]
[  264.3299482   -178.04269732  -865.07627285 ...,  1511.14345638
   604.56573646    -2.13485758]
[-1026.52450262 -1220.67005802 -1464.73698706 ..., -1343.89197694
 -1333.07229494     3.82268403]
[ -849.71823393 -1606.86720751 -2533.23867161 ...,  -882.43462704
 -1473.44518286     6.47950724]
[-892.59031628 -750.29434408 -920.75462168 ...,   -5.38005657 -294.02378351
   -3.26495975]
[-529.38827307 -734.56316706 -474.27712387 ..., -139.91826521  201.68506165
    7.52778254]
[ 1288.27161782  1940.34937233  2270.62820262 ...,   742.64999226
   686.48273568    -5.65432651]
[ 1191.96881877  1368.23727538  1735.98750416 ...,   618.85669421
   910.1126086     -3.13335432]
[ 1287.66192884  1805.56626793  2363.13434907 ...,   362.19442179
   458.01413668     6.2351719 ]


   506.26154351     9.73207503]
[ 1153.23869999  1343.676244    2147.26544859 ...,    78.32910908
   526.20383683    -5.16192497]
[-1981.17341118 -1990.39918284 -2238.75999697 ...,  -909.42830044
  -252.4804041     -8.24127842]
[  3.26363080e+02  -1.39318024e+02  -8.74734697e+02 ...,   1.68771661e+03
   7.05752424e+02  -1.22948889e+00]
[-1093.66608426 -1320.1072942  -1613.69454499 ..., -1358.1089357
 -1352.86699819     3.40201353]
[ -982.12632827 -1886.91310955 -2939.20806717 ..., -1041.27492156
 -1620.21486375     3.64109142]
[-899.50387845 -657.49409095 -836.80554633 ..., -250.70889626 -581.37986867
   -2.95172991]
[ -4.23267736e+02  -6.04194707e+02  -3.38286198e+02 ...,   2.34583337e-01
   3.59160820e+02   9.06702614e+00]
[ 1278.37782035  1956.79625905  2304.58847489 ...,   898.45546146
   849.40715921    -5.06067813]
[ 1295.51647894  1439.29700243  1948.34755129 ...,   518.10613795
   889.17670264    -2.75472529]
[ 1393.98523965  1925.0267183   2510.30459932 ...,   408.47163756
   

## Inline Question 1:
Why do we expect our loss to be close to -log(0.1)? Explain briefly.**

**Your answer:** *Fill this in*


In [None]:
# Complete the implementation of softmax_loss_naive and implement a (naive)
# version of the gradient that uses nested loops.
loss, grad = softmax_loss_naive(W, X_dev, y_dev, 0.0)

# As we did for the SVM, use numeric gradient checking as a debugging tool.
# The numeric gradient should be close to the analytic gradient.
from cs231n.gradient_check import grad_check_sparse
f = lambda w: softmax_loss_naive(w, X_dev, y_dev, 0.0)[0]
grad_numerical = grad_check_sparse(f, W, grad, 10)

# similar to SVM case, do another gradient check with regularization
loss, grad = softmax_loss_naive(W, X_dev, y_dev, 5e1)
f = lambda w: softmax_loss_naive(w, X_dev, y_dev, 5e1)[0]
grad_numerical = grad_check_sparse(f, W, grad, 10)

In [19]:
# Now that we have a naive implementation of the softmax loss function and its gradient,
# implement a vectorized version in softmax_loss_vectorized.
# The two versions should compute the same results, but the vectorized version should be
# much faster.
tic = time.time()
loss_naive, grad_naive = softmax_loss_naive(W, X_dev, y_dev, 0.000005)
toc = time.time()
print('naive loss: %e computed in %fs' % (loss_naive, toc - tic))

from cs231n.classifiers.softmax import softmax_loss_vectorized
tic = time.time()
loss_vectorized, grad_vectorized = softmax_loss_vectorized(W, X_dev, y_dev, 0.000005)
toc = time.time()
print('vectorized loss: %e computed in %fs' % (loss_vectorized, toc - tic))

# As we did for the SVM, we use the Frobenius norm to compare the two versions
# of the gradient.
grad_difference = np.linalg.norm(grad_naive - grad_vectorized, ord='fro')
print('Loss difference: %f' % np.abs(loss_naive - loss_vectorized))
print('Gradient difference: %f' % grad_difference)

naive loss: 2.386522e+00 computed in 0.115528s
(500, 10)
Gradient difference: 187901.075801


In [13]:
# Use the validation set to tune hyperparameters (regularization strength and
# learning rate). You should experiment with different ranges for the learning
# rates and regularization strengths; if you are careful you should be able to
# get a classification accuracy of over 0.35 on the validation set.
from cs231n.classifiers import Softmax
results = {}
best_val = -1
best_softmax = None
learning_rates = [1e-7, 5e-7]
regularization_strengths = [2.5e4, 5e4]

################################################################################
# TODO:                                                                        #
# Use the validation set to set the learning rate and regularization strength. #
# This should be identical to the validation that you did for the SVM; save    #
# the best trained softmax classifer in best_softmax.                          #
################################################################################
pass
################################################################################
#                              END OF YOUR CODE                                #
################################################################################
    
# Print out results.
for lr, reg in sorted(results):
    train_accuracy, val_accuracy = results[(lr, reg)]
    print('lr %e reg %e train accuracy: %f val accuracy: %f' % (
                lr, reg, train_accuracy, val_accuracy))
    
print('best validation accuracy achieved during cross-validation: %f' % best_val)

best validation accuracy achieved during cross-validation: -1.000000


In [None]:
# evaluate on test set
# Evaluate the best softmax on test set
y_test_pred = best_softmax.predict(X_test)
test_accuracy = np.mean(y_test == y_test_pred)
print('softmax on raw pixels final test set accuracy: %f' % (test_accuracy, ))

In [None]:
# Visualize the learned weights for each class
w = best_softmax.W[:-1,:] # strip out the bias
w = w.reshape(32, 32, 3, 10)

w_min, w_max = np.min(w), np.max(w)

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
for i in range(10):
    plt.subplot(2, 5, i + 1)
    
    # Rescale the weights to be between 0 and 255
    wimg = 255.0 * (w[:, :, :, i].squeeze() - w_min) / (w_max - w_min)
    plt.imshow(wimg.astype('uint8'))
    plt.axis('off')
    plt.title(classes[i])