In [None]:
import numpy as np
import os
import copy
import subprocess
from sklearn.metrics import accuracy_score
import renom as rm
from renom.cuda import set_cuda_active
set_cuda_active(True)
import renom_img
from renom_img.api.classification.resnext import ResNeXt50
from renom_img.api.utility.augmentation import Augmentation
from renom_img.api.utility.augmentation.process import *
from datetime import datetime

In [None]:
os.mkdir('model_logs')
date = str(datetime.date(datetime.now()))
root='/mnt/research/dataset/Caltech/101_ObjectCategories'

In [None]:
def prepare_data(path):
    class_map = sorted(os.listdir(root))[1:]

    image_path_list = []
    label_list = []

    for i, c in enumerate(class_map):
        root_path = os.path.join(root, c)
        img_files = os.listdir(root_path)
        image_path_list.extend([os.path.join(root_path, path) for path in img_files])
        label_list += [i]*len(img_files)

    N = len(image_path_list)
    perm = np.random.permutation(N)
    train_N = int(N*0.8)

    train_image_path_list = [image_path_list[p] for p in perm[:train_N]]
    train_label_path_list = [label_list[p] for p in perm[:train_N]]

    valid_image_path_list = [image_path_list[p] for p in perm[train_N:]]
    valid_label_path_list = [label_list[p] for p in perm[train_N:]]
    
    return class_map, train_image_path_list, train_label_path_list, valid_image_path_list, valid_label_path_list

In [None]:
cmap, train_x, train_y, valid_x, valid_y = prepare_data(root)

In [None]:
model = ResNeXt50(cmap, load_pretrained_weight=True, train_whole_network=True)

In [None]:
aug = Augmentation([
    Shift(10,10),
    RandomCrop(padding=4),
    Flip(),
    ContrastNorm(),
])

In [None]:
def end_function(*args):
    if len(args)>0:
        model = args[1]
        train_list = args[2]
        validation_loss_list = args[3]
        epoch = args[0]
        if len(validation_loss_list)>1:
            tmp = copy.deepcopy(validation_loss_list)
            current_loss = tmp[-1]
            del(tmp[-1])
            tmp.sort()
            if(current_loss<tmp[0]):
                predicted = model.predict(valid_x)
                accuracy = accuracy_score(valid_y, predicted)
                fp = open('model_logs/resnext50@'+date+'.txt','a+')
                fp.write('Epoch: {:03d} Train Loss: {:3.2f}  Valid Loss: {:3.2f} Accuracy: {:3.2f} \n'.format(epoch,float(train_list[-1]),float(validation_loss_list[-1]),float(accuracy)))
                fp.close()

In [None]:
# Hyperparameters
total_epoch = 100
batch = 28
imsize = model.imsize
multiscale = None
optimizer = model._opt.__class__
augmentation = [str(name.__class__).split('.')[-1] for name in aug._process_list]
evaluation_matrix = "Accuracy"
dataset = "Caltech_101"
standard = 0.0
load_pretrained=True
train_whole=True
renom_v = rm.__version__
renom_img_v = renom_img.__version__
commit_id = str(subprocess.check_output(['git','rev-parse','HEAD']))

In [None]:
# write hyperparameters to file
fp = open('model_logs/resnext50@'+date+'.txt','a+')
fp.write('Commit Hash: '+commit_id[2:-3]+'\nReNom version: '+renom_v+'\nReNomIMG version: '+renom_img_v)
fp.write('\nExpected score: {:3.2f}\n'.format(float(standard)))
fp.write('\n===================================================Hyperparameters=======================================================\n')
fp.write('\nTotal epoch: {:03d}\nBatch size: {:03d}\nImage size: ({:03d},{:03d})'.format(total_epoch,batch,imsize[0],imsize[1]))
fp.write('\nMultiscale: '+str(multiscale)+'\nOptimizer: '+str(optimizer).split('.')[-1]+'\nAugmentation: '+str(augmentation))
fp.write('\nEvaluation matrix: '+str(evaluation_matrix)+'\nDataset: '+str(dataset))
fp.write('\nLoad Pretrained weight: '+str(load_pretrained)+'\nTrain whole network: '+str(train_whole))
fp.write('\n==========================================================================================================================\n\n')
fp.close()

In [None]:
model.fit(train_x,train_y,valid_x,valid_y,batch_size=batch,epoch=total_epoch,augmentation=aug,callback_end_epoch=end_function)

fp = open('model_logs/resnext50@'+date+'.txt','a')
fp.write('\nSuccess')
fp.close()