Browse files

starting to state-machine-ify flow structure, split up data structure…

…s into separate class files, introduced state machine and state classes
  • Loading branch information...
1 parent 7ef0e99 commit b4ce29f8fdeb8686ea3f9e1404d22cf0fe828308 @cathywu committed Jan 27, 2012
Showing with 114 additions and 26 deletions.
  1. +39 −0 RingBuffer.py
  2. +20 −0 Video.py
  3. +41 −0 VideoStateMachine.py
  4. +13 −4 lkdemo.py
  5. +1 −22 main.py
View
39 RingBuffer.py
@@ -0,0 +1,39 @@
+import numpy as np
+
+class RingBuffer:
+ def __init__(self,size_max=5,default=0,data=[]):
+ self.max = size_max
+ self.default = default
+ self.data = list(data)
+ self.index = len(data) % self.max
+ self.count = len(data) # number of non-default values in RingBuffer
+ self.data.extend([default for i in range(self.max-len(self.data))])
+
+ def write(self,x):
+ "write value into ring buffer"
+ self.data[self.index] = x
+ if self.index + 1 > self.count:
+ self.count = self.index + 1
+ self.index = (self.index + 1) % self.max
+ def empty(self):
+ self.data = [default for i in self.data]
+ self.count = 0
+
+ def get(self):
+ "returns a list of elements from oldest to newest"
+ return self.data[self.index:] + self.data[:self.index]
+ def get_prev(self,i):
+ return self.data[(self.index - i) % self.max]
+ def get_last(self):
+ return self.get_prev(1)
+
+ def set_prev(self,i,value):
+ self.data[(self.index - i) % self.max] = value
+ def set_last(self,value):
+ self.set_prev(1,value)
+
+ # FIXME below functions only work for integer values right now, not objects
+ def average(self):
+ return np.average(self.data)
+ def meanmagnitude(self):
+ return np.sqrt(np.dot(self.data,self.data))/self.max
View
20 Video.py
@@ -0,0 +1,20 @@
+import matplotlib.pyplot as plt
+from SimpleCV import VirtualCamera
+
+class Video:
+ def __init__(self,path):
+ self.capture = VirtualCamera(path,"video")
+ self.im = None
+ def step(self,stepsize=1,scale=0.50):
+ for i in range(stepsize-1):
+ self.capture.getImage()
+ self.im = self.capture.getImage().copy().scale(scale)
+ return self.get_image()
+ def show(self):
+ plt.figure()
+ plt.show()
+ plt.imshow(self.im,cmap="gray")
+ def get_image(self):
+ return self.im
+ def save(self):
+ pass
View
41 VideoStateMachine.py
@@ -0,0 +1,41 @@
+from RingBuffer import RingBuffer
+
+class VideoStateMachine:
+ def __init__(self,state):
+ self.state = state
+ def next_state(self,features,frame):
+ self.state = self.state.transition(features,frame) or self.state
+ return self.current_state()
+ def current_state(self):
+ return self.state
+
+class VideoState:
+ def __init__(self,features_history=None,
+ frame_history=None):
+ self.features_history = features_history or RingBuffer(size_max=10,default=None)
+ self.frame_history = frame_history or RingBuffer(size_max=10,default=None)
+ self.flow_history = RingBuffer(size_max=10,default=None)
+ def transition(self,features,frame):
+ self.features_history.write(features)
+ self.frame_history.write(frame)
+ if self.features_history.count > 1:
+ diff = [(a-c,b-d) for ((a,b),(c,d)) in zip(self.features_history.get_prev(2),self.features_history.get_last())]
+ x,y = zip(*diff)
+ print "Current (%s,%s)" % (sum(x)/len(x),sum(y)/len(y))
+ #recent_flowX.write(sum(x)/len(x))
+ #recent_flowY.write(sum(y)/len(y))
+ #print "Recent (%s,%s)" % (recent_flowX.meanmagnitude(),recent_flowY.meanmagnitude())
+ return None
+ def get_output_frame(self):
+ return self.frame_history.get_last()
+
+#class Shift10pxDownState(VideoState):
+
+class NoiseState(VideoState):
+ def __init__(self,flow_history=None,feature_history=None,
+ frame_history=None):
+ super(NoiseState,self).__init__(flow_history=flow_history,
+ feature_history=feature_history,frame_history=frame_history)
+ def transition(self,features,frame):
+ super(NoiseState, self).transition(features,frame)
+
View
17 lkdemo.py
@@ -6,6 +6,9 @@
# import the necessary things for OpenCV and video reading
from main import Video, DATA_PATH
+from VideoStateMachine import VideoStateMachine, VideoState
+from RingBuffer import RingBuffer
+from Video import Video
from SimpleCV import cv
#############################################################################
@@ -69,14 +72,15 @@ def on_mouse (event, x, y, flags, param):
"To add/remove a feature point click it\n"
# first, create the necessary windows
- cv.NamedWindow ('LkDemo', cv.CV_WINDOW_AUTOSIZE)
+ cv.NamedWindow ('Video Stabilization', cv.CV_WINDOW_AUTOSIZE)
# register the mouse callback
- cv.SetMouseCallback ('LkDemo', on_mouse, None)
+ cv.SetMouseCallback ('Video Stabilization', on_mouse, None)
frame = video.step()
imsize = frame.size()
frame = frame.getBitmap()
+ state_machine = VideoStateMachine(VideoState())
while 1:
# do forever
@@ -128,15 +132,20 @@ def on_mouse (event, x, y, flags, param):
# we have points, so display them
# calculate the optical flow
- features, status, track_error = cv.CalcOpticalFlowPyrLK (
+ new_features, status, track_error = cv.CalcOpticalFlowPyrLK (
prev_grey, grey, prev_pyramid, pyramid,
features,
(win_size, win_size), 3,
(cv.CV_TERMCRIT_ITER|cv.CV_TERMCRIT_EPS, 20, 0.03),
flags)
+ agg_err = sum([e for e in track_error])
+ if agg_err > 1e-03:
+ print "Agg_err: %s" % agg_err
+ state = state_machine.next_state(features,image)
+ image = state.get_output_frame()
# set back the points we keep
- features = [ p for (st,p) in zip(status, features) if st]
+ features = [ p for (st,p) in zip(status, new_features) if st ]
if add_remove_pt:
# we have a point to add, so see if it is close to
View
23 main.py
@@ -8,32 +8,11 @@
import lk
import numpy as np
import time
+from Video import Video
#DATA_PATH = "/home/cathywu/Dropbox/UROP/wearable/data/exp001/compressed/iphone4s-1920r_30f_all_auto.avi"
DATA_PATH = "/home/cathywu/Dropbox/UROP/wearable/data/exp001/compressed/firefly_fw_640r_60f_320s.avi"
-lk_params = dict( winSize = (10, 10),
- maxLevel = 3,
- criteria = (cv.CV_TERMCRIT_EPS | cv.CV_TERMCRIT_ITER, 20, 0.03),
- derivLambda = 0.0 )
-class Video:
- def __init__(self,path):
- self.capture = scv.VirtualCamera(path,"video")
- self.im = None
- def step(self,stepsize=1,scale=0.50):
- for i in range(stepsize-1):
- self.capture.getImage()
- self.im = self.capture.getImage().copy().scale(scale)
- return self.get_image()
- def show(self):
- plt.figure()
- plt.show()
- plt.imshow(self.im,cmap="gray")
- def get_image(self):
- return self.im
- def save(self):
- pass
-
class Flow:
def __init__(self,im1,im2,win=10):
self.im1 = im1

0 comments on commit b4ce29f

Please sign in to comment.