In [None]:
import math
import cv2
import numpy as np
from skimage import morphology


def edge(frame):
    #如果不是灰度图，转换为灰度图
    if len(frame.shape)==3:
        gray = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY) #转换为灰度图
    else:
        gray=frame
    blurred = cv2.GaussianBlur(gray,(5,5),0) #高斯滤波
    edge = cv2.Canny(blurred,50,150) #canny边缘检测
    return edge

def hsv(frame):
    dst = cv2.GaussianBlur(frame,(5,5),0) #对图像进行第一次高斯滤波
    # dst = cv2.medianBlur(frame,5)
    hsv = cv2.cvtColor(dst,cv2.COLOR_BGR2HSV) #使用HSV颜色空间
    lower_yellow= np.array([26,43,46])
    upper_yellow = np.array([34,255,255])
    lower_white = np.array([0,0,201])
    upper_white = np.array([180,45,255])
    # dark_blue = np.uint8([[[12,22,121]]])
    # dark_blue = cv2.cvtColor(dark_blue,cv2.COLOR_BGR2HSV)

    mask = cv2.inRange(hsv,lower_yellow,upper_yellow) + cv2.inRange(hsv,lower_white,upper_white) #提取所需颜色
    # 先腐蚀后膨胀
    # mask=cv2.inRange(hsv,lower_blue,upper_blue)
    mask = morphology.erosion(mask,morphology.square(width=5)) #形态学方法滤除噪声
    mask= morphology.dilation(mask,morphology.square(11))
    mask= morphology.erosion(mask,morphology.square(5))
    return mask


def calculate_slope(line):

    if len(line[0]) == 4:
        x_1, y_1, x_2, y_2 = line[0]
    else:
        x_1 = line[0][0]
        y_1 = line[0][1]
        x_2 = line[1][0]
        y_2 = line[1][1]
    if x_2 - x_1 == 0:
        return 0
    return (y_2 - y_1) / (x_2 - x_1)


def reject_abnormal_lines(lines, threshold):
    """剔出斜率不一致的线段"""
    if len(lines) == 0:
        return lines
    x_coords = np.ravel([[line[0][0]] for line in lines])
    # print("x_coords",x_coords)
    x_mean=np.mean(x_coords)
    diff=abs(x_coords-x_mean)
    # print("lines",lines)
    for i in range(len(diff) - 1, -1, -1):
        if diff[i] > 100:
            # print("i", i)
            lines.pop(i)
            
    # print("slope_lines",lines)
    slopes = [calculate_slope(line) for line in lines]
    # print("slopes",slopes)
    while len(lines) > 0:
        mean = np.mean(slopes)
        diff = [abs(s - mean) for s in slopes]
        idx = np.argmax(diff)
        if diff[idx] > threshold:
            slopes.pop(idx)
            lines.pop(idx)
        else:
            break
    return lines


def least_squares_fit(lines):

    x_coords = np.ravel([[line[0][0], line[0][2]] for line in lines])
    y_coords = np.ravel([[line[0][1], line[0][3]] for line in lines])  # 取出所有标点


    poly = np.polyfit(x_coords, y_coords, deg=1)  # 进行直线拟合，得到多项式系数
    point_min = (np.min(x_coords), np.polyval(poly, np.min(x_coords)))
    point_max = (
        np.max(x_coords),
        np.polyval(poly, np.max(x_coords)),
    )  # 根据多项式系数，计算两个直线上的点
    return np.array([point_min, point_max], dtype=np.int64)


def is_vertical(left_lines, right_lines):
    """剔除垂直的线段"""
    k1 = calculate_slope(left_lines)
    k2 = calculate_slope(right_lines)
    # print("k1_left",k1,"k2",k2)
    if k1*k2>0:
        return False
    theta1 = math.atan(k1)
    theta2 = math.atan(k2)
    delta = math.pi + theta2 - theta1
    # print("theta1",theta1,"theta2",theta2,"delta",delta)
    max_delta = math.pi * 2 / 5
    if delta > max_delta:
        return True
    else:
        return False
    
    
