-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit ba5f825
Showing
56 changed files
with
1,802 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
classdef NetworkTrainingTest < matlab.unittest.TestCase | ||
|
||
% Copyright 2018 The MathWorks, Inc. | ||
|
||
methods (Test) | ||
function testNumberNetwork(testCase) | ||
nSamples = 10; | ||
initialChannels = 4; | ||
imageSize = [64, 64]; | ||
|
||
[train, ~] = sudoku.training.getNumberData(nSamples, false); | ||
|
||
options = trainingOptions('sgdm', ... | ||
'ExecutionEnvironment', 'cpu', ... | ||
'MaxEpochs', 2, ... | ||
'MiniBatchSize', 64); | ||
|
||
layers = sudoku.training.vggLike(initialChannels, imageSize); | ||
net = trainNetwork(train, layers, options); | ||
end | ||
|
||
function testSegmentationNetwork(testCase) | ||
inputSize = [64, 64, 3]; | ||
numClasses = 2; | ||
networkDepth = 2; | ||
trainFraction = 0.1; | ||
|
||
%% Get the training data | ||
[imagesTrain, labelsTrain] = sudoku.training.getSudokuData(trainFraction, false); | ||
|
||
train = pixelLabelImageDatastore(imagesTrain, labelsTrain, ... | ||
'OutputSize', inputSize(1:2)); | ||
|
||
%% Setup the network | ||
layers = segnetLayers(inputSize, numClasses, networkDepth); | ||
layers = sudoku.training.weightLossByFrequency(layers, train); | ||
|
||
opts = trainingOptions('sgdm', ... | ||
'InitialLearnRate', 0.005, ... | ||
'MaxEpochs', 2, ... | ||
'MiniBatchSize', 2); | ||
|
||
%% Train | ||
net = trainNetwork(train, layers, opts); %#ok<NASGU> | ||
|
||
|
||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
classdef MapToCornerTest < matlab.unittest.TestCase | ||
|
||
% Copyright 2018 The MathWorks, Inc. | ||
|
||
methods (Test) | ||
function testLineIntersectionOrigin(testCase) | ||
line1 = [0, 0]; | ||
line2 = [0, pi/2]; | ||
|
||
intersection = sudoku.intersect(line1, line2); | ||
|
||
testCase.verifyEqual(intersection, [0, 0]); | ||
end | ||
|
||
function testLineIntersectionPositive(testCase) | ||
line1 = [100, 0]; | ||
line2 = [100, pi/2]; | ||
|
||
intersection = sudoku.intersect(line1, line2); | ||
|
||
testCase.verifyEqual(intersection, [100, 100]); | ||
end | ||
|
||
function testLineIntersectionNegative(testCase) | ||
line1 = [100, 0]; | ||
line2 = [-100, -pi/2]; | ||
|
||
intersection = sudoku.intersect(line1, line2); | ||
|
||
testCase.verifyEqual(intersection, [100, 100]); | ||
end | ||
|
||
function testLineIntersectionParallel(testCase) | ||
line1 = [0, 0]; | ||
line2 = [100, 0]; | ||
|
||
intersect = @() sudoku.intersect(line1, line2); | ||
|
||
testCase.verifyError(intersect, 'sudoku:intersect:noIntersection'); | ||
end | ||
|
||
function testMultiLineIntersect(testCase) | ||
lines = [0, 0; | ||
100, 0; | ||
0, pi/2; | ||
100, pi/2]; | ||
|
||
intersections = sudoku.intersectAll(lines); | ||
|
||
testCase.verifyEqual(size(intersections, 1), 6); | ||
testCase.verifyTrue(testCase.pointInArray(intersections, [0, 0], 1e-6)); | ||
testCase.verifyTrue(testCase.pointInArray(intersections, [100, 0], 1e-6)); | ||
testCase.verifyTrue(testCase.pointInArray(intersections, [0, 100], 1e-6)); | ||
testCase.verifyTrue(testCase.pointInArray(intersections, [100, 100], 1e-6)); | ||
end | ||
|
||
function testMultiLineIntersectParallel(testCase) | ||
lines = [0, 0; | ||
100, 0;]; | ||
|
||
intersections = sudoku.intersectAll(lines); | ||
|
||
testCase.verifyEqual(size(intersections, 1), 1); | ||
testCase.verifyEqual(intersections, [NaN, NaN]); | ||
end | ||
|
||
function testMapToIntersections(testCase) | ||
im = imread('+tests/Label_1.png'); | ||
map = im == 1; | ||
|
||
lines = sudoku.getMapLines(map); | ||
intersections = sudoku.intersectAll(lines); | ||
|
||
testCase.verifyEqual(size(intersections, 1), 6); | ||
testCase.verifyTrue(testCase.pointInArray(intersections, [2180, 788], 1)); | ||
testCase.verifyTrue(testCase.pointInArray(intersections, [2781, 943], 1)); | ||
testCase.verifyTrue(testCase.pointInArray(intersections, [1933, 1215], 1)); | ||
testCase.verifyTrue(testCase.pointInArray(intersections, [2596, 1412], 1)); | ||
end | ||
|
||
function testMapToSortedIntersections(testCase) | ||
im = imread('+tests/Label_1.png'); | ||
map = im == 1; | ||
expectedIntersections = [2180, 788; | ||
2781, 943; | ||
2596, 1412; | ||
1933, 1215;]; | ||
|
||
lines = sudoku.getMapLines(map); | ||
intersections = sudoku.intersectAll(lines); | ||
intersections = sudoku.selectAndSort(intersections); | ||
|
||
testCase.verifyEqual(size(intersections, 1), 4); | ||
testCase.verifyEqual(intersections, expectedIntersections, 'AbsTol', 1); | ||
end | ||
|
||
function testSelectAndSortRemoveParallel(testCase) | ||
intersections = [0, 0; | ||
100, 0; | ||
NaN, NaN; | ||
100, 100]; | ||
|
||
intersections = sudoku.selectAndSort(intersections); | ||
|
||
testCase.verifyEqual(size(intersections, 1), 3); | ||
end | ||
|
||
function testSelectAndSortRemoveDistant(testCase) | ||
intersections = [0, 0; | ||
100, 0; | ||
0, 100; | ||
100, 100; | ||
500, 1000]; | ||
|
||
intersections = sudoku.selectAndSort(intersections); | ||
|
||
testCase.verifyEqual(size(intersections, 1), 4); | ||
end | ||
end | ||
|
||
methods (Static) | ||
function tf = pointInArray(points, testPoint, tolerance) | ||
testResult = sum(sum(abs(points - testPoint) < tolerance, 2) == 2); | ||
tf = testResult == 1; | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
classdef ReadDataTest < matlab.unittest.TestCase | ||
|
||
% Copyright 2018 The MathWorks, Inc. | ||
|
||
methods (Test) | ||
function testImportOneLabel(testCase) | ||
data = sudoku.training.readNumberLabels('+tests/data1'); | ||
|
||
testCase.verifyEqual(length(data), 1); | ||
testCase.verifyEqual(size(data('0001')), [9, 9]); | ||
end | ||
|
||
function testImportTwoLabels(testCase) | ||
data = sudoku.training.readNumberLabels('+tests/data2'); | ||
|
||
testCase.verifyEqual(length(data), 2); | ||
testCase.verifyEqual(size(data('0001')), [9, 9]); | ||
testCase.verifyEqual(size(data('0002')), [9, 9]); | ||
end | ||
|
||
function testImportBadNumbers(testCase) | ||
importData = @() sudoku.training.readNumberLabels('+tests/bad_data1'); | ||
|
||
testCase.verifyError(importData, 'sudoku:BadNumberData'); | ||
end | ||
|
||
function testImportBadLabels(testCase) | ||
importData = @() sudoku.training.readNumberLabels('+tests/bad_data2'); | ||
|
||
testCase.verifyError(importData, 'sudoku:DuplicateLabel'); | ||
end | ||
|
||
function testNameParseing(testCase) | ||
testName = 'C:\one\two\0003_04.jpg'; | ||
|
||
[sudokuNumber, repeat] = sudoku.training.parseFilename(testName); | ||
|
||
testCase.verifyEqual(sudokuNumber, '0003'); | ||
testCase.verifyEqual(repeat, '04'); | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
classdef UndistortTest < matlab.unittest.TestCase | ||
|
||
% Copyright 2018 The MathWorks, Inc. | ||
|
||
methods (Test) | ||
|
||
function testUndistortion(testCase) | ||
im = imread('+tests/0001_01.jpg'); | ||
imagePoints = [2180, 788; ... | ||
2781, 943; ... | ||
1933, 1215; ... | ||
2596, 1412]; | ||
|
||
imagePoints = imagePoints + 20*(0.5 - 1*rand(4,2)); | ||
|
||
boxWidth = 32; | ||
fullWidth = 9*boxWidth; | ||
worldPoints = [0, 0; ... | ||
fullWidth, 0; ... | ||
0, fullWidth; ... | ||
fullWidth, fullWidth]; | ||
outputImage = sudoku.undistort(im, imagePoints, worldPoints); | ||
|
||
newIm = mat2cell(outputImage, ... | ||
repmat(boxWidth, 1, 9), ... | ||
repmat(boxWidth, 1, 9), ... | ||
3); | ||
montage(newIm) | ||
% TODO finish this test | ||
end | ||
|
||
end | ||
|
||
end |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
0001 | ||
-5 -3 4 6 -7 8 9 1 2 | ||
-6 7 2 -1 -9 -5 3 4 8 | ||
1 -9 -8 3 4 2 5 -6 7 | ||
-8 5 9 7 -6 4 2 -3 | ||
-4 2 6 -8 5 -3 7 9 -1 | ||
-7 1 3 9 -2 4 8 5 -6 | ||
9 -6 1 5 3 7 -2 -8 4 | ||
2 8 7 -4 -1 -9 6 3 -5 | ||
3 4 5 2 -8 6 1 -7 -9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
0001 | ||
-5 -3 4 6 -7 8 9 1 2 | ||
-6 7 2 -1 -9 -5 3 4 8 | ||
1 -9 -8 3 4 2 5 -6 7 | ||
-8 5 9 7 -6 1 4 2 -3 | ||
-4 2 6 -8 5 -3 7 9 -1 | ||
-7 1 3 9 -2 4 8 5 -6 | ||
9 -6 1 5 3 7 -2 -8 4 | ||
2 8 7 -4 -1 -9 6 3 -5 | ||
3 4 5 2 -8 6 1 -7 -9 | ||
|
||
0001 | ||
-5 -3 4 6 -7 8 9 1 2 | ||
-6 7 2 -1 -9 -5 3 4 8 | ||
1 -9 -8 3 4 2 5 -6 7 | ||
-8 5 9 7 -6 1 4 2 -3 | ||
-4 2 6 -8 5 -3 7 9 -1 | ||
-7 1 3 9 -2 4 8 5 -6 | ||
9 -6 1 5 3 7 -2 -8 4 | ||
2 8 7 -4 -1 -9 6 3 -5 | ||
3 4 5 2 -8 6 1 -7 -9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
0001 | ||
-5 -3 4 6 -7 8 9 1 2 | ||
-6 7 2 -1 -9 -5 3 4 8 | ||
1 -9 -8 3 4 2 5 -6 7 | ||
-8 5 9 7 -6 1 4 2 -3 | ||
-4 2 6 -8 5 -3 7 9 -1 | ||
-7 1 3 9 -2 4 8 5 -6 | ||
9 -6 1 5 3 7 -2 -8 4 | ||
2 8 7 -4 -1 -9 6 3 -5 | ||
3 4 5 2 -8 6 1 -7 -9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
0001 | ||
-5 -3 4 6 -7 8 9 1 2 | ||
-6 7 2 -1 -9 -5 3 4 8 | ||
1 -9 -8 3 4 2 5 -6 7 | ||
-8 5 9 7 -6 1 4 2 -3 | ||
-4 2 6 -8 5 -3 7 9 -1 | ||
-7 1 3 9 -2 4 8 5 -6 | ||
9 -6 1 5 3 7 -2 -8 4 | ||
2 8 7 -4 -1 -9 6 3 -5 | ||
3 4 5 2 -8 6 1 -7 -9 | ||
|
||
0002 | ||
-5 -3 4 6 -7 8 9 1 2 | ||
-6 7 2 -1 -9 -5 3 4 8 | ||
1 -9 -8 3 4 2 5 -6 7 | ||
-8 5 9 7 -6 1 4 2 -3 | ||
-4 2 6 -8 5 -3 7 9 -1 | ||
-7 1 3 9 -2 4 8 5 -6 | ||
9 -6 1 5 3 7 -2 -8 4 | ||
2 8 7 -4 -1 -9 6 3 -5 | ||
3 4 5 2 -8 6 1 -7 -9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
*.mp4 | ||
*.asv | ||
*.avi | ||
models/ | ||
raw_data/ | ||
number_data/ | ||
checkpoints/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Deep Sudoku Solver | ||
|
||
__Takes an uncontrolled image of a sudoku puzzle, identifies the location, reads the puzzle, and solves it.__ | ||
|
||
This example was originally put together for the [UK MATLAB Expo](https://www.matlabexpo.com/uk) 2018, for a talk entitled _Computer Vision and Image processing with MATLAB ([video](https://www.mathworks.com/videos/image-processing-and-computer-vision-with-matlab-1541003708736.html), [blog post(https://blogs.mathworks.com/deep-learning/2018/11/15/sudoku-solver-image-processing-and-deep-learning/)])_. It is intended to demonstrate the use of a combination of deep learning and image procesing to solve a computer vision problem. | ||
|
||
## Getting started | ||
|
||
- Get a copy of the code either by cloning the repository or downloading a .zip | ||
- Run the example live script getting_started.mlx | ||
|
||
## Details | ||
|
||
Broadly the algorithm is divided into four distinct steps: | ||
|
||
1. Find the sudoku puzzle in an image using deep learning (sematic segmentation) | ||
2. Extracts each of the 81 number boxes in the puzzle using image processing. | ||
3. Read the number contained in each box using deep learning. | ||
4. Solve the puzzle using opimisation. | ||
|
||
For more details see the original [Expo talk](https://www.mathworks.com/videos/image-processing-and-computer-vision-with-matlab-1541003708736.html). | ||
|
||
![](presentation/reprojected_result.jpg) | ||
|
||
## Usage | ||
|
||
- Install my navigating to the top level directory then running `install()` to add the required folders to the MATLAB path. | ||
- Run `setupData()` to fetch the required training data from my public drive. | ||
- Run `sudoku.trainSemanticSegmentation(filename)` | ||
where `filename` is the name under which the trainded network will be saved (in the `models/` folder). | ||
- Run `sudoku.trainNumberNetwork(filename)` | ||
- TODO run on a single image | ||
|
||
## Contributing | ||
|
||
Please file any bug reports or questions as [GitHub issues](https://github.com/mathworks/deep-sudoku-solver/issues). | ||
|
||
_Copyright 2018 The MathWorks, Inc._ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
function install() | ||
|
||
% Copyright 2018 The MathWorks, Inc. | ||
|
||
thisPath = fileparts(mfilename('fullpath')); | ||
addpath(fullfile(thisPath, 'src', 'sudoku')); | ||
end |
Oops, something went wrong.