In [1]:
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from utils import plot_batch, calculate_loss_and_accuracy
from FruitModel import FruitModel

In [2]:
batch_size = 512

In [3]:
train_transforms = transforms.Compose([
    transforms.ColorJitter(.1,.1,.1,.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.Pad(20, fill=(255,255,255)),
    transforms.RandomCrop(100),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(20),
    transforms.ToTensor()
])

In [4]:
validation_transforms = transforms.Compose([
    transforms.ToTensor()
])

In [5]:
train_data = ImageFolder('fruits-360/Training', train_transforms)

In [6]:
test_data = ImageFolder('fruits-360/Test', validation_transforms)

In [7]:
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=0, shuffle=True)

In [8]:
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=0, shuffle=True)

batch, labels = next(iter(train_loader))
plot_batch(batch, 3, 3)

In [9]:
model = FruitModel()
model.load_state_dict(torch.load('models/fruit_net.pt'))
model = model.cuda()

In [10]:
criterion = nn.CrossEntropyLoss()

In [11]:
learning_rate = 0.001
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

In [12]:
def train(epochs):
    
    print("Trainning is about to start...")

    for epoch in range(1, epochs+1):
        print('---- EPOCH {} ----'.format(epoch))
        step = 0
        for i, (inputs, labels) in enumerate(train_loader):
            
            model.train()

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            step += 1

            if step % 100 == 0:
                print('[{}] Step {}: Loss: {}'.format(epoch, step, loss))
        model.eval()
        _, train_acc = calculate_loss_and_accuracy(train_loader, model, criterion)
        _, test_acc = calculate_loss_and_accuracy(test_loader, model, criterion)
        print('Epoch {} - Train Acc: {:.2f} Validation Acc: {:.2f}'.format(epoch, train_acc, test_acc))
    return test_acc

In [13]:
train(30)
learning_rate = learning_rate/3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
train(30)
learning_rate = learning_rate/3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
train(30)
learning_rate = learning_rate/3
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
train(30)

Trainning is about to start...
---- EPOCH 1 ----
[1] Step 100: Loss: 4.8897833824157715
Epoch 1 - Train Acc: 5.66 Validation Acc: 7.03
---- EPOCH 2 ----
[2] Step 100: Loss: 4.8199334144592285
Epoch 2 - Train Acc: 11.20 Validation Acc: 13.87
---- EPOCH 3 ----
[3] Step 100: Loss: 4.736672878265381
Epoch 3 - Train Acc: 16.73 Validation Acc: 24.22
---- EPOCH 4 ----
[4] Step 100: Loss: 4.692636013031006
Epoch 4 - Train Acc: 19.60 Validation Acc: 25.91
---- EPOCH 5 ----
[5] Step 100: Loss: 4.5778889656066895
Epoch 5 - Train Acc: 27.08 Validation Acc: 39.97
---- EPOCH 6 ----
[6] Step 100: Loss: 4.493751049041748
Epoch 6 - Train Acc: 33.20 Validation Acc: 45.44
---- EPOCH 7 ----
[7] Step 100: Loss: 4.486231327056885
Epoch 7 - Train Acc: 34.11 Validation Acc: 47.46
---- EPOCH 8 ----
[8] Step 100: Loss: 4.46613073348999
Epoch 8 - Train Acc: 39.19 Validation Acc: 54.62
---- EPOCH 9 ----
[9] Step 100: Loss: 4.417186737060547
Epoch 9 - Train Acc: 43.75 Validation Acc: 63.48
---- EPOCH 10 ----
[10] 

KeyboardInterrupt: 

In [14]:
torch.save(model.state_dict(), 'models/fruit_net.pt')