In [1]:
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import traceback

import numpy as np

%matplotlib widget
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix

import tensorflow as tf

from tensorflow.keras.models     import Model
from tensorflow.keras.layers     import Input, Dense, Lambda, Concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras            import backend as K
from tensorflow.keras.callbacks  import EarlyStopping

from scipy.ndimage import sobel

In [2]:
%%html
<style>
    .output_wrapper button.btn.btn-default, 
    .output_wrapper .ui-dialog-titlebar { 
        display: none; 
    }
</style>

In [3]:
class View:
    def __init__(self, game, grid):
        self.game = game
        self.axis = self.game.figure.add_subplot(grid)
        View.update(self)
        
    def update(self):
        plt.sca(self.axis)

In [4]:
class ViewBoard(View):
    def __init__(self, game, axis):
        View.__init__(self, game, axis)
        
        plt.xlim([-1,1])
        plt.ylim([-1,1])
        plt.axis('off')
        self.territory_predictions_handle = plt.imshow(self.get_territory_predictions(), 
                                                       extent=(-1,1,1,-1), vmin=-0.5, vmax=0.5, 
                                                       interpolation='bilinear', cmap=self.game.player_colour_map)
        
        self.territory_predictions_contour_handles = self.get_contour_handles()
        
        self.player_scatter_handles = []
        for player_idx in range(self.game.num_players):
            hnd = plt.scatter(self.game.player_pieces[player_idx][:,0], 
                              self.game.player_pieces[player_idx][:,1], 
                              s=250, linewidths=[2], c=[self.game.player_colours[player_idx]], edgecolors='w')
            self.player_scatter_handles += [ hnd ]
            
        self.player_scatter_end_handles = []
        for player_idx in range(self.game.num_players):
            hnd = plt.scatter([], [], s=50, c='k')
            self.player_scatter_end_handles += [ hnd ]
            
        self.cursor_x_handle = plt.axvline(x=0., color='w')
        self.cursor_y_handle = plt.axhline(y=0., color='w')
    
    def update_territory_predictions(self):
        View.update(self)
        
        self.territory_predictions_handle.set_data(self.get_territory_predictions())
        
        for handle in self.territory_predictions_contour_handles:
            for elem in handle.collections:
                elem.remove()
        
        self.territory_predictions_contour_handles = self.get_contour_handles()
        
    def update_player_scatter(self):
        View.update(self)
        
        for player_idx in range(self.game.num_players):
            self.player_scatter_handles[player_idx].set_offsets(self.game.player_pieces[player_idx])
            
        for player_idx in range(self.game.num_players):
            if (self.game.player_pieces[player_idx].shape[0] > 0):
                self.player_scatter_end_handles[player_idx].set_offsets(self.game.player_pieces[player_idx][0,:])
            
    def update_cursor(self):
        self.cursor_x_handle.set_data([self.game.cursor_x, self.game.cursor_x], [0, 1])
        self.cursor_y_handle.set_data([0, 1], [self.game.cursor_y, self.game.cursor_y])
        
    def get_territory_predictions(self):
        return (self.game.territory_predictions-0.5)*0.8
    
    def get_contour_handles(self):
        x = np.linspace(-1,1,self.game.territory_resolution)
        levels = np.linspace(-0.5, 0.5, 8)
        
        z = self.get_territory_predictions()
        
        return [
            plt.contour(x, x, z, levels=levels, linewidths=1.5, colors='w'),
            
            #plt.contour(x, x, z, 
            #            levels=levels, linewidths=1, 
            #            cmap=self.game.player_colour_map)
        ]

