In [1]:
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from utils import plot_batch, calculate_loss_and_accuracy
from FruitModel import FruitModel

In [2]:
batch_size = 512

In [3]:
train_transforms = transforms.Compose([
    transforms.ColorJitter(.15,.15,.15,.15),
    transforms.RandomGrayscale(p=0.15),
    transforms.Pad(25, fill=(255,255,255)),
    transforms.RandomCrop(100),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(25),
    transforms.ToTensor()
])

In [4]:
validation_transforms = transforms.Compose([
    transforms.ToTensor()
])

In [5]:
train_data = ImageFolder('fruits-360/Training', validation_transforms)

In [6]:
test_data = ImageFolder('fruits-360/Test', validation_transforms)

In [7]:
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=0, shuffle=True)

In [8]:
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=0, shuffle=True)

batch, labels = next(iter(train_loader))
plot_batch(batch, 3, 3)

In [9]:
model = FruitModel()
model.load_state_dict(torch.load('models/fruit_light_net.pt')['state_dict'])
model = model.cuda()

In [10]:
criterion = nn.CrossEntropyLoss()

In [11]:
learning_rate = 0.001
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

In [12]:
def train(epochs):
    
    print("Trainning is about to start...")

    for epoch in range(1, epochs+1):
        print('---- EPOCH {} ----'.format(epoch))
        step = 0
        for i, (inputs, labels) in enumerate(train_loader):
            
            model.train()

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            step += 1

            if step % 100 == 0:
                print('[{}] Step {}: Loss: {}'.format(epoch, step, loss))
        model.eval()
        _, train_acc = calculate_loss_and_accuracy(train_loader, model, criterion)
        _, test_acc = calculate_loss_and_accuracy(test_loader, model, criterion)
        print('Epoch {} - Train Acc: {:.2f} Validation Acc: {:.2f}'.format(epoch, train_acc, test_acc))
    return test_acc

In [13]:
acc = train(10)

Trainning is about to start...
---- EPOCH 1 ----
[1] Step 100: Loss: 4.3764495849609375
Epoch 1 - Train Acc: 58.14 Validation Acc: 54.10
---- EPOCH 2 ----
[2] Step 100: Loss: 3.9979331493377686
Epoch 2 - Train Acc: 84.44 Validation Acc: 76.76
---- EPOCH 3 ----
[3] Step 100: Loss: 3.633695363998413
Epoch 3 - Train Acc: 93.36 Validation Acc: 83.72
---- EPOCH 4 ----
[4] Step 100: Loss: 3.3659820556640625
Epoch 4 - Train Acc: 95.90 Validation Acc: 88.54
---- EPOCH 5 ----
[5] Step 100: Loss: 3.167135238647461
Epoch 5 - Train Acc: 96.16 Validation Acc: 90.04
---- EPOCH 6 ----
[6] Step 100: Loss: 3.336150646209717
Epoch 6 - Train Acc: 97.53 Validation Acc: 91.67
---- EPOCH 7 ----
[7] Step 100: Loss: 3.1100547313690186
Epoch 7 - Train Acc: 98.18 Validation Acc: 92.71
---- EPOCH 8 ----
[8] Step 100: Loss: 3.0926501750946045
Epoch 8 - Train Acc: 98.63 Validation Acc: 93.23
---- EPOCH 9 ----
[9] Step 100: Loss: 3.067779064178467
Epoch 9 - Train Acc: 98.89 Validation Acc: 92.97
---- EPOCH 10 ----


In [15]:
torch.save({'state_dict':model.state_dict(), 'acc': acc}, 'models/fruit_light_net.pt')