# Optimizer

In [45]:
from torchvision import datasets, models, transforms
import torch.optim as optim
import torch.nn as nn
from torchvision.transforms import *
from torch.utils.data import DataLoader
import torch

num_classes = 3
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

params_to_update = model.parameters()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(params_to_update, lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

image_folder = datasets.ImageFolder('./shapes/train', transform=transforms.Compose([Resize(224), ToTensor()]))
dataloader = DataLoader(image_folder, batch_size=2, shuffle=True, num_workers=4)

num_epochs = 50

for epoch in range(num_epochs):
    optimizer.step()
    scheduler.step()
    model.train()
    
    running_loss = 0.0
    running_corrects = 0
    
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
        
            loss.backward()
            optimizer.step()
            
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = running_corrects.double() / len(dataloader)
        
    print(f'epoch {epoch}/{num_epochs} : {epoch_loss}, {epoch_acc}')

epoch 0/50 : 2.55140141248703, 0.7333333333333333
epoch 1/50 : 2.15029292901357, 0.8
epoch 2/50 : 2.1754257758458455, 0.8
epoch 3/50 : 2.1179486989974974, 1.0
epoch 4/50 : 1.9582115093866983, 1.0666666666666667
epoch 5/50 : 1.957203201452891, 1.0666666666666667
epoch 6/50 : 1.9669291893641154, 1.0666666666666667
epoch 7/50 : 2.1212442954381308, 0.8
epoch 8/50 : 2.0924832503000896, 1.2
epoch 9/50 : 2.0810962915420532, 1.0
epoch 10/50 : 1.9156936168670655, 1.1333333333333333
epoch 11/50 : 2.0182487964630127, 1.0666666666666667
epoch 12/50 : 2.1095704793930055, 0.9333333333333333
epoch 13/50 : 1.9573211352030435, 1.2
epoch 14/50 : 1.9201882243156434, 1.2
epoch 15/50 : 2.1121373494466145, 0.8666666666666667
epoch 16/50 : 2.174876602490743, 0.8666666666666667
epoch 17/50 : 1.9423249204953512, 1.2666666666666666
epoch 18/50 : 2.035244107246399, 1.0
epoch 19/50 : 1.9820697466532389, 1.0
epoch 20/50 : 2.1847590764363605, 0.8666666666666667
epoch 21/50 : 2.0501793225606284, 1.2
epoch 22/50 : 2.