Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
justinpinkney committed Nov 6, 2019
0 parents commit ba5f825
Show file tree
Hide file tree
Showing 56 changed files with 1,802 additions and 0 deletions.
49 changes: 49 additions & 0 deletions +tests/+integration/NetworkTrainingTest.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
classdef NetworkTrainingTest < matlab.unittest.TestCase

% Copyright 2018 The MathWorks, Inc.

methods (Test)
function testNumberNetwork(testCase)
nSamples = 10;
initialChannels = 4;
imageSize = [64, 64];

[train, ~] = sudoku.training.getNumberData(nSamples, false);

options = trainingOptions('sgdm', ...
'ExecutionEnvironment', 'cpu', ...
'MaxEpochs', 2, ...
'MiniBatchSize', 64);

layers = sudoku.training.vggLike(initialChannels, imageSize);
net = trainNetwork(train, layers, options);
end

function testSegmentationNetwork(testCase)
inputSize = [64, 64, 3];
numClasses = 2;
networkDepth = 2;
trainFraction = 0.1;

%% Get the training data
[imagesTrain, labelsTrain] = sudoku.training.getSudokuData(trainFraction, false);

train = pixelLabelImageDatastore(imagesTrain, labelsTrain, ...
'OutputSize', inputSize(1:2));

%% Setup the network
layers = segnetLayers(inputSize, numClasses, networkDepth);
layers = sudoku.training.weightLossByFrequency(layers, train);

opts = trainingOptions('sgdm', ...
'InitialLearnRate', 0.005, ...
'MaxEpochs', 2, ...
'MiniBatchSize', 2);

%% Train
net = trainNetwork(train, layers, opts); %#ok<NASGU>


end
end
end
128 changes: 128 additions & 0 deletions +tests/+unit/MapToCornerTest.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
classdef MapToCornerTest < matlab.unittest.TestCase

% Copyright 2018 The MathWorks, Inc.

methods (Test)
function testLineIntersectionOrigin(testCase)
line1 = [0, 0];
line2 = [0, pi/2];

intersection = sudoku.intersect(line1, line2);

testCase.verifyEqual(intersection, [0, 0]);
end

function testLineIntersectionPositive(testCase)
line1 = [100, 0];
line2 = [100, pi/2];

intersection = sudoku.intersect(line1, line2);

testCase.verifyEqual(intersection, [100, 100]);
end

function testLineIntersectionNegative(testCase)
line1 = [100, 0];
line2 = [-100, -pi/2];

intersection = sudoku.intersect(line1, line2);

testCase.verifyEqual(intersection, [100, 100]);
end

function testLineIntersectionParallel(testCase)
line1 = [0, 0];
line2 = [100, 0];

intersect = @() sudoku.intersect(line1, line2);

testCase.verifyError(intersect, 'sudoku:intersect:noIntersection');
end

function testMultiLineIntersect(testCase)
lines = [0, 0;
100, 0;
0, pi/2;
100, pi/2];

intersections = sudoku.intersectAll(lines);

testCase.verifyEqual(size(intersections, 1), 6);
testCase.verifyTrue(testCase.pointInArray(intersections, [0, 0], 1e-6));
testCase.verifyTrue(testCase.pointInArray(intersections, [100, 0], 1e-6));
testCase.verifyTrue(testCase.pointInArray(intersections, [0, 100], 1e-6));
testCase.verifyTrue(testCase.pointInArray(intersections, [100, 100], 1e-6));
end

function testMultiLineIntersectParallel(testCase)
lines = [0, 0;
100, 0;];

intersections = sudoku.intersectAll(lines);

testCase.verifyEqual(size(intersections, 1), 1);
testCase.verifyEqual(intersections, [NaN, NaN]);
end

function testMapToIntersections(testCase)
im = imread('+tests/Label_1.png');
map = im == 1;

lines = sudoku.getMapLines(map);
intersections = sudoku.intersectAll(lines);

testCase.verifyEqual(size(intersections, 1), 6);
testCase.verifyTrue(testCase.pointInArray(intersections, [2180, 788], 1));
testCase.verifyTrue(testCase.pointInArray(intersections, [2781, 943], 1));
testCase.verifyTrue(testCase.pointInArray(intersections, [1933, 1215], 1));
testCase.verifyTrue(testCase.pointInArray(intersections, [2596, 1412], 1));
end

function testMapToSortedIntersections(testCase)
im = imread('+tests/Label_1.png');
map = im == 1;
expectedIntersections = [2180, 788;
2781, 943;
2596, 1412;
1933, 1215;];