def is_cross(left_lines,right_lines):
    x1,y1=left_lines[0]
    x2,y2=left_lines[1]
    x3,y3=right_lines[0]
    x4,y4=right_lines[1]
    k1=(y2-y1)/(x2-x1)
    k2=(y4-y3)/(x4-x3)
    b1=y1-k1*x1
    b2=y3-k2*x3
    x=(b2-b1)/(k1-k2)
    y=k1*x+b1
    if x>0 and x<950 and y>250 and y<540:
        return True
    else:
        return False
    
def distance(x,y,line):
    x1,y1=line[0]
    x2,y2=line[1]
    k=(y2-y1)/(x2-x1)
    b=y1-k*x1
    d=abs(k*x-y+b)/math.sqrt(k**2+1)
    return d


def get_line(frame):
    lzero=0
    rzero=0
    fit=0
    img0=frame
    hsv_mask=hsv(frame)
    # cv2.imshow('hsv',hsv_mask)
    
    edge_mask=edge(frame)
    mask = np.zeros_like(edge_mask)  # 变换为numpy格式的图片
    mask = cv2.fillPoly(
        mask,
        np.array(
            object=[
                [
                    [0, 250],
                    [950, 250],
                    [950, 540],
                    [0, 540],
                ]
            ]
        ),
        color=255,
    )  # 对感兴趣区域制作掩膜
    roi_masked_edge_img = cv2.bitwise_and(edge_mask, mask)  # 与运算

    
    #面积过滤
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(roi_masked_edge_img, connectivity=8)
    filtered = np.zeros_like(roi_masked_edge_img)
    for i in range(1, num_labels):  # 跳过背景
        area = stats[i, cv2.CC_STAT_AREA]
        if area >200:  # 设置面积阈值，保留较大的线段（可调整）
            filtered[labels == i] = 255
    hsvedge=cv2.bitwise_and(filtered,hsv_mask)
    hsvedge=morphology.dilation(hsvedge,morphology.square(5))
    img0[hsvedge>0]=[0,0,255]
    
    
    
    dia_masked_edge_img=morphology.dilation(roi_masked_edge_img,morphology.square(3))
    hsv_masked_edge_img=cv2.bitwise_and(hsv_mask,dia_masked_edge_img)
    # cv2.imshow('hsv_edge',hsv_masked_edge_img)

    lines = cv2.HoughLinesP(
        hsv_masked_edge_img, 1, np.pi / 180, 15, minLineLength=100, maxLineGap=30
    )  
    # print(lines)
    #画出lines检测出的直线
    # for line in lines:
    #     x1, y1, x2, y2 = line[0]
    #     cv2.line(img0, (x1, y1), (x2, y2), (0, 255, 0), 2)

    
    
    if lines is None:
        return img0

    right_lines = [line for line in lines if calculate_slope(line) >= 0]
    left_lines = [line for line in lines if calculate_slope(line) < 0]
    # for line in lines:
    #     print(calculate_slope(line))

    # 拐弯时，斜率同向，重新筛选
    # if(len(left_lines)==0 or len(right_lines)==0):
    #     right_lines = [line for line in lines if line[0][2] > 480]
    #     left_lines = [line for line in lines if line[0][1] < 480]
    if len(left_lines)==0 or len(right_lines)==0:
        lines_slope = [calculate_slope(line) for line in lines]
        line_mean = np.mean(lines_slope)
        left_lines=[line for line in lines if calculate_slope(line)>line_mean]
        right_lines=[line for line in lines if calculate_slope(line)<line_mean]


    reject_abnormal_lines(left_lines, threshold=0.1)
    reject_abnormal_lines(right_lines, threshold=0.1)    

    # for line in right_lines:
    #     x1, y1, x2, y2 = line[0]
    #     cv2.line(img0, (x1, y1), (x2, y2), (0, 255, 0), 2)
        
    # for line in left_lines:
    #     x1, y1, x2, y2 = line[0]
    #     cv2.line(img0, (x1, y1), (x2, y2), (255, 255, 0), 2)

    if len(left_lines)!=0:
        left_lines = least_squares_fit(left_lines)
    else:
        left_lines=[[0,0],[0,0]]
        lzero=1
    if len(right_lines)!=0:
        right_lines = least_squares_fit(right_lines)
    else:
        right_lines=[[0,0],[0,0]]
        rzero=1

    
    left_length=math.sqrt((left_lines[0][0]-left_lines[1][0])**2+(left_lines[0][1]-left_lines[1][1])**2)
    right_length=math.sqrt((right_lines[0][0]-right_lines[1][0])**2+(right_lines[0][1]-right_lines[1][1])**2)

    
    if not (lzero==1 or rzero==1):
        kl=calculate_slope(left_lines)
        kr=calculate_slope(right_lines)
        # print(kl,kr)
        thetal=math.atan(kl) if kl>0 else math.pi+math.atan(kl)
        thetar=math.atan(kr) if kr>0 else math.pi+math.atan(kr)
        theta=thetal-thetar
        # print(theta)
        if abs(theta)<math.pi/10:
            lr=[[left_lines.flatten().tolist()],[right_lines.flatten().tolist()]]
            # print(lr)
            lr_lines=least_squares_fit(lr)
            #计算点到直线距离
            dis=distance(left_lines[0][0],left_lines[0][1],lr_lines)
            if dis<100:
                cv2.line(
                    img0, tuple(lr_lines[0]), tuple(lr_lines[1]), color=(255, 0, 0), thickness=5
                )     
                fit=1
        if is_cross(left_lines,right_lines):
            if left_length<right_length:
                left_lines=[[0,0],[0,0]]
            else:
                right_lines=[[0,0],[0,0]]
    if fit==0:
        cv2.line(
            img0, tuple(left_lines[0]), tuple(left_lines[1]), color=(255, 0, 0), thickness=5
        )
        cv2.line(
            img0, tuple(right_lines[0]), tuple(right_lines[1]), color=(255, 0, 0), thickness=5
        )

    return img0


