In [None]:
# 1.canny边缘检测  2.mask   3.霍夫变换   4.离群值过滤    5.最小二乘拟合     6.绘制直线
import math
from re import L
from turtle import right
import cv2
import numpy as np
import matplotlib.pyplot as plt
from regex import T
from skimage import morphology
from sympy import true


def edge(frame):
    #如果不是灰度图，转换为灰度图
    if len(frame.shape)==3:
        gray = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY) #转换为灰度图
    else:
        gray=frame
    blurred = cv2.GaussianBlur(gray,(5,5),0) #高斯滤波
    edge = cv2.Canny(blurred,50,150) #canny边缘检测
    return edge

def hsv(frame):
    dst = cv2.GaussianBlur(frame,(5,5),0) #对图像进行第一次高斯滤波
    # dst = cv2.medianBlur(frame,5)
    hsv = cv2.cvtColor(dst,cv2.COLOR_BGR2HSV) #使用HSV颜色空间
    lower_yellow= np.array([26,43,46])
    upper_yellow = np.array([34,255,255])
    lower_white = np.array([0,0,200])
    upper_white = np.array([180,30,255])
    # dark_blue = np.uint8([[[12,22,121]]])
    # dark_blue = cv2.cvtColor(dark_blue,cv2.COLOR_BGR2HSV)

    mask = cv2.inRange(hsv,lower_yellow,upper_yellow) + cv2.inRange(hsv,lower_white,upper_white) #提取所需颜色
    # 先腐蚀后膨胀
    # mask=cv2.inRange(hsv,lower_blue,upper_blue)
    mask = morphology.erosion(mask,morphology.square(width=5)) #形态学方法滤除噪声
    mask= morphology.dilation(mask,morphology.square(11))
    mask= morphology.erosion(mask,morphology.square(5))
    return mask


def calculate_slope(line):

    if len(line[0]) == 4:
        x_1, y_1, x_2, y_2 = line[0]
    else:
        x_1 = line[0][0]
        y_1 = line[0][1]
        x_2 = line[1][0]
        y_2 = line[1][1]
    if x_2 - x_1 == 0:
        return 0
    return (y_2 - y_1) / (x_2 - x_1)


def reject_abnormal_lines(lines, threshold):
    """剔出斜率不一致的线段"""
    if len(lines) == 0:
        return lines
    x_coords = np.ravel([[line[0][0]] for line in lines])
    # print("x_coords",x_coords)
    x_mean=np.mean(x_coords)
    diff=abs(x_coords-x_mean)
    # print("lines",lines)
    for i in range(len(diff) - 1, -1, -1):
        if diff[i] > 200:
            # print("i", i)
            lines.pop(i)
            
    # print("slope_lines",lines)
    slopes = [calculate_slope(line) for line in lines]
    # print("slopes",slopes)
    while len(lines) > 0:
        mean = np.mean(slopes)
        diff = [abs(s - mean) for s in slopes]
        idx = np.argmax(diff)
        if diff[idx] > threshold:
            slopes.pop(idx)
            lines.pop(idx)
        else:
            break
    return lines


def least_squares_fit(lines):

    x_coords = np.ravel([[line[0][0], line[0][2]] for line in lines])
    y_coords = np.ravel([[line[0][1], line[0][3]] for line in lines])  # 取出所有标点


    poly = np.polyfit(x_coords, y_coords, deg=1)  # 进行直线拟合，得到多项式系数
    point_min = (np.min(x_coords), np.polyval(poly, np.min(x_coords)))
    point_max = (
        np.max(x_coords),
        np.polyval(poly, np.max(x_coords)),
    )  # 根据多项式系数，计算两个直线上的点
    return np.array([point_min, point_max], dtype=np.int64)


