# Bagnet Training
This notebook trains a bagnet17 model on the plant dataset

In [1]:
import time
import os

import bagnets.pytorchnet
import numpy as np
import pandas as pd
from plotnine import *
import sklearn.model_selection
from sklearn.metrics import accuracy_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import torch.backends.cudnn as cudnn
import tqdm

In [2]:
seed = 42

np.random.seed(seed)                                                                       
torch.manual_seed(seed)                                                                    
if torch.backends.cudnn.enabled:                                                                
    torch.backends.cudnn.deterministic = True                                                   
    torch.backends.cudnn.benchmark = False

In [None]:
data_dir = '../data/'

data_transform = transforms.Compose([
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])

image_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform = data_transform)

train_indices, tune_indices = sklearn.model_selection.train_test_split(np.arange(len(image_dataset)), 
                                                                       test_size=.1, random_state=seed)

train_data = torch.utils.data.Subset(image_dataset, train_indices)
tune_data = torch.utils.data.Subset(image_dataset, tune_indices)

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)
tune_dataloader = torch.utils.data.DataLoader(tune_data)

dataset_size = len(image_dataset)
class_names = image_dataset.classes

device = torch.device("cuda")

model = bagnet.pytorchnet.bagnet17(num_classes = 12)

criterion = nn.CrossEntropyLoss()
criterion = criterion.cpu()
optimizer = optim.Adam(model.parameters(), lr=0.01)

model.to(device)
overall_step = 0
epochs = 50

best_tune_loss = None
tune_losses = []
train_losses = []
tune_accuracies = []

for epoch in tqdm.tqdm_notebook(range(epochs)):  # loop over the dataset multiple times
    train_loss = 0
    tune_loss = 0
    tune_correct = 0
    for data in tqdm.tqdm_notebook(train_dataloader, total=len(train_dataloader)):
        # get the inputs
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
    
    for data in tune_dataloader:
        with torch.no_grad():
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            predicted = torch.argmax(outputs.data, 1)
            
            tune_loss += criterion(outputs, labels)
            tune_correct += accuracy_score(predicted.cpu(), labels.cpu(), normalize=False)
    
    tune_losses.append(tune_loss)
    train_losses.append(train_loss)

    accuracy = tune_correct / len(tune_data)
    tune_accuracies.append(accuracy)
    
    print('Epoch:\t{}\tTrain Loss:\t{}\tTune Loss:\t{}\tTune Acc:\t{}'.format(epoch, train_loss, 
                                                                              tune_loss, accuracy))
            
    # Save model
    if best_tune_loss is None or tune_loss < best_tune_loss:
        best_tune_loss = tune_loss
        torch.save(model.state_dict(), '../results/trained_bagnet.pkl')

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

HBox(children=(IntProgress(value=0, max=268), HTML(value='')))


Epoch:	0	Train Loss:	548.6915876865387	Tune Loss:	1468.51416015625	Tune Acc:	0.12421052631578948


HBox(children=(IntProgress(value=0, max=268), HTML(value='')))


Epoch:	1	Train Loss:	402.03591734170914	Tune Loss:	1301.36865234375	Tune Acc:	0.14105263157894737


HBox(children=(IntProgress(value=0, max=268), HTML(value='')))


Epoch:	2	Train Loss:	315.882990449667	Tune Loss:	1551.898193359375	Tune Acc:	0.11578947368421053


HBox(children=(IntProgress(value=0, max=268), HTML(value='')))

In [None]:
metrics = {'epochs': list(range(epochs)), 'tune_losses': tune_losses, 'train_losses': train_losses,
          'tune_accuracies': tune_accuracies}
metric_df = pd.DataFrame.from_dict(metrics)

ggplot(metric_df, aes(x = 'epochs', y = 'tune_losses')) + geom_line() + ggtitle('Tune loss')

In [None]:
ggplot(metric_df, aes(x = 'epochs', y = 'train_losses')) + geom_line() + ggtitle('Train loss')

In [None]:
ggplot(metric_df, aes(x = 'epochs', y = 'tune_accuracies')) + geom_line() + ggtitle('Tune accuracy')