In [5]:
class ViewTrain(View):
    def __init__(self, game, axis):
        View.__init__(self, game, axis)
        
        plt.title(self.get_title())
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        self.train_accuracy_handle, = plt.plot(*self.get_train_accuracy())
    
    def update(self):
        View.update(self)
        
        plt.title(self.get_title())
        self.train_accuracy_handle.set_data(*self.get_train_accuracy())
        
        self.axis.relim()
        self.axis.autoscale_view()
        
    def get_title(self):
        if len(self.game.train_accuracy) == 0: return 'Train Accuracy'
        return 'Train Accuracy %.2f' % (self.game.train_accuracy[-1])
    
    def get_train_accuracy(self):
        return list(range(len(self.game.train_accuracy))), self.game.train_accuracy

In [6]:
class ViewConfusion(View):
    def __init__(self, game, axis):
        View.__init__(self, game, axis)   
        
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted label')
        plt.ylabel('True label')
        self.confusion_matrix_handle = plt.imshow(-self.game.confusion_matrix, 
                                                  vmin=-self.game.confusion_matrix.max(), 
                                                  vmax=self.game.confusion_matrix.max(),
                                                  interpolation='nearest', cmap=self.game.player_colour_map)
        
        plt.xticks(np.arange(self.game.num_players), self.game.player_names)
        plt.yticks(np.arange(self.game.num_players), self.game.player_names, rotation=45)
        
        self.text_handles = {}
        for i in range(self.game.confusion_matrix.shape[0]):
            for j in range(self.game.confusion_matrix.shape[1]):
                plt.scatter(j, i-0.05, s=1000, c='w')
                self.text_handles[(i,j)] = plt.text(j, i, self.get_confusion_matrix(i, j), 
                                                    horizontalalignment='center', color='black')
    
    def update(self):
        View.update(self)
        
        self.confusion_matrix_handle.set_data(-self.game.confusion_matrix)
        self.confusion_matrix_handle.set_clim(vmin=-self.game.confusion_matrix.max(), 
                                              vmax=self.game.confusion_matrix.max())
        
        for i in range(self.game.confusion_matrix.shape[0]):
            for j in range(self.game.confusion_matrix.shape[1]):
                self.text_handles[(i,j)].set_text(self.get_confusion_matrix(i, j))
                
    def get_confusion_matrix(self, i, j):
        return '%d' % (self.game.confusion_matrix[i, j])

In [7]:
class ViewPie(View):
    def __init__(self, game, axis):
        View.__init__(self, game, axis)  
        
        plt.title(self.get_title())
        self.pie_handle = plt.pie(self.game.territory_proportions, colors=self.game.player_colours)
    
    def update(self):
        View.update(self)  
        
        plt.cla()
        plt.title(self.get_title())
        self.pie_handle = plt.pie(self.game.territory_proportions, colors=self.game.player_colours)
        
    def get_title(self):
        return 'Territory ' + ' | '.join(['%.2f' % (v) for v in self.game.territory_proportions])

In [8]:
class ViewScore(View):
    def __init__(self, game, axis):
        View.__init__(self, game, axis)  
        
        plt.title('Score')
        plt.xlabel('Round')
        plt.ylabel('Territory')
        
        self.round_score_handles = []
        for player_idx in range(self.game.num_players):
            hnd, = plt.plot(*self.get_round_score(player_idx), '-o', color=self.game.player_colours[player_idx])
            self.round_score_handles += [hnd]
        
    def update(self):
        View.update(self)  
        
        for player_idx in range(self.game.num_players):
            self.round_score_handles[player_idx].set_data(*self.get_round_score(player_idx))
        
        self.axis.relim()
        self.axis.autoscale_view()
        
    def get_round_score(self, player_idx):
        return list(range(len(self.game.round_scores[player_idx]))), self.game.round_scores[player_idx]

In [9]:
class ViewMessage(View):
    def __init__(self, game, axis):
        View.__init__(self, game, axis)  
        
        plt.axis('off')
        plt.xlim([-1,1])
        plt.ylim([-1,1])
        self.text_handle = plt.text(0,0, self.get_message(), horizontalalignment='center', color='black')
        
    def update(self):
        View.update(self)  
        
        self.text_handle.set_text(self.get_message())
        
    def get_message(self):
        return '\n'.join(self.game.message)