def is_vertical(left_lines, right_lines):
    """剔除垂直的线段"""
    k1 = calculate_slope(left_lines)
    k2 = calculate_slope(right_lines)
    # print("k1_left",k1,"k2",k2)
    if k1*k2>0:
        return False
    theta1 = math.atan(k1)
    theta2 = math.atan(k2)
    delta = math.pi + theta2 - theta1
    # print("theta1",theta1,"theta2",theta2,"delta",delta)
    max_delta = math.pi * 2 / 5
    if delta > max_delta:
        return True
    else:
        return False


def get_line(frame):
    lzero=0
    rzero=0
    img0=frame
    hsv_mask=hsv(frame)
    edge_mask=edge(frame)
    """2.roi_mask(提取感兴趣的区域)"""
    mask = np.zeros_like(edge_mask)  # 变换为numpy格式的图片
    mask = cv2.fillPoly(
        mask,
        np.array(
            object=[
                [
                    [0, 250],
                    [950, 250],
                    [950, 540],
                    [0, 540],
                ]
            ]
        ),
        color=255,
    )  # 对感兴趣区域制作掩膜
    roi_masked_edge_img = cv2.bitwise_and(edge_mask, mask)  # 与运算
    # cv2.imshow('masked_edge_img',masked_edge_img)
    # cv2.waitKey(0)
    # #面积过滤
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(roi_masked_edge_img, connectivity=8)
    filtered = np.zeros_like(roi_masked_edge_img)
    for i in range(1, num_labels):  # 跳过背景
        area = stats[i, cv2.CC_STAT_AREA]
        # print(area)
        if area >200:  # 设置面积阈值，保留较大的线段（可调整）
            filtered[labels == i] = 255
    hsvedge=cv2.bitwise_and(filtered,hsv_mask)
    hsvedge=morphology.dilation(hsvedge,morphology.square(5))
    img0[hsvedge>0]=[0,0,255]
    
    
    
    dia_masked_edge_img=morphology.dilation(roi_masked_edge_img,morphology.square(3))
    hsv_masked_edge_img=cv2.bitwise_and(hsv_mask,dia_masked_edge_img)
    # cv2.imshow('masked_edge_img',masked_edge_img)
    # cv2.waitKey(0)
    # cv2.imshow('hsv_mask',hsv_mask)
    # cv2.waitKey(0)
    # show
    # cv2.imshow('hsv_masked_edge_img',hsv_masked_edge_img)
    # cv2.imshow('edge_mask',edge_mask)
    # cv2.imshow('hsv_mask',hsv_mask)
    # cv2.imshow("dia_masked_edge_img",dia_masked_edge_img)
    # cv2.imshow("roi_masked_edge_img",roi_masked_edge_img)
    # cv2.waitKey(0)

    lines = cv2.HoughLinesP(
        hsv_masked_edge_img, 1, np.pi / 180, 15, minLineLength=50, maxLineGap=30
    )  # 获取所有线段
    # print("lines",lines)
    # for line in lines:
    #     x1, y1, x2, y2 = line[0]
    #     cv2.line(img0, (x1, y1), (x2, y2), (0, 255, 0), 2)
    if lines is None:
        return img0

    left_lines = [line for line in lines if calculate_slope(line) >= 0]
    right_lines = [line for line in lines if calculate_slope(line) < 0]
    if(len(left_lines)==0 or len(right_lines)==0):
        right_lines = [line for line in lines if line[0][0] > 480]
        left_lines = [line for line in lines if line[0][2] < 480]

    # print("left_lines",left_lines)
    # print("right_lines",right_lines)
    # if left_lines is not None:
    #      left_lines = [line for line in left_lines if calculate_slope(left_lines) >= 0]
    # if right_lines is not None:
    #     right_lines = [line for line in right_lines if calculate_slope(right_lines) < 0]

    reject_abnormal_lines(left_lines, threshold=0.1)
    reject_abnormal_lines(right_lines, threshold=0.1)
    # print(left_lines)
    # print(right_lines)
    # 将这些点画在图上
    # for line in left_lines:
    #     cv2.line(
    #         img0, tuple(line[0][:2]), tuple(line[0][2:]), color=(0, 255, 0), thickness=2
    #     )
    # for line in right_lines:
    #     cv2.line(
    #         img0, tuple(line[0][:2]), tuple(line[0][2:]), color=(0, 255, 0), thickness=2
    #     )

    if len(left_lines)!=0:
        left_lines = least_squares_fit(left_lines)
    else:
        left_lines=[[0,0],[0,0]]
        lzero=1
    if len(right_lines)!=0:
        right_lines = least_squares_fit(right_lines)
    else:
        right_lines=[[0,0],[0,0]]
        rzero=1

    # print("left_lines",left_lines)
    # print("right_lines",right_lines)
    
    if calculate_slope(left_lines)*calculate_slope(right_lines)>0:
        left_lines=[[0,0],[0,0]]
        lzero=1
    
    left_length=math.sqrt((left_lines[0][0]-left_lines[1][0])**2+(left_lines[0][1]-left_lines[1][1])**2)
    right_length=math.sqrt((right_lines[0][0]-right_lines[1][0])**2+(right_lines[0][1]-right_lines[1][1])**2)
    # print("left_length",left_length)
    # print("right_length",right_length)
    if  not (lzero==1 or rzero==1):
        if is_vertical(left_lines, right_lines):
            if left_length<right_length:
                left_lines=[[0,0],[0,0]]
            else:
                right_lines=[[0,0],[0,0]] 

    cv2.line(
        img0, tuple(left_lines[0]), tuple(left_lines[1]), color=(255, 0, 0), thickness=5
    )
    cv2.line(
        img0, tuple(right_lines[0]), tuple(right_lines[1]), color=(255, 0, 0), thickness=5
    )
    return img0


