In [None]:
from fastai.vision import *
from fastai.distributed import *
from fastai.metrics import error_rate
from fastai.callbacks import SaveModelCallback

if torch.cuda.is_available():
  print("Great! Good to go!")
else:
  print('CUDA is not up!')

In [None]:
from gepcore.utils import cell_graph, convolution
from gepcore.entity import Gene, Chromosome
from gepcore.symbol import PrimitiveSet
from gepnet.model import get_gepnet, arch_config
from gepnet.utils import count_parameters
torch.backends.cudnn.benchmark = True

In [None]:
from pygraphviz import AGraph
import glob

In [None]:
# get chromosme from fil1
gpath = '/home/cliff/ResearchProjects/geppy_nn/mlj_experiments/3-2-train/3-2-seed-2/best/indv_10'
graph = [AGraph(g) for g in glob.glob(gpath+'/*.dot')]
_, comp_graph = cell_graph.generate_comp_graph(graph)

#cell_graph.draw_graph(graph, 'nb_graphs/rs/run_4')
print(comp_graph)

In [None]:
# # generate new chromosome
# # define primitive set
# pset = PrimitiveSet('cnn')

# # add cellular encoding program symbols
# pset.add_program_symbol(cell_graph.end)
# pset.add_program_symbol(cell_graph.seq)
# pset.add_program_symbol(cell_graph.cpo)
# pset.add_program_symbol(cell_graph.cpi)

# # add convolutional operations symbols
# conv_symbol = convolution.get_symbol()
# pset.add_cell_symbol(conv_symbol.conv1x1)
# pset.add_cell_symbol(conv_symbol.conv3x3)
# pset.add_cell_symbol(conv_symbol.dwconv3x3)
# #pset.add_cell_symbol(conv_symbol.conv1x3)
# #pset.add_cell_symbol(conv_symbol.conv3x1)
# #pset.add_cell_symbol(conv_symbol.maxpool3x3)

# def gene_gen():
#     return Gene(pset, 3)

# ch = Chromosome(gene_gen, 4)
# graph, comp_graph = cell_graph.generate_comp_graph(ch)

# cell_graph.save_graph(graph, 'nb_graphs/rs/run_4')
# cell_graph.draw_graph(graph, 'nb_graphs/rs/run_4')

In [None]:
seed = 331
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
conf = arch_config(comp_graph=comp_graph,
                   depth_coeff=1.0,
                   width_coeff=1.0,
                   channels=16,
                   repeat_list=[3, 2, 1, 2],
                   classes=30)

net = get_gepnet(conf)
count_parameters(net)

In [None]:
path = Path("/home/cliff/rs_imagery/AID")
tfms = get_transforms(flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.)

bs = 128
data = (ImageList.from_folder(path/'train')
        .split_by_rand_pct(valid_pct=0.1, seed=seed) 
        .label_from_folder()
        .transform(tfms, size=224)
        .databunch(bs=bs, num_workers=num_cpus())
        .normalize())

In [None]:
#model_dir = '/home/cliff/ResearchProjects/models/random_search/'
learn = Learner(data, net, metrics=[error_rate, accuracy], model_dir=gpath).mixup()
learn.to_fp16()

In [None]:
learn.lr_find(end_lr=100)

learn.recorder.plot()

In [None]:
cb = SaveModelCallback(learn, every='improvement', monitor='accuracy', name='aid-model-s322')
learn.fit_one_cycle(500, 1e-2, wd=0.0004, callbacks=[cb]) 

In [None]:
################################# Testing ########################################

In [None]:
gpath = '/home/cliff/ResearchProjects/geppy_nn/mlj_experiments/3-2-train/3-2-seed-2/best/indv_10'
graph = [AGraph(g) for g in glob.glob(gpath+'/*.dot')]
_, comp_graph = cell_graph.generate_comp_graph(graph)


conf = arch_config(comp_graph=comp_graph,
               depth_coeff=1.0,
               width_coeff=1.0,
               channels=16,
               repeat_list=[3, 2, 1, 2],
               classes=30)

net = get_gepnet(conf)

print(count_parameters(net),'\n')


tfms = get_transforms(do_flip=False)
path = Path("/home/cliff/rs_imagery/AID/")

bs = 128

data = (ImageList.from_folder(path)
        .split_by_folder(train='train', valid='test')
        .label_from_folder()
        .transform(tfms, size=224)
        .databunch(bs=bs, num_workers=num_cpus())
        .normalize())

model = Learner(data, net, metrics=[accuracy, error_rate]).load(gpath+'/aid-model-s322')
_, acc, _ = model.validate()
print('%.2f' %(acc.item()*100))