In [None]:
import os
import torch
import numpy as np

from functions import models
from functions import txt_gen 
from functions import data_view
from functions import train_model
from configs.plantseed_config import config
from functions.dataloaders import ListDataset
from functions.vis_error import error_checking

In [None]:
opt = config()
opt.batch_size = 32
opt.LR = 0.002
opt.epochs = 5
opt.use_checkpoint = True

In [None]:
txt_gen.gen_train_valid(opt.home_loc+"input/")

In [None]:
print("Load dataloaders....")
trainset = ListDataset(opt, train = "train")
train_loader = torch.utils.data.DataLoader(trainset, batch_size = opt.batch_size, shuffle=True, num_workers = 2)
valset = ListDataset(opt, train = "valid")
valid_loader = torch.utils.data.DataLoader(valset, batch_size = opt.batch_size, shuffle=True, num_workers = 2)
dataloaders = {'train':train_loader, 'valid':valid_loader}
dataset_sizes = {'train': len(train_loader.dataset), 'valid': len(valid_loader.dataset)}

train_dir = opt.home_loc + "input/train"
classes = os.listdir(train_dir)
classes = sorted(classes, key = lambda item: (int(item.partition(' ')[0]) if item[0].isdigit() else float('inf'), item))
print(classes)

In [None]:
data_view.vis_unnormalise(dataloaders, classes)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
model_conv = models.resnet50(12, pretrained = True)
model_conv = model_conv.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_conv.parameters(), lr = opt.LR, momentum=opt.momentum, 
                      nesterov=opt.nesterov, weight_decay=opt.weight_decay)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.3)

In [None]:
params = {
    'model':model_conv,
    'criterion' : criterion,
    'optimizer' : optimizer,
    'scheduler' : exp_lr_scheduler,
    'dataloaders' : dataloaders,
    'dataset_sizes' : dataset_sizes,
    'use_checkpoint' : opt.use_checkpoint,
    'epoch' : opt.epochs,
    'device' : device
}

In [None]:
model_conv, train_loss, val_loss, train_acc, val_acc, get_lr = train_model.train(params)

In [None]:
train_model.classwise_accuracy(model_conv, dataloaders, classes, device)

In [None]:
import inference

In [None]:
exm = error_checking()

In [None]:
exm.worst_prediction("Black-grass", num=5, imgs_per_row =3)

In [None]:
exm.best_prediction("Loose Silky-bent", 15, 4)

In [None]:
from matplotlib import pyplot as plt

epchs = np.linspace(0,len(train_loss), num = len(train_loss))

plt.figure(1)
plt.plot(epchs, train_loss, label='train_loss')
plt.plot(epchs, val_loss, label='val_loss')
plt.xlabel("epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()
plt.close()

In [None]:
plt.figure(2)
plt.plot(epchs, get_lr)
plt.xlabel("epochs")
plt.ylabel("LR")
plt.show()
plt.close()

In [None]:
plt.figure(3)
plt.plot(epchs, train_acc, label = 'train_acc')
plt.plot(epchs, val_acc, label = 'val_acc')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.legend()
plt.show()
plt.close()