lines = sudoku.getMapLines(map);
intersections = sudoku.intersectAll(lines);
intersections = sudoku.selectAndSort(intersections);

testCase.verifyEqual(size(intersections, 1), 4);
testCase.verifyEqual(intersections, expectedIntersections, 'AbsTol', 1);
end

function testSelectAndSortRemoveParallel(testCase)
intersections = [0, 0;
100, 0;
NaN, NaN;
100, 100];

intersections = sudoku.selectAndSort(intersections);

testCase.verifyEqual(size(intersections, 1), 3);
end

function testSelectAndSortRemoveDistant(testCase)
intersections = [0, 0;
100, 0;
0, 100;
100, 100;
500, 1000];

intersections = sudoku.selectAndSort(intersections);

testCase.verifyEqual(size(intersections, 1), 4);
end
end

methods (Static)
function tf = pointInArray(points, testPoint, tolerance)
testResult = sum(sum(abs(points - testPoint) < tolerance, 2) == 2);
tf = testResult == 1;
end
end

end
42 changes: 42 additions & 0 deletions +tests/+unit/ReadDataTest.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
classdef ReadDataTest < matlab.unittest.TestCase

% Copyright 2018 The MathWorks, Inc.

methods (Test)
function testImportOneLabel(testCase)
data = sudoku.training.readNumberLabels('+tests/data1');

testCase.verifyEqual(length(data), 1);
testCase.verifyEqual(size(data('0001')), [9, 9]);
end

function testImportTwoLabels(testCase)
data = sudoku.training.readNumberLabels('+tests/data2');

testCase.verifyEqual(length(data), 2);
testCase.verifyEqual(size(data('0001')), [9, 9]);
testCase.verifyEqual(size(data('0002')), [9, 9]);
end

function testImportBadNumbers(testCase)
importData = @() sudoku.training.readNumberLabels('+tests/bad_data1');

testCase.verifyError(importData, 'sudoku:BadNumberData');
end

function testImportBadLabels(testCase)
importData = @() sudoku.training.readNumberLabels('+tests/bad_data2');

testCase.verifyError(importData, 'sudoku:DuplicateLabel');
end

function testNameParseing(testCase)
testName = 'C:\one\two\0003_04.jpg';

[sudokuNumber, repeat] = sudoku.training.parseFilename(testName);

testCase.verifyEqual(sudokuNumber, '0003');
testCase.verifyEqual(repeat, '04');
end
end
end
34 changes: 34 additions & 0 deletions +tests/+unit/UndistortTest.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
classdef UndistortTest < matlab.unittest.TestCase

% Copyright 2018 The MathWorks, Inc.

methods (Test)

function testUndistortion(testCase)
im = imread('+tests/0001_01.jpg');
imagePoints = [2180, 788; ...
2781, 943; ...
1933, 1215; ...
2596, 1412];

imagePoints = imagePoints + 20*(0.5 - 1*rand(4,2));

boxWidth = 32;
fullWidth = 9*boxWidth;
worldPoints = [0, 0; ...
fullWidth, 0; ...
0, fullWidth; ...
fullWidth, fullWidth];
outputImage = sudoku.undistort(im, imagePoints, worldPoints);

newIm = mat2cell(outputImage, ...
repmat(boxWidth, 1, 9), ...
repmat(boxWidth, 1, 9), ...
3);
montage(newIm)
% TODO finish this test
end

end

end
Binary file added +tests/0001_01.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added +tests/Label_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions +tests/bad_data1
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
0001
-5 -3 4 6 -7 8 9 1 2
-6 7 2 -1 -9 -5 3 4 8
1 -9 -8 3 4 2 5 -6 7
-8 5 9 7 -6 4 2 -3
-4 2 6 -8 5 -3 7 9 -1
-7 1 3 9 -2 4 8 5 -6
9 -6 1 5 3 7 -2 -8 4
2 8 7 -4 -1 -9 6 3 -5
3 4 5 2 -8 6 1 -7 -9
21 changes: 21 additions & 0 deletions +tests/bad_data2
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
0001
-5 -3 4 6 -7 8 9 1 2
-6 7 2 -1 -9 -5 3 4 8
1 -9 -8 3 4 2 5 -6 7
-8 5 9 7 -6 1 4 2 -3
-4 2 6 -8 5 -3 7 9 -1
-7 1 3 9 -2 4 8 5 -6
9 -6 1 5 3 7 -2 -8 4
2 8 7 -4 -1 -9 6 3 -5
3 4 5 2 -8 6 1 -7 -9