if __name__ == "__main__":
    def frame(a):
        frame = cv2.imread(f'frame/{a}.jpg')
        out=get_line(frame)
        cv2.imshow('frame',out)
        cv2.waitKey(0)
        cv2.destroyAllWindows
    
    def video():
        cap = cv2.VideoCapture("video.mp4")  # 或者使用摄像头：cv2.VideoCapture(0)

    # 获取视频的宽度、高度和帧率
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(cap.get(cv2.CAP_PROP_FPS))

        # 定义视频编解码器并创建 VideoWriter 对象
        out = cv2.VideoWriter(
            "traffic_video.mp4",
            cv2.VideoWriter_fourcc(*"mp4v"),
            fps,
            (frame_width, frame_height),
        )
        i=0
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            print("processed_frame: ", i)
            i+=1
            # 处理帧
            processed_frame = get_line(frame)

            # 写入处理后的帧
            out.write(processed_frame)

        # 释放资源
        cap.release()
        out.release()
        cv2.destroyAllWindows()
        
        
    # frame(30)
    video()

processed_frame:  0
processed_frame:  1
processed_frame:  2
processed_frame:  3
processed_frame:  4
processed_frame:  5
processed_frame:  6
processed_frame:  7
processed_frame:  8
processed_frame:  9
processed_frame:  10
processed_frame:  11
processed_frame:  12
processed_frame:  13
processed_frame:  14
processed_frame:  15
processed_frame:  16
processed_frame:  17
processed_frame:  18
processed_frame:  19
processed_frame:  20
processed_frame:  21
processed_frame:  22
processed_frame:  23
processed_frame:  24
processed_frame:  25
processed_frame:  26
processed_frame:  27
processed_frame:  28
processed_frame:  29
processed_frame:  30
processed_frame:  31
processed_frame:  32
processed_frame:  33
processed_frame:  34
processed_frame:  35
processed_frame:  36
processed_frame:  37
processed_frame:  38
processed_frame:  39
processed_frame:  40
processed_frame:  41
processed_frame:  42
processed_frame:  43
processed_frame:  44
processed_frame:  45
processed_frame:  46
processed_frame:  47
pr