In [None]:
%reload_ext autoreload
%autoreload 2

In [1]:
from fastai.vision.all import *
from fastai.callback.tracker import SaveModelCallback

from gepcore.utils import convolution, cell_graph
from gepcore.entity import Gene, Chromosome
from gepcore.symbol import PrimitiveSet
from nas_seg.seg_model import *
from nas_seg.utils import get_mask, overall_acc
from nas_seg.isprs_dataset import img_to_mask, mask_to_img
from gepnet.utils import count_parameters

from pygraphviz import AGraph
import glob
from skimage import io
from sklearn.metrics import confusion_matrix

torch.backends.cudnn.benchmark = True

if torch.cuda.is_available():
  print("Great! Good to go!")
else:
  print('CUDA is not up!')

Great! Good to go!


In [10]:
labels = np.array(["imp. surf.", "buildings", "low veg.", "trees", "cars", "clutter"])
num_classes = len(labels) 

# define primitive set
pset = PrimitiveSet('cnn')
# add cellular encoding program symbols
pset.add_program_symbol(cell_graph.end)
pset.add_program_symbol(cell_graph.seq)
pset.add_program_symbol(cell_graph.cpo)
pset.add_program_symbol(cell_graph.cpi)

# add convolutional operations symbols
conv_symbol = convolution.get_symbol()
pset.add_cell_symbol(conv_symbol.sepconv3x3)
pset.add_cell_symbol(conv_symbol.sepconv5x5)
pset.add_cell_symbol(conv_symbol.dilconv3x3)
pset.add_cell_symbol(conv_symbol.dilconv5x5)
pset.add_cell_symbol(conv_symbol.maxpool3x3)
pset.add_cell_symbol(conv_symbol.avgpool3x3)

def gene_gen():
    return Gene(pset, 2)

ch = Chromosome(gene_gen, 4)
graph, comp_graphs = cell_graph.generate_comp_graph(ch)

cell_graph.save_graph(graph, 'nas_seg/comp_graphs/')
cell_graph.draw_graph(graph, 'nas_seg/comp_graphs/')

conf = arch_config(comp_graphs=comp_graphs,
                   channels=24,
                   input_size=256,
                   classes=num_classes)

net = Network(conf)
print(count_parameters(net))

38.781048


In [None]:
# net

In [11]:
labels = np.array(["imp. surf.", "buildings", "low veg.", "trees", "cars", "clutter"])
num_classes = len(labels) 

graph = [AGraph(g) for g in glob.glob('nas_seg/comp_graphs/*.dot')]
_, comp_graphs = cell_graph.generate_comp_graph(graph)

conf = arch_config(comp_graphs=comp_graphs, channels=24, classes=len(labels), input_size=128)

# net = get_net(conf)
net = Network(conf)
print(count_parameters(net))

38.781048


In [12]:
# dataset path
window_size = 128

dataset = 'Vaihingen' #'Potsdam'
if dataset == 'Potsdam':
    dataset_dir = Path.home()/'rs_imagery/ISPRS_DATASETS/{}'.format(dataset)
    trainset_dir = dataset.lower() + '_{}'.format(window_size) 
elif dataset == 'Vaihingen':
    dataset_dir = Path.home()/'rs_imagery/ISPRS_DATASETS/{}'.format(dataset)
    trainset_dir = dataset.lower() + '_{}'.format(window_size) 

# training set path
data_path = dataset_dir/'{}'.format(trainset_dir)
img_path = data_path/'images/train'

# img_dir = get_image_files(img_path)
# img = img_dir[9]
# msk = get_mask(img)
# msk = PILImage.create(mask_to_img(io.imread(msk)))
# img = PILImage.create(io.imread(img))

# img.show(), msk.show()

In [13]:
#[0.4769, 0.3227, 0.3191], [0.1967, 0.1358, 0.1300] -- 256
#[0.4776, 0.3226, 0.3189], [0.1816, 0.1224, 0.1185] -- 128

data = DataBlock(blocks=(ImageBlock, MaskBlock(codes=labels)),
    get_items=get_image_files,
    get_y=get_mask,
    splitter=RandomSplitter(seed=42),
    batch_tfms=[*aug_transforms(flip_vert=True, size=window_size), 
                Normalize.from_stats([0.4776, 0.3226, 0.3189], [0.1816, 0.1224, 0.1185])])

dls = data.dataloaders(img_path, bs=10)

