**Author:** Boris Kundu

**Problem Statement:** Train and compare different optimizers.

**Dataset:** Iris

In [17]:
#Import packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn import datasets

In [18]:
#Read data
iris = datasets.load_iris()

In [19]:
#Define input parameters
n1 = len(iris.feature_names)  # input size
k = len(iris.target_names)    # output size
n2 = 5                        # hidden layer size

In [20]:
#Class to define model
class Model(nn.Module):
    #Initialize
    def __init__(self, datasize, hiddensize, outputsize):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(datasize, hiddensize)
        self.layer2 = nn.Linear(hiddensize, outputsize)
    #Feed forward
    def forward(self, x):
        x = F.sigmoid(self.layer1(x))
        return self.layer2(x)

In [21]:
#Define inputs and output
X = torch.tensor(iris["data"], dtype=torch.float)
target = torch.tensor(iris["target"], dtype=torch.long)

In [22]:
#Define system parameters
alpha = 0.9 #Momentum
eta = 0.01 #Learning rate
epochs = 1000 #Iterations

In [23]:
#Initialize model
model = Model(n1, n2, k)

In [24]:
#Define different optimizers for comparison
adaGrad = optim.Adagrad(model.parameters(), lr=eta)
rmsProp = optim.RMSprop(model.parameters(), lr=eta)
adam = optim.Adam(model.parameters(), lr=eta)
adamW = optim.AdamW(model.parameters(), lr=eta)

In [25]:
#Make predictions
def predict(features,target_class,my_model,msg):
    o2 = my_model(X)
    ypred = o2.argmax(axis=1)
    print(f'Predictions using {msg} are:\n{ypred}')
    matches = torch.eq(ypred, target).int().sum()
    print(f'Matches using {msg} are:{matches.item()}')

In [26]:
#Train model using optimizer
def train(features,target_class,my_model,opt,msg):
    for i in range(epochs):
        o2 = my_model(features)
        L = F.cross_entropy(o2, target_class)
        if (i%100 == 0):
            print(f'Loss:{L.item()} at Epoch:{i}')
        opt.zero_grad()
        L.backward()
        opt.step()
    #Predict
    predict(features,target_class,my_model,msg)

In [27]:
#Train AdaGrad
train(X,target,model,adaGrad,'AdaGrad')

Loss:1.0684064626693726 at Epoch:0
Loss:0.8466913104057312 at Epoch:100
Loss:0.7325554490089417 at Epoch:200
Loss:0.667316198348999 at Epoch:300
Loss:0.6239688992500305 at Epoch:400
Loss:0.5922978520393372 at Epoch:500
Loss:0.5676078796386719 at Epoch:600
Loss:0.5474175214767456 at Epoch:700
Loss:0.5302877426147461 at Epoch:800
Loss:0.5153290629386902 at Epoch:900
Predictions using AdaGrad are:
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1,
        2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2])
Matches using AdaGrad are:145


In [28]:
#Train RMSProp
train(X,target,model,rmsProp,'RMSProp')

Loss:0.5019610524177551 at Epoch:0
Loss:0.15814290940761566 at Epoch:100
Loss:0.1010393351316452 at Epoch:200
Loss:0.07857422530651093 at Epoch:300
Loss:0.0671488493680954 at Epoch:400
Loss:0.060310568660497665 at Epoch:500
Loss:0.05609855800867081 at Epoch:600
Loss:0.05328028276562691 at Epoch:700
Loss:0.05130861699581146 at Epoch:800
Loss:0.049885619431734085 at Epoch:900
Predictions using RMSProp are:
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1,
        2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2])
Matches using RMSProp are:146


In [29]:
#Train Adam
train(X,target,model,adam,'Adam')

Loss:0.048825278878211975 at Epoch:0
Loss:0.04443224146962166 at Epoch:100
Loss:0.042826540768146515 at Epoch:200
Loss:0.04165118187665939 at Epoch:300
Loss:0.040784478187561035 at Epoch:400
Loss:0.04017601162195206 at Epoch:500
Loss:0.039768002927303314 at Epoch:600
Loss:0.039468422532081604 at Epoch:700
Loss:0.03916464000940323 at Epoch:800
Loss:0.038777656853199005 at Epoch:900
Predictions using Adam are:
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2])
Matches using Adam are:148


In [30]:
#Train AdamW
train(X,target,model,adamW,'AdamW')

Loss:0.03830048069357872 at Epoch:0
Loss:0.03827093914151192 at Epoch:100
Loss:0.03821292519569397 at Epoch:200
Loss:0.038118284195661545 at Epoch:300
Loss:0.03798876702785492 at Epoch:400
Loss:0.03782551735639572 at Epoch:500
Loss:0.037630002945661545 at Epoch:600
Loss:0.03740396723151207 at Epoch:700
Loss:0.03714967146515846 at Epoch:800
Loss:0.036869149655103683 at Epoch:900
Predictions using AdamW are:
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2])
Matches using AdamW are:148