0001
-5 -3 4 6 -7 8 9 1 2
-6 7 2 -1 -9 -5 3 4 8
1 -9 -8 3 4 2 5 -6 7
-8 5 9 7 -6 1 4 2 -3
-4 2 6 -8 5 -3 7 9 -1
-7 1 3 9 -2 4 8 5 -6
9 -6 1 5 3 7 -2 -8 4
2 8 7 -4 -1 -9 6 3 -5
3 4 5 2 -8 6 1 -7 -9
10 changes: 10 additions & 0 deletions +tests/data1
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
0001
-5 -3 4 6 -7 8 9 1 2
-6 7 2 -1 -9 -5 3 4 8
1 -9 -8 3 4 2 5 -6 7
-8 5 9 7 -6 1 4 2 -3
-4 2 6 -8 5 -3 7 9 -1
-7 1 3 9 -2 4 8 5 -6
9 -6 1 5 3 7 -2 -8 4
2 8 7 -4 -1 -9 6 3 -5
3 4 5 2 -8 6 1 -7 -9
21 changes: 21 additions & 0 deletions +tests/data2
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
0001
-5 -3 4 6 -7 8 9 1 2
-6 7 2 -1 -9 -5 3 4 8
1 -9 -8 3 4 2 5 -6 7
-8 5 9 7 -6 1 4 2 -3
-4 2 6 -8 5 -3 7 9 -1
-7 1 3 9 -2 4 8 5 -6
9 -6 1 5 3 7 -2 -8 4
2 8 7 -4 -1 -9 6 3 -5
3 4 5 2 -8 6 1 -7 -9

0002
-5 -3 4 6 -7 8 9 1 2
-6 7 2 -1 -9 -5 3 4 8
1 -9 -8 3 4 2 5 -6 7
-8 5 9 7 -6 1 4 2 -3
-4 2 6 -8 5 -3 7 9 -1
-7 1 3 9 -2 4 8 5 -6
9 -6 1 5 3 7 -2 -8 4
2 8 7 -4 -1 -9 6 3 -5
3 4 5 2 -8 6 1 -7 -9
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
*.mp4
*.asv
*.avi
models/
raw_data/
number_data/
checkpoints/
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Deep Sudoku Solver

__Takes an uncontrolled image of a sudoku puzzle, identifies the location, reads the puzzle, and solves it.__

This example was originally put together for the [UK MATLAB Expo](https://www.matlabexpo.com/uk) 2018, for a talk entitled _Computer Vision and Image processing with MATLAB ([video](https://www.mathworks.com/videos/image-processing-and-computer-vision-with-matlab-1541003708736.html), [blog post(https://blogs.mathworks.com/deep-learning/2018/11/15/sudoku-solver-image-processing-and-deep-learning/)])_. It is intended to demonstrate the use of a combination of deep learning and image procesing to solve a computer vision problem.

## Getting started

- Get a copy of the code either by cloning the repository or downloading a .zip
- Run the example live script getting_started.mlx

## Details

Broadly the algorithm is divided into four distinct steps:

1. Find the sudoku puzzle in an image using deep learning (sematic segmentation)
2. Extracts each of the 81 number boxes in the puzzle using image processing.
3. Read the number contained in each box using deep learning.
4. Solve the puzzle using opimisation.

For more details see the original [Expo talk](https://www.mathworks.com/videos/image-processing-and-computer-vision-with-matlab-1541003708736.html).

![](presentation/reprojected_result.jpg)

## Usage

- Install my navigating to the top level directory then running `install()` to add the required folders to the MATLAB path.
- Run `setupData()` to fetch the required training data from my public drive.
- Run `sudoku.trainSemanticSegmentation(filename)`
where `filename` is the name under which the trainded network will be saved (in the `models/` folder).
- Run `sudoku.trainNumberNetwork(filename)`
- TODO run on a single image

## Contributing

Please file any bug reports or questions as [GitHub issues](https://github.com/mathworks/deep-sudoku-solver/issues).

_Copyright 2018 The MathWorks, Inc._
7 changes: 7 additions & 0 deletions install.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function install()

% Copyright 2018 The MathWorks, Inc.

thisPath = fileparts(mfilename('fullpath'));
addpath(fullfile(thisPath, 'src', 'sudoku'));
end
Loading

0 comments on commit ba5f825

Please sign in to comment.