In [17]:
import cv2
import numpy as np
from skimage.registration import phase_cross_correlation
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

PATH = r"/home/moritz/Nextcloud/Uni/Physik Master/Semester 2/IMIP/fringe_detection"

# Avoid Qt conflicts with OpenCV
matplotlib.use('Agg')  

def crop_frames(frames, crop_region):
    (xmin, xmax), (ymin, ymax) = crop_region
    return frames[:, ymin:ymax, xmin:xmax]

def subtract_min(frames):
    """Subtract minimum frame and ensure non-negative values."""
    # Calculate minimum in float64 for precision
    min_frame = np.min(frames, axis=0)
    
    # Save min frame visualization
    plt.imshow(min_frame.squeeze(), cmap='gray')
    plt.colorbar()
    plt.title('Min Frame')
    plt.savefig("min_frame.png", dpi=150, bbox_inches='tight')
    plt.close()
    
    # Subtract min and handle negative values efficiently
    frames_float = frames.astype(np.float64) - min_frame

    max_val = np.amax(frames_float)
    frames_float *= 255.0 / max_val 

    return frames_float

def preprocess_video(input_path):
    cap = cv2.VideoCapture(input_path)
    
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    ret, first_frame = cap.read()
    first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
    
    height, width = first_frame.shape
    
    frames = np.empty((total_frames, height, width), dtype=np.uint8)
    frames[0] = first_frame
    
    frame_idx = 1
    with tqdm(total=total_frames-1, desc="Loading frames") as pbar:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            # Convert to grayscale immediately
            if len(frame.shape) == 3:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            
            frames[frame_idx] = frame
            frame_idx += 1
            pbar.update(1)
    
    cap.release()
    
    frames = frames[..., np.newaxis]

    frames = subtract_min(frames)
    
    return frames, fps

def video_from_frames(frames, output_path, fps=30):
    height, width = frames.shape[1:3]
    frames_bgr = np.repeat(frames, 3, axis=3)  # Convert to BGR by repeating the single channel
    print(frames_bgr.shape)

    
    fourcc = cv2.VideoWriter.fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, int(fps), (width, height))
    
    with tqdm(total=len(frames_bgr), desc="Writing video") as pbar:
        for frame in frames_bgr:
            out.write(frame.astype(np.uint8))
            pbar.update(1)
    
    out.release()


def get_phase_cross_correlation_shifts(frames, upsample_factor=100):
    if len(frames.shape) == 4 and frames.shape[3] == 1:
        # Remove channel dimension for processing
        frames_2d = frames.squeeze(axis=3)
    else:
        frames_2d = frames
    
    relative_shifts = [np.array([0.0, 0.0])]
    absolute_shifts = [np.array([0.0, 0.0])]
    
    # Convert frames to float64 once for better precision
    frames_float = frames_2d.astype(np.float64)
    
    # Process consecutive frames with progress bar
    with tqdm(total=len(frames_float)-1, desc="Computing shifts") as pbar:
        for i in range(len(frames_float)-1):
            reference = frames_float[i]
            current = frames_float[i+1]
            
            # Calculate phase correlation
            shift, _, _ = phase_cross_correlation(
                reference, current, 
                upsample_factor=upsample_factor
            )
            
            relative_shifts.append(shift)
            absolute_shifts.append(absolute_shifts[-1] + shift)
            pbar.update(1)
    
    return np.array(absolute_shifts), np.array(relative_shifts)


def make_pretty_plot(total_shifts):
    plt.figure(figsize=(10, 5))
    plt.xlim(0, len(total_shifts))
    # custom x scale
    plt.autoscale(axis='y')
    plt.plot(total_shifts)
    # vlines at 400 and 800
    line_1 = 467
    line_2 = 863
    plt.axvline(x=line_1, color='g', linestyle='--', label='Pump On')
    plt.axvline(x=line_2, color='g', linestyle='--', label='Pump Off')
    # x ticks in seconds (30fps)
    num_labels = 10
    fps = 30
    max_seconds = len(total_shifts) / fps
    x_ticks = np.linspace(0, max_seconds, num_labels)
    x_labels = [f"{int(tick)}" for tick in x_ticks]
    plt.xticks(np.arange(0, len(total_shifts), step=int(len(total_shifts) / num_labels)), x_labels)
    # get y ticks ans set labels to be const*tick
    maxima_per_second = 12
    wavelength = 600e-9  # 600 nm
    const = maxima_per_second * wavelength / 2 * 1e6  # convert to micrometers
    y_ticks = plt.yticks()[0]
    ticks = [int(tick)*const for tick in y_ticks]
    y_labels = [f"{tick:.1f}" for tick in ticks]
    plt.yticks(y_ticks, y_labels)
    # annotate
    plt.annotate('Pump Off', xy=(line_1-300, max(total_shifts) * 0.9), color='red', fontsize=12)
    plt.annotate('Pump On', xy=(line_1+135, max(total_shifts) * 0.9), color='green', fontsize=12)
    plt.annotate('Pump Off', xy=(line_2+135, max(total_shifts) * 0.9), color='red', fontsize=12)
    # background color in first 400 frames ligh red, the light green, then light red again
    plt.axvspan(0, line_1, color='lightcoral', alpha=0.3)
    plt.axvspan(line_1, line_2, color='lightgreen', alpha=0.3)
    plt.axvspan(line_2, len(total_shifts), color='lightcoral', alpha=0.3) 
    # labeling and stuff
    plt.xlabel("Time (s)")
    plt.ylabel("Total Shift (um)")
    plt.title("Total Phase Correlation Shift Over Time")
    plt.grid()
    plt.tight_layout()
    #show / save
    plt.savefig("pretty_plot.png", dpi=150, bbox_inches='tight')
    plt.show()
    plt.close()

try:
    make_pretty_plot(total_shifts)
except NameError:
    pass

  plt.show()


In [4]:
vid = r"data/1ul_pump_on_off_2.mp4"
vid = os.path.join(PATH, vid)
out = r"data/1ul_pump_on_off_2_cropped.mp4"
out = os.path.join(PATH, out)
crop_region = None
crop_regions = [((0, 400), (0, 100)),
                ]

frames, fps = preprocess_video(vid)
frames = frames[:1250]  # Limit to first 1250 frames as ROI

video_from_frames(frames, out, fps=int(fps))

Loading frames: 100%|██████████| 1370/1370 [00:01<00:00, 1123.22it/s]


(1250, 480, 640, 3)


Writing video: 100%|██████████| 1250/1250 [00:14<00:00, 86.18it/s] 


In [5]:
frames_cropped = [crop_frames(frames, crop_region) for crop_region in crop_regions]
shifts = []
for cropped_frames in frames_cropped:
    absolute_shifts, relative_shifts = get_phase_cross_correlation_shifts(cropped_frames)
    shifts.append(absolute_shifts)

# calculate total shift (pythagoras)
total_shifts = [np.linalg.norm(shift, axis=1) for shift in shifts]
total_shifts = np.sum(total_shifts, axis=0)

Computing shifts: 100%|██████████| 1249/1249 [00:09<00:00, 128.07it/s]


In [10]:
make_pretty_plot(total_shifts)

  plt.show()