In [10]:
class TerritoryModel:
    def __init__(self, game):
        self.game = game
        
        inputs   = Input(shape=(2,))
    
        f_sin    = Lambda(lambda x : K.sin(x))(inputs)
        f_sq     = Lambda(lambda x : K.square(x))(inputs)
        f_corr   = Lambda(lambda x : K.prod(x, axis=-1, keepdims=True))(inputs)
        features = Concatenate()([inputs, f_sin, f_sq, f_corr])

        x        = Dense(10, activation='relu'   )(features) 
        x        = Dense(10, activation='relu'   )(x) 
        outputs  = Dense(1,  activation='sigmoid')(x) 

        self.model = Model(inputs=inputs, outputs=outputs)
        self.model.compile(Adam(), loss='binary_crossentropy', metrics=['binary_accuracy'])
    
    def reset(self):
        session = K.get_session()
        for layer in self.model.layers: 
            if hasattr(layer, 'kernel_initializer'):
                layer.kernel.initializer.run(session=session)
    
    def train(self):
        
        def gen():
            while True:
                points = []
                labels = []
                for player_idx in range(self.game.num_players):
                    idxs = np.random.choice(self.game.player_pieces[player_idx].shape[0], 
                                            size=(self.game.train_batch_size//2,))
                    points += [self.game.player_pieces[player_idx][idxs,:]]
                    labels += [player_idx] * (self.game.train_batch_size//2)
                points = np.concatenate(points, axis=0)
                
                points += np.random.normal(scale=self.game.train_variance, size=points.shape)
                
                labels = np.expand_dims(np.array(labels, dtype=np.float32), axis=-1) 
                
                yield points, labels
        
        h = self.model.fit_generator(gen(), steps_per_epoch=1, epochs=self.game.train_steps, verbose=0)
        self.game.train_accuracy = h.history['binary_accuracy']
    
    def predict(self, inputs):
        
        shape  = inputs.shape[:-1]
        inputs = np.reshape(inputs, (np.prod(shape), -1))
        
        predictions = self.model.predict(inputs)
        return np.reshape(predictions, shape+(-1,))
        

In [11]:
class Game:
    
    def __init__(self):
        
        self.territory_resolution = 100
        xs     = np.linspace(-1, 1, self.territory_resolution)
        xv, yv = np.meshgrid(xs, xs)
        xv     = np.reshape(xv, (-1,))
        yv     = np.reshape(yv, (-1,))
        self.territory_grid = np.reshape(np.stack([xv,yv], axis=-1), 
                                         (self.territory_resolution, 
                                          self.territory_resolution, 2))
        
        self.message           = ['Nothing has failed... yet...']
        
        # Player Attributes
        self.num_players       = 2
        self.player_names      = ['Red Player', 'Blue Player']
        self.player_colour_map = plt.get_cmap('RdBu')
        self.player_colours    = np.array([self.player_colour_map(64), 
                                           self.player_colour_map(255-64)])
        # Game State
        self.turn = 0
        self.player = 0
        self.pieces_per_round = 3
        self.max_pieces = self.pieces_per_round * 2
        self.round_scores = []
        for player_idx in range(self.num_players):
            self.round_scores += [[1. / self.num_players]]
            
        self.player_pieces = [np.empty((0,2), dtype=np.float32), np.empty((0,2), dtype=np.float32)]
        self.territory_predictions = np.random.normal(scale=0.1, size=(100,100))
        self.territory_proportions = [0.5, 0.5]
        self.confusion_matrix      = np.zeros((2,2))
       
        # Model Params
        self.train_accuracy = []
        self.train_steps = 1000
        self.train_batch_size = 64
        self.train_variance = 0.1
        self.territory_model = TerritoryModel(self)  
        
        self.territory_predictions = self.territory_model.predict(self.territory_grid)[:,:,0]
        
        # Configure GUI
        self.title  = 'Regionnaires'
        self.figw   = 980
        self.figh   = 800
        self.figdpi = 80
        self.figure = plt.figure(num=self.title, facecolor='w', 
                                 figsize=(self.figw/self.figdpi, self.figh/self.figdpi), 
                                 dpi=self.figdpi, constrained_layout=True)

        self.grid_height    = 4
        self.grid_width     = 6
        self.layout         = self.figure.add_gridspec(self.grid_height, self.grid_width)
        self.view_board     = ViewBoard(    self, self.layout[-3:,   :-2])
        
        self.view_score     = ViewScore(    self, self.layout[1:2,  -2:])
        self.view_train     = ViewTrain(    self, self.layout[-1:,  -2:])
        self.view_confusion = ViewConfusion(self, self.layout[-2:-1,-2:-1])
        self.view_pie       = ViewPie(      self, self.layout[-2:-1,  -1:])
        self.view_message   = ViewMessage(  self, self.layout[:1,    :])

    def attach(self, key, func):
        self.figure.canvas.mpl_connect(key, func)
        
    def onclick(self, axis, x, y):
        if not (axis == self.view_board.axis): return
        
        try:
            self.message = []
            
            self.player_pieces[self.player] = np.concatenate([self.player_pieces[self.player], 
                                                              np.array([[x, y]])], axis=0)
            
            if (self.player_pieces[self.player].shape[0] > self.max_pieces):
                self.player_pieces[self.player] = self.player_pieces[self.player][-self.max_pieces:,:]
                
            self.view_board.update_player_scatter()
            
            self.message += ['added point to %s' % (self.player_names[self.player])]
            
            if ((self.turn+1) % (self.pieces_per_round * self.num_players)) == 0:
                
                self.message += ['retraining model']
            
                #self.territory_model.reset()
                self.territory_model.train()
                self.view_train.update()
                
                self.territory_predictions = self.territory_model.predict(self.territory_grid)[:,:,0]
                
                self.view_board.update_territory_predictions()
                
                sample_labels = []
                for player_idx in range(self.num_players):
                    sample_labels += [player_idx] * self.player_pieces[player_idx].shape[0]
                sample_points = np.concatenate(self.player_pieces, axis=0)
                sample_labels = np.array(sample_labels, dtype=np.int32)

                sample_preds = self.territory_model.predict(sample_points)
                sample_preds = (sample_preds > 0.5).astype(np.int32)
                
                self.confusion_matrix = confusion_matrix(sample_labels.tolist(), sample_preds.tolist())
                self.view_confusion.update()
                
                area = np.sum(self.territory_predictions >= 0.5) / (self.territory_resolution**2)
                self.territory_proportions = [1. - area, area] 
                self.view_pie.update()
                
                for player_idx in range(self.num_players):
                    
                    score = 0
                    if len(self.round_scores[player_idx]) > 0: score = self.round_scores[player_idx][-1]
                    
                    self.round_scores[player_idx] += [score + self.territory_proportions[player_idx]]
                    
                self.view_score.update()
            
            self.turn += 1
            self.player = self.turn % self.num_players
            
        except Exception:
            self.message = [traceback.format_exc()]
        
        self.view_message.update()
        
    def onmove(self, axis, x, y):
        if not (axis == self.view_board.axis): return
        
        try:
            self.cursor_x = x
            self.cursor_y = y
            
            self.view_board.update_cursor()
            
        except Exception:
            self.message = traceback.format_exc()
        
        self.view_message.update()
    
game = Game()

def onclick(event):
    global game
    game.onclick(event.inaxes, event.xdata, event.ydata)
    
def onmove(event):
    global game
    game.onmove(event.inaxes, event.xdata, event.ydata)
    
game.attach('button_press_event', onclick)
game.attach('motion_notify_event', onmove)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Ideas
---
---

- Place or move on turn