In [1]:
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torch
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from utils import plot_batch, calculate_loss_and_accuracy
from FruitModel import FruitModel

In [2]:
batch_size = 512

In [3]:
train_transforms = transforms.Compose([
    transforms.ColorJitter(.05,.05,.05,.05),
    transforms.RandomGrayscale(p=0.1),
    transforms.Pad(15, fill=(255,255,255)),
    transforms.RandomCrop(100),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ToTensor()
])

In [4]:
validation_transforms = transforms.Compose([
    transforms.ToTensor()
])

In [5]:
train_data = ImageFolder('fruits-360/Training', train_transforms)

In [6]:
test_data = ImageFolder('fruits-360/Test', validation_transforms)

In [7]:
train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=0, shuffle=True)

In [8]:
test_loader = DataLoader(test_data, batch_size=batch_size, num_workers=0, shuffle=True)

batch, labels = next(iter(train_loader))
plot_batch(batch, 3, 3)

In [9]:
model = FruitModel()
model.load_state_dict(torch.load('models/fruit_light_net.pt')['state_dict'])
model = model.cuda()

In [10]:
criterion = nn.CrossEntropyLoss()

In [11]:
learning_rate = 0.0001
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

In [12]:
def train(epochs):
    
    print("Trainning is about to start...")

    for epoch in range(1, epochs+1):
        print('---- EPOCH {} ----'.format(epoch))
        step = 0
        for i, (inputs, labels) in enumerate(train_loader):
            
            model.train()

            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            step += 1

            if step % 100 == 0:
                print('[{}] Step {}: Loss: {}'.format(epoch, step, loss))
        model.eval()
        _, train_acc = calculate_loss_and_accuracy(train_loader, model, criterion)
        _, test_acc = calculate_loss_and_accuracy(test_loader, model, criterion)
        print('Epoch {} - Train Acc: {:.2f} Validation Acc: {:.2f}'.format(epoch, train_acc, test_acc))
    return test_acc

In [13]:
acc = train(30)

Trainning is about to start...
---- EPOCH 1 ----
[1] Step 100: Loss: 1.9974522590637207
Epoch 1 - Train Acc: 88.28 Validation Acc: 95.18
---- EPOCH 2 ----
[2] Step 100: Loss: 1.8689663410186768
Epoch 2 - Train Acc: 86.65 Validation Acc: 94.14
---- EPOCH 3 ----
[3] Step 100: Loss: 1.779833197593689
Epoch 3 - Train Acc: 88.48 Validation Acc: 94.27
---- EPOCH 4 ----
[4] Step 100: Loss: 1.8072493076324463
Epoch 4 - Train Acc: 87.57 Validation Acc: 93.49
---- EPOCH 5 ----
[5] Step 100: Loss: 1.607529878616333
Epoch 5 - Train Acc: 88.67 Validation Acc: 94.40
---- EPOCH 6 ----
[6] Step 100: Loss: 1.4780793190002441
Epoch 6 - Train Acc: 87.50 Validation Acc: 94.08
---- EPOCH 7 ----
[7] Step 100: Loss: 1.5266075134277344
Epoch 7 - Train Acc: 87.43 Validation Acc: 94.92
---- EPOCH 8 ----
[8] Step 100: Loss: 1.4918668270111084
Epoch 8 - Train Acc: 87.76 Validation Acc: 94.34
---- EPOCH 9 ----
[9] Step 100: Loss: 1.497037410736084
Epoch 9 - Train Acc: 88.09 Validation Acc: 94.79
---- EPOCH 10 ----

KeyboardInterrupt: 

In [16]:
torch.save({'state_dict':model.state_dict(), 'acc': acc}, 'models/fruit_light_net.pt')

In [11]:
test_data.class_to_idx

{'Apple Braeburn': 0,
 'Apple Crimson Snow': 1,
 'Apple Golden 1': 2,
 'Apple Golden 2': 3,
 'Apple Golden 3': 4,
 'Apple Granny Smith': 5,
 'Apple Pink Lady': 6,
 'Apple Red 1': 7,
 'Apple Red 2': 8,
 'Apple Red 3': 9,
 'Apple Red Delicious': 10,
 'Apple Red Yellow 1': 11,
 'Apple Red Yellow 2': 12,
 'Apricot': 13,
 'Avocado': 14,
 'Avocado ripe': 15,
 'Banana': 16,
 'Banana Lady Finger': 17,
 'Banana Red': 18,
 'Beetroot': 19,
 'Blueberry': 20,
 'Cactus fruit': 21,
 'Cantaloupe 1': 22,
 'Cantaloupe 2': 23,
 'Carambula': 24,
 'Cauliflower': 25,
 'Cherry 1': 26,
 'Cherry 2': 27,
 'Cherry Rainier': 28,
 'Cherry Wax Black': 29,
 'Cherry Wax Red': 30,
 'Cherry Wax Yellow': 31,
 'Chestnut': 32,
 'Clementine': 33,
 'Cocos': 34,
 'Corn': 35,
 'Corn Husk': 36,
 'Cucumber Ripe': 37,
 'Cucumber Ripe 2': 38,
 'Dates': 39,
 'Eggplant': 40,
 'Fig': 41,
 'Ginger Root': 42,
 'Granadilla': 43,
 'Grape Blue': 44,
 'Grape Pink': 45,
 'Grape White': 46,
 'Grape White 2': 47,
 'Grape White 3': 48,
 'Grap

In [14]:
acc

NameError: name 'acc' is not defined