In [1]:
from __future__ import print_function, division
import os
import time
import torch
import pandas as pd
import numpy as np
# For showing and formatting images
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# For importing datasets into pytorch
import torchvision.datasets as dataset

# Used for dataloaders
from torch.utils.data import DataLoader

# For pretrained resnet34 model
import torchvision.models as models

# For optimisation function
import torch.nn as nn
import torch.optim as optim

# For turning data into tensors
import torchvision.transforms as transforms

# For loss function
import torch.nn.functional as F

# Tensor to wrap data in
from torch.autograd import Variable

In [2]:
PATH = '/home/cell/data/plant_seedlings/model/'
!ls {PATH+"train"}

Black-grass  Common Chickweed  Loose Silky-bent   Shepherds Purse
Charlock     Common wheat      Maize		  Small-flowered Cranesbill
Cleavers     Fat Hen	       Scentless Mayweed  Sugar beet


In [3]:
batch_size = 16
sz = 224

In [4]:
## Image loaders
## Dataset transforms puts the images in tensor form
normalise = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_raw = dataset.ImageFolder(PATH+"train", transform=transforms.Compose([transforms.RandomResizedCrop(sz),
                                                                            transforms.ToTensor(),
                                                                           normalise]))
train_loader = DataLoader(train_raw, batch_size=batch_size, shuffle=True, num_workers=4)

valid_raw = dataset.ImageFolder(PATH+"valid", transform=transforms.Compose([transforms.CenterCrop(sz),
                                                                            transforms.ToTensor(),
                                                                           normalise]))
valid_loader = DataLoader(valid_raw, batch_size=batch_size, shuffle=False, num_workers=4)

In [5]:
## Create resnet model
resnet34=models.resnet34(pretrained=True)

## Loss function and optimiser
criterion = nn.CrossEntropyLoss().cuda()
optimiser = optim.Adam(resnet34.fc.parameters(), lr=0.001, weight_decay=0.001)

In [6]:
def train(epochs):
    #epoch=1
    resnet34.train()
    time_secs = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        start_time = time.time()
        #print(batch_idx)
        data, target = Variable(data), Variable(target)
        optimiser.zero_grad()
        output = resnet34(data)
        loss=criterion(output, target)
        loss.backward()
        optimiser.step()
        time_secs += (time.time() - start_time)
        if batch_idx % 10 == 0:
            print("Run time for 10 batches was: ", time_secs)
            print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx*len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.data[0]))
            time_secs = 0
            #break;

In [7]:
def validation():
    resnet34.eval()
    test_loss = 0
    correct = 0
    for data, target in valid_loader:
        data, target = Variable(data, volatile = True), Variable(target)
        output=resnet34(data)
        test_loss += criterion(output, target).data[0]
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    
    test_loss /= len(valid_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(valid_loader.dataset),
    100. * correct / len(valid_loader.dataset)))

In [8]:
## Loop through epochs training data and then testing it
for epoch in range(1,10):
    train(epoch)
    validation()

Run time for 10 batches was:  7.7427427768707275
Run time for 10 batches was:  66.25943541526794
Run time for 10 batches was:  65.76920747756958
Run time for 10 batches was:  65.48485612869263
Run time for 10 batches was:  64.98099827766418
Run time for 10 batches was:  65.25183844566345
Run time for 10 batches was:  65.10970306396484
Run time for 10 batches was:  64.9295711517334
Run time for 10 batches was:  65.06325435638428
Run time for 10 batches was:  64.72999000549316
Run time for 10 batches was:  65.13281917572021
Run time for 10 batches was:  65.19238924980164
Run time for 10 batches was:  65.51887321472168
Run time for 10 batches was:  65.15035557746887
Run time for 10 batches was:  65.2229688167572
Run time for 10 batches was:  65.34444665908813
Run time for 10 batches was:  65.31076502799988
Run time for 10 batches was:  65.144784450531
Run time for 10 batches was:  65.23147296905518
Run time for 10 batches was:  65.19735479354858
Run time for 10 batches was:  65.1530547142

Run time for 10 batches was:  63.041414976119995
Run time for 10 batches was:  62.733288526535034
Run time for 10 batches was:  62.28198742866516
Run time for 10 batches was:  62.45811414718628
Run time for 10 batches was:  61.95737075805664
Run time for 10 batches was:  62.208385944366455
Run time for 10 batches was:  61.88788318634033
Run time for 10 batches was:  61.77302360534668
Run time for 10 batches was:  62.416311264038086
Run time for 10 batches was:  62.390942096710205
Run time for 10 batches was:  62.1340537071228

Test set: Average loss: 0.1102, Accuracy: 449/944 (48%)

Run time for 10 batches was:  8.151244401931763
Run time for 10 batches was:  63.00780272483826
Run time for 10 batches was:  62.145835876464844
Run time for 10 batches was:  61.77436590194702
Run time for 10 batches was:  61.813085079193115
Run time for 10 batches was:  61.57111573219299
Run time for 10 batches was:  61.37032604217529
Run time for 10 batches was:  62.57654285430908
Run time for 10 batches 

Run time for 10 batches was:  63.123897314071655
Run time for 10 batches was:  63.100257396698
Run time for 10 batches was:  62.973812103271484
Run time for 10 batches was:  63.01951885223389
Run time for 10 batches was:  62.76961588859558
Run time for 10 batches was:  62.86982703208923
Run time for 10 batches was:  62.62620949745178
Run time for 10 batches was:  62.786614418029785
Run time for 10 batches was:  62.75418972969055
Run time for 10 batches was:  63.024214029312134
Run time for 10 batches was:  62.81673240661621
Run time for 10 batches was:  62.79929184913635
Run time for 10 batches was:  62.98816227912903
Run time for 10 batches was:  62.897167444229126
Run time for 10 batches was:  62.93545961380005
Run time for 10 batches was:  62.776084661483765
Run time for 10 batches was:  62.94812893867493
Run time for 10 batches was:  62.87740516662598
Run time for 10 batches was:  62.92909240722656
Run time for 10 batches was:  62.73447275161743
Run time for 10 batches was:  62.936