In [14]:
weights = weight=torch.tensor([0.007, 0.008, 0.02, 0.02, 0.2, 1.0]).cuda()  # Weights for class balancing
loss_func = CrossEntropyLossFlat(weight=weights, axis=1)    # [0.001, 0.0009, 0.002, 0.002, 0.03, 1.0]    
save_best = SaveModelCallback(monitor='overall_acc')

learn = Learner(dls, net, wd=1e-4, metrics=overall_acc, model_dir=dataset_dir, cbs=save_best)
# learn.lr_find()

In [None]:
# learn.fit_flat_cos(20, 1e-3)
learn.fit_one_cycle(100, 1e-3)

epoch,train_loss,valid_loss,overall_acc,time
0,0.588024,0.577917,0.778113,18:48
1,0.543995,0.477555,0.811205,17:57
2,0.494655,0.461912,0.821404,17:56
3,0.450545,0.447786,0.822641,17:55
4,0.442649,0.395611,0.84464,17:54
5,0.429131,0.384202,0.850424,17:39
6,0.40117,0.351077,0.862934,17:35
7,0.391954,0.366566,0.856157,17:38
8,0.380354,0.335066,0.867553,17:38
9,0.349947,0.323866,0.871996,17:37


Better model found at epoch 0 with overall_acc value: 0.7781131863594055.
Better model found at epoch 1 with overall_acc value: 0.8112048506736755.
Better model found at epoch 2 with overall_acc value: 0.8214043974876404.
Better model found at epoch 3 with overall_acc value: 0.8226408362388611.
Better model found at epoch 4 with overall_acc value: 0.8446400761604309.
Better model found at epoch 5 with overall_acc value: 0.8504236340522766.
Better model found at epoch 6 with overall_acc value: 0.8629338145256042.
Better model found at epoch 8 with overall_acc value: 0.8675534725189209.
Better model found at epoch 9 with overall_acc value: 0.8719964623451233.
Better model found at epoch 10 with overall_acc value: 0.8753377199172974.
Better model found at epoch 11 with overall_acc value: 0.8778769969940186.
Better model found at epoch 12 with overall_acc value: 0.8789169788360596.
Better model found at epoch 14 with overall_acc value: 0.8881429433822632.
Better model found at epoch 16 wit

In [None]:
learn.export('nas_seg/model.pkl')

In [None]:
learn.show_results(alpha=1)

In [None]:
############## Testing ###################

In [None]:
graph = [AGraph(g) for g in glob.glob('nas_seg/comp_graphs/*.dot')]
_, comp_graphs = cell_graph.generate_comp_graph(graph)

conf = arch_config(comp_graphs=comp_graphs,
                   channels=32,
                   input_size=256,
                   classes=num_classes)

# net = get_net(conf)
net = Network(conf)
print(count_parameters(net))
# net

In [None]:
# testing set path
data_path = dataset_dir/'{}'.format(trainset_dir)
img_path = data_path/'images/'

def get_mask(x: Path):
    dset = x.parent.name
    path = x.parent.parent.parent/'masks'/dset
    name = x.name
    return (path/name)

# img_dir = get_image_files(img_path/'valid')
# img = img_dir[100]
# msk = get_mask(img)
# msk1 = PILImage.create(mask_to_img(io.imread(msk)))
# img = PILImage.create(io.imread(img))

# img.show(), msk.show()

In [None]:
def _parent_idxs(items, name):
    def _inner(items, name): return mask2idxs(Path(o).parent.name == name for o in items)
    return [i for n in L(name) for i in _inner(items,n)]

def parent_splitter(train_name='train', valid_name='valid'):
    "Split `items` from the grand parent folder names (`train_name` and `valid_name`)."
    def _inner(o, **kwargs):
        return _parent_idxs(o, train_name), _parent_idxs(o, valid_name)
    return _inner

In [None]:
data = DataBlock(blocks=(ImageBlock, MaskBlock(codes=labels)),
                 get_items=get_image_files,
                 get_y=get_mask,
                 splitter=parent_splitter(),
                 batch_tfms=[*aug_transforms(do_flip=False, size=window_size), 
                             Normalize.from_stats([0.4769, 0.3227, 0.3191], [0.1967, 0.1358, 0.1300])])

dls_ = data.dataloaders(img_path, bs=20)

In [None]:
model = Learner(dls_, net, wd=1e-4, metrics=overall_acc, model_dir=dataset_dir).load('model')
model.validate()

In [None]:
# model = load_learner('nas_seg/model.pkl', cpu=False)
# count_parameters(model.model)

In [None]:
# model.metrics= overall_acc
# model.dls = dls_

In [None]:
# preds, y = model.get_preds()

In [None]:
# overall_acc(preds,y)

In [None]:
# model.show_results()