In [None]:
from strikepoint.driver import LeptonDriver
from IPython.display import Image, display
from picamera2 import Picamera2
from time import sleep

import cv2
import numpy as np



In [None]:
def showFrame(frame):
    ok, buf = cv2.imencode('.png', frame)
    if not ok:
        raise RuntimeError("Could not encode frame to PNG")
    display(Image(data=buf.tobytes()))


In [None]:
from strikepoint.producer import FrameProducer
from threading import Thread

with Picamera2() as picam:
    with LeptonDriver() as driver:
        driver.setLogFile('logsFromDevNotebook.log')
        driver.startPolling()
        # TODO: Use Picamera2 to set up the camera properly
        # picam2 = Picamera2()
        producer = FrameProducer(driver, picam, fps=9, depth=4)
        image = producer.getFrame()
        showFrame(image)
        # # this is bad, need something to synchronize here
        # showFrame(producer.lastThermalRaw)
        # showFrame(producer.lastThermalDiff)
        # showFrame(producer.lastVisualRaw)
        producer.stop()

In [None]:
from strikepoint.frames import FrameReader


def detect_circles(frame, *,
                   showTarget=None, dp=1.2, minDist=30,
                   param1=100, param2=30,
                   minRadius=5, maxRadius=120):
    """
    Detect circles with HoughCircles. Accepts color, gray or float frames.
    Returns an (N,3) array of (x,y,r) if found, otherwise None.
    """
    img = frame.copy()
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blurred = cv2.medianBlur(gray, 5)
    circles = cv2.HoughCircles(blurred,
                               cv2.HOUGH_GRADIENT,
                               dp, minDist,
                               param1=param1, param2=param2,
                               minRadius=minRadius, maxRadius=maxRadius)

    if circles is not None:
        circles = np.round(circles[0]).astype(int)
        vis = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
        for (x, y, r) in circles:
            cv2.circle(vis, (x, y), r, (0, 255, 0), 2)
            cv2.circle(vis, (x, y), 2, (0, 0, 255), 3)
        if showTarget is not None and showTarget == len(circles):
            # showFrame(vis)
            print(f"found circle at ({x}, {y}), r={r}")
        return list(circles)

    return list()


def findTargetCircleCount(frame, targetCount, *,
                          showTarget=None, dp=1.2, minDist=30,
                          param1=100, param2=30,
                          minRadius=10, maxRadius=50):
    radius = minRadius
    frame = frame.copy()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    frame = cv2.medianBlur(frame, 5)
    circles = cv2.HoughCircles(frame, cv2.HOUGH_GRADIENT, dp, minDist,
                               param1=param1, param2=param2,
                               minRadius=minRadius, maxRadius=maxRadius)
    
    if circles is None:
        return list()

    intensity = list()    
    circles = np.round(circles[0]).astype(int)
    for (x, y, r) in circles:
        mask = np.zeros(frame.shape, dtype=np.uint8)
        cv2.circle(mask, (x, y), r, 255, -1)
        # showFrame(mask)
        meanVal = cv2.mean(frame, mask=mask)[0]
        print(f"Circle at ({x}, {y}), r={r} has mean intensity {meanVal}")
        intensity.append((meanVal, (x, y, r)))

    intensity.sort(key=lambda t: t[0], reverse=True)
    vis = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
    for _, (x, y, r) in intensity[:targetCount]:
        cv2.circle(vis, (x, y), r, (0, 255, 0), 2)
        cv2.circle(vis, (x, y), 2, (0, 0, 255), 3)
    showFrame(vis)
    return list(a[1] for a in intensity[:targetCount])


with FrameReader("demo-three-balls.bin") as reader:

    count = 0
    while (frame := reader.readFrame()) is not None:
        count += 1
        (visFrame, thermFrame, _) = np.hsplit(frame, 3)
        if count > 25:
            break
        # showFrame(thermFrame)
        # visCircles = findTargetCircleCount(visFrame, 1, startRadius=10)
        try:
            visCircles = findTargetCircleCount(visFrame, 3)
            thermCircles = findTargetCircleCount(thermFrame, 3)
        except RuntimeError as ex:
            continue

        # if len(visCircles) == 3 and len(thermCircles) == 3:
        #     thermCircles.sort(key=lambda c: (c[0], c[1]))
        #     ptsSrc = np.array(
        #         list((p[0], p[1]) for p in thermCircles), dtype=np.float32)
        #     visCircles.sort(key=lambda c: (c[0], c[1]))
        #     ptsDst = np.array(
        #         list((p[0], p[1]) for p in visCircles), dtype=np.float32)
        #     M = cv2.getAffineTransform(ptsSrc, ptsDst)

        #     h, w = thermFrame.shape[:2]
        #     warped = cv2.warpAffine(
        #         thermFrame, M, (w, h), flags=cv2.INTER_LINEAR)
        #     showFrame(warped)

            # print(f"Frame {i}: Detected circles at:")
            # for j in range(3):
            #     x1, y1, r1 = visCircles[j]
            #     x2, y2, r2 = thermalCircles[j]
            #     print(f"  Ball {j}: Thermal=({x1},{y1},{r1}) Visual=({x2},{y2},{r2})")