if __name__ == "__main__":
    def frame(a):
        frame = cv2.imread(f'frame/{a}.jpg')
        out=get_line(frame)
        cv2.imshow('frame',out)
        cv2.waitKey(0)
        cv2.destroyAllWindows
    
    def video():
        cap = cv2.VideoCapture("video.mp4") 
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        out = cv2.VideoWriter(
            "traffic_video.mp4",
            cv2.VideoWriter_fourcc(*"mp4v"),
            fps,
            (frame_width, frame_height),
        )
        i=0
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            print("processed_frame: ", i)
            i+=1
            processed_frame = get_line(frame)
            out.write(processed_frame)
        cap.release()
        out.release()
        cv2.destroyAllWindows()
        
        
    # frame(121)
    video()

processed_frame:  0
1.5022935514760682
processed_frame:  1
1.517386674874328
processed_frame:  2
1.4243319189391594
processed_frame:  3
1.515046972699835
processed_frame:  4
1.5369083881077894
processed_frame:  5
1.5192064273901131
processed_frame:  6
1.5125229007447265
processed_frame:  7
1.5059141613078735
processed_frame:  8
1.5246479427630495
processed_frame:  9
1.5389209359207603
processed_frame:  10
0.21053166487363384
processed_frame:  11
0.06981348663322862
processed_frame:  12
0.11288346890174594
processed_frame:  13
0.05370013596342382
processed_frame:  14
0.07327392592138232
processed_frame:  15
0.10526799923315538
processed_frame:  16
0.133229467335849
processed_frame:  17
0.12252192487691094
processed_frame:  18
0.10558512069436865
processed_frame:  19
0.1564827300744414
processed_frame:  20
0.0884552916156047
processed_frame:  21
0.11762718985731174
processed_frame:  22
0.08274633403691073
processed_frame:  23
0.15153736407302687
processed_frame:  24
0.12361315492023472
p

  x=(b2-b1)/(k1-k2)


processed_frame:  171
1.6260014886582064
processed_frame:  172
1.6301793624092693
processed_frame:  173
1.6455192537363144
processed_frame:  174
1.6426078499163557
processed_frame:  175
1.6377758239937703
processed_frame:  176
1.685629727233734
processed_frame:  177
1.6391459283377325
processed_frame:  178
1.6490034653035088
processed_frame:  179
1.654569270538663
