Permalink
144 lines (120 sloc) 3.35 KB
// What it does:
//
// This example uses the Tensorflow (https://www.tensorflow.org/) deep learning framework
// to classify whatever is in front of the camera.
//
// Download the Tensorflow "Inception" model and descriptions file from:
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
//
// Extract the tensorflow_inception_graph.pb model file from the .zip file.
//
// Also extract the imagenet_comp_graph_label_strings.txt file with the descriptions.
//
// How to run:
//
// go run ./cmd/tf-classifier/main.go 0 ~/Downloads/tensorflow_inception_graph.pb ~/Downloads/imagenet_comp_graph_label_strings.txt opencv cpu
//
// +build example
package main
import (
"bufio"
"fmt"
"image"
"image/color"
"os"
"gocv.io/x/gocv"
)
func main() {
if len(os.Args) < 4 {
fmt.Println("How to run:\ntf-classifier [camera ID] [modelfile] [descriptionsfile]")
return
}
// parse args
deviceID := os.Args[1]
model := os.Args[2]
descr := os.Args[3]
descriptions, err := readDescriptions(descr)
if err != nil {
fmt.Printf("Error reading descriptions file: %v\n", descr)
return
}
backend := gocv.NetBackendDefault
if len(os.Args) > 4 {
backend = gocv.ParseNetBackend(os.Args[4])
}
target := gocv.NetTargetCPU
if len(os.Args) > 5 {
target = gocv.ParseNetTarget(os.Args[5])
}
// open capture device
webcam, err := gocv.OpenVideoCapture(deviceID)
if err != nil {
fmt.Printf("Error opening video capture device: %v\n", deviceID)
return
}
defer webcam.Close()
window := gocv.NewWindow("Tensorflow Classifier")
defer window.Close()
img := gocv.NewMat()
defer img.Close()
// open DNN classifier
net := gocv.ReadNet(model, "")
if net.Empty() {
fmt.Printf("Error reading network model : %v\n", model)
return
}
defer net.Close()
net.SetPreferableBackend(gocv.NetBackendType(backend))
net.SetPreferableTarget(gocv.NetTargetType(target))
status := "Ready"
statusColor := color.RGBA{0, 255, 0, 0}
fmt.Printf("Start reading device: %v\n", deviceID)
for {
if ok := webcam.Read(&img); !ok {
fmt.Printf("Device closed: %v\n", deviceID)
return
}
if img.Empty() {
continue
}
// convert image Mat to 224x224 blob that the classifier can analyze
blob := gocv.BlobFromImage(img, 1.0, image.Pt(224, 224), gocv.NewScalar(0, 0, 0, 0), true, false)
// feed the blob into the classifier
net.SetInput(blob, "input")
// run a forward pass thru the network
prob := net.Forward("softmax2")
// reshape the results into a 1x1000 matrix
probMat := prob.Reshape(1, 1)
// determine the most probable classification
_, maxVal, _, maxLoc := gocv.MinMaxLoc(probMat)
// display classification
desc := "Unknown"
if maxLoc.X < 1000 {
desc = descriptions[maxLoc.X]
}
status = fmt.Sprintf("description: %v, maxVal: %v\n", desc, maxVal)
gocv.PutText(&img, status, image.Pt(10, 20), gocv.FontHersheyPlain, 1.2, statusColor, 2)
blob.Close()
prob.Close()
probMat.Close()
window.IMShow(img)
if window.WaitKey(1) >= 0 {
break
}
}
}
// readDescriptions reads the descriptions from a file
// and returns a slice of its lines.
func readDescriptions(path string) ([]string, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
var lines []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
return lines, scanner.Err()
}