In [1]:
import torch.nn as nn
import numpy as np
import cv2
import torch
import os
from torch.utils.data import DataLoader, Dataset
import math
import torch.nn.functional as F
import csv
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from B_spline import BS_curve
from utils import keyP,cobb_angle_line
path_heatmap=r'D:\Project\Xiehe_Spinal_image_stitching\cobb\Heatmap'
path_image=r'D:\Project\Xiehe_Spinal_image_stitching\cobb\ke30_u7_AASCE2019-master\boostnet_labeldata'
class CobbNetDataset(Dataset):
    def __init__(self, path,path_heatmap, train=True):
        self.names = []
        self.labels = []
        self.cobb_angles=[]
        self.heatmap_names=[]
        self.scale=4#heatmap的缩放倍数
        self.train=train
        if train:
            image_path = path + "/data/training_preprocessed/"  # 原图
            heatmaps_path = path_heatmap+"/pred_training_cp/"   # 对应的热图
            names = csv.reader(open(path + "/labels/training/filenames.csv", 'r'))  # 文件名
            cobb_angles= csv.reader(open(path + "/labels/training/angles.csv", 'r')) # 对应的cobb角
            names=list(names)
            self.names=[image_path+n[0] for n in names]
            self.heatmap_names=[heatmaps_path+n[0] for n in names]
            landmarks = csv.reader(open(path + "/labels/training/landmarks.csv", 'r'))
        else:
            image_path = path + "/data/test/"
            heatmaps_path = path_heatmap+"/pred_test_cp/"
            names = csv.reader(open(path + "/labels/test/filenames.csv", 'r'))
            cobb_angles= csv.reader(open(path + "/labels/test/angles.csv", 'r'))
            names=list(names)
            self.names=[image_path+n[0] for n in names]
            self.heatmap_names=[heatmaps_path+n[0] for n in names]
            landmarks = csv.reader(open(path + "/labels/test/landmarks.csv", 'r'))
        
        for landmark_each_image in landmarks:  # 地标
            coordinate_list = []
            for coordinate in landmark_each_image:
                coordinate_list.append(float(coordinate))
            self.labels.append(coordinate_list)
            
        for cobb_each in cobb_angles:  # cobb角
            cobb_list = []
            for cobb in cobb_each:
                cobb_list.append(float(cobb))
            self.cobb_angles.append(cobb_list)
             

    def pad_img(self, img, flag=True): 
          
        h,w=img.shape[:2]
        if(flag):
            h_max=3840
            w_max=1536
        else:
            h_max=960
            w_max=384
        top = math.floor((h_max - h)/2)
        bottom = round((h_max - h)/2+0.1)
        left = math.floor((w_max - w) / 2)
        right = round((w_max - w) / 2+0.1)#四舍五入
        image_padded = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
        
        return image_padded

    def __getitem__(self, index):
        image_name = self.names[index]
        label = self.labels[index]
        cobb_angle_GT=np.array(self.cobb_angles[index])

        heatmap_name=self.heatmap_names[index]
        origin_image = cv2.imread(image_name)
        image_padded=self.pad_img(origin_image)

        target_height = image_padded.shape[0]//4 # 3840/4=960 缩小原图输入
        target_width =image_padded.shape[1]//4 # 1536/4=384
        image_resize=cv2.resize(image_padded, (target_width, target_height))#结果存在image
        heatmap = cv2.imread(heatmap_name,0)
        heatmap_padded=self.pad_img(heatmap,False)
        kp = keyP(heatmap)
        heatmap_y=[coord[0] for coord in kp]
        heatmap_x=[coord[1] for coord in kp]
        bs=BS_curve(9,3)  #10个控制点，3次B样条
        kp_pred = np.array([heatmap_y,heatmap_x]).T # 基于热图计算B样条参数
        paras = bs.estimate_parameters(kp_pred) # B样条参数
        knots = bs.get_knots() # 节点
        if bs.check():
            cp = bs.approximation(kp_pred) # 控制点
        uq = np.linspace(0,1,34)
        y_c = np.array(bs.bs(uq)) # 计算B样条曲线
        cobb_angle=np.array(cobb_angle_line(y_c)) # 基于预测的中心点计算角度
        xs=[]
        ys=[]
        p=label
        img_src_resize=cv2.resize(origin_image, (target_width, target_height))
        h,w=img_src_resize.shape[:2]
        num_p = len(p)//2
        for i in range(0,num_p,2):
            xs.append((p[i]+p[i+1])*w/2) # 原图像脊柱中心点
            ys.append((p[i+num_p]+p[i+num_p+1])*h/2)

        bs=BS_curve(9,3) # 10个控制点，3次B样条
        kp_GT = np.array([ys,xs]).T # 基于GT控制点计算B样条参数
        paras_GT = bs.estimate_parameters(kp_GT)
        knots_GT = bs.get_knots()
        if bs.check():
            cp_GT = bs.approximation(kp_GT)

 
        # heatmap_resize=cv2.resize(heatmap_padded ,(target_width//self.scale,target_height//self.scale)) #960/4=240 ,384/4=96缩小heatmap输入
        
        image_resize = torch.tensor(image_resize, dtype=torch.float32)
        heatmap = torch.tensor(heatmap, dtype=torch.float32)

        
        return origin_image.shape,image_resize,label,heatmap_padded,image_name,kp_pred,cp,knots,cp_GT,knots_GT,cobb_angle,cobb_angle_GT

        
        
    def __len__(self):
        return len(self.names)



### 模型定义

In [30]:
class KeypointBSplineNet(nn.Module):
    """
    基于关键点的B样条Cobb角测量网络
    输入：关键点坐标 (data_pred)
    输出：控制点 (cp) 和节点 (knots)
    通过B样条解析计算Cobb角
    """
    
    def __init__(self, num_keypoints=34, num_control_points=10, degree=3, num_angles=4):
        super(KeypointBSplineNet, self).__init__()
        self.num_keypoints = num_keypoints
        self.num_control_points = num_control_points
        self.degree = degree
        self.num_angles = num_angles
        
        # 输入维度：num_keypoints * 2 (x, y坐标)
        input_dim = num_keypoints * 2
        
        # 控制点输出维度：num_control_points * 2 (x, y坐标)
        cp_output_dim = num_control_points * 2
        
        # 节点输出维度：num_control_points + degree + 1
        knots_output_dim = num_control_points + degree + 1
        
        # 共享特征提取层
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # 控制点预测头
        self.cp_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, cp_output_dim)
        )
        
        # 节点预测头
        self.knots_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, knots_output_dim),
            nn.Sigmoid()  # 确保节点值在[0,1]范围内
        )
        
        # B样条处理器
        self.bspline_processor = BSplineProcessor(num_control_points, degree)
    
    def forward(self, keypoints):
        """
        前向传播
        Args:
            keypoints: 关键点坐标 [B, num_keypoints, 2] 或 [B, num_keypoints*2]
        Returns:
            cp: 预测的控制点 [B, num_control_points, 2]
            knots: 预测的节点 [B, num_control_points + degree + 1]
            cobb_angles: 计算的Cobb角 [B, num_angles]
        """
        batch_size = keypoints.shape[0]
        
        # 确保输入是2D的
        if keypoints.dim() == 3:
            keypoints = keypoints.view(batch_size, -1)  # [B, num_keypoints*2]
        
        # 特征提取
        features = self.feature_extractor(keypoints)  # [B, 256]
        
        # 预测控制点
        cp_flat = self.cp_head(features)  # [B, num_control_points*2]
        cp = cp_flat.view(batch_size, self.num_control_points, 2)  # [B, num_control_points, 2]
        
        # 预测节点
        knots = self.knots_head(features)  # [B, num_control_points + degree + 1]
        
        # 计算Cobb角
        cobb_angles = self._compute_cobb_angles(keypoints, cp, knots)
        
        return cp, knots, cobb_angles
    
    def _compute_cobb_angles(self, keypoints, cp, knots):
        """
        通过B样条计算Cobb角
        Args:
            keypoints: 原始关键点 [B, num_keypoints*2]
            cp: 控制点 [B, num_control_points, 2]
            knots: 节点 [B, num_control_points + degree + 1]
        Returns:
            cobb_angles: Cobb角 [B, num_angles]
        """
        batch_size = cp.shape[0]
        cobb_angles_list = []
        
        for b in range(batch_size):
            try:
                # 创建B样条对象
                bs = BS_curve(self.num_control_points - 1, self.degree)
                bs.cp = cp[b].detach().cpu().numpy()
                bs.u = knots[b].detach().cpu().numpy()
                bs.m = knots.shape[1] - 1
                
                # 检查B样条是否有效
                if bs.check():
                    # 采样34个点
                    uq = np.linspace(0, 1, 34)
                    curve_points = np.array(bs.bs(uq))  # [34, 2]
                    
                    # 计算Cobb角
                    cobb_angle = np.array(cobb_angle_line(curve_points))
                    cobb_angles_list.append(cobb_angle)
                else:
                    # 如果B样条无效，返回零角度
                    cobb_angles_list.append(np.zeros(self.num_angles))
                    
            except Exception as e:
                print(f"B样条计算错误: {e}")
                cobb_angles_list.append(np.zeros(self.num_angles))
        
        return torch.tensor(np.array(cobb_angles_list), dtype=torch.float32, device=cp.device)
    
    def predict_from_keypoints(self, keypoints):
        """
        从关键点预测Cobb角（推理模式）
        Args:
            keypoints: 关键点坐标 [num_keypoints, 2]
        Returns:
            cobb_angles: 预测的Cobb角
        """
        self.eval()
        with torch.no_grad():
            if keypoints.dim() == 2:
                keypoints = keypoints.unsqueeze(0)  # 添加batch维度
            
            cp, knots, cobb_angles = self.forward(keypoints)
            return cobb_angles.squeeze(0)  # 移除batch维度


class BSplineProcessor:
    """B样条处理器"""
    
    def __init__(self, num_control_points=10, degree=3):
        self.num_control_points = num_control_points
        self.degree = degree
    
    def fit_curve_from_keypoints(self, keypoints):
        """
        从关键点拟合B样条曲线
        Args:
            keypoints: 关键点坐标 [num_keypoints, 2]
        Returns:
            bs: B样条对象
            cp: 控制点
            knots: 节点
        """
        bs = BS_curve(self.num_control_points - 1, self.degree)
        data = np.array(keypoints)
        paras = bs.estimate_parameters(data)
        knots = bs.get_knots()
        
        if bs.check():
            cp = bs.approximation(data)
            return bs, cp, knots
        else:
            return None, None, None


class KeypointBSplineLoss(nn.Module):
    """KeypointBSplineNet的损失函数"""
    
    def __init__(self, cp_weight=1.0, knots_weight=1.0, angle_weight=1.0):
        super(KeypointBSplineLoss, self).__init__()
        self.cp_weight = cp_weight
        self.knots_weight = knots_weight
        self.angle_weight = angle_weight
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred_cp, pred_knots, pred_angles, 
                gt_cp, gt_knots, gt_angles):
        """
        计算总损失
        Args:
            pred_cp: 预测控制点 [B, num_control_points, 2]
            pred_knots: 预测节点 [B, num_control_points + degree + 1]
            pred_angles: 预测角度 [B, num_angles]
            gt_cp: 真实控制点 [B, num_control_points, 2]
            gt_knots: 真实节点 [B, num_control_points + degree + 1]
            gt_angles: 真实角度 [B, num_angles]
        Returns:
            total_loss: 总损失
            loss_dict: 各项损失详情
        """
        # 控制点损失
        cp_loss = self.mse_loss(pred_cp, gt_cp)
        
        # 节点损失
        knots_loss = self.mse_loss(pred_knots, gt_knots)
        
        # 角度损失
        angle_loss = self.mse_loss(pred_angles, gt_angles)
        
        # 总损失
        total_loss = (self.cp_weight * cp_loss + 
                     self.knots_weight * knots_loss + 
                     self.angle_weight * angle_loss)
        
        loss_dict = {
            'total_loss': total_loss.item(),
            'cp_loss': cp_loss.item(),
            'knots_loss': knots_loss.item(),
            'angle_loss': angle_loss.item()
        }
        
        return total_loss, loss_dict

### train

In [35]:


batch_size= 4

train_dataset = CobbNetDataset(path_image,path_heatmap,train=True)
train_loader = DataLoader(train_dataset,batch_size,shuffle=False,num_workers=0)

test_dasaset = CobbNetDataset(path_image,path_heatmap,train=False)
test_loader = DataLoader(test_dasaset,1,shuffle=True,num_workers=0)
model = KeypointBSplineNet(num_keypoints=34, num_control_points=10, degree=3)

# 创建损失函数
criterion = KeypointBSplineLoss()
    
i = 1
for origin_shape,image_resize,label,heatmap_padded,image_name,kp_pred,cp,knots,cp_GT,knots_GT,cobb_angle,cobb_angle_GT in train_loader:


    # 模拟数据
    print(origin_shape)
    # 前向传播
    pred_cp, pred_knots, pred_angles = model(kp_pred.to(torch.float32))
    
    print(f"预测控制点形状: {pred_cp.shape}")
    print(f"预测节点形状: {pred_knots.shape}")
    print(f"预测角度形状: {pred_angles.shape}")
    
    # 计算损失
    total_loss, loss_dict = criterion(pred_cp, pred_knots, pred_angles,
                                    cp_GT, knots_GT, cobb_angle_GT)
    
    print(f"总损失: {total_loss.item():.4f}")
    print(f"各项损失: {loss_dict}")
    break


    

[tensor([2125, 2364, 1485, 1572]), tensor([ 755, 1078,  505,  525]), tensor([3, 3, 3, 3])]
B样条计算错误: zero-size array to reduction operation maximum which has no identity
B样条计算错误: zero-size array to reduction operation maximum which has no identity
B样条计算错误: zero-size array to reduction operation maximum which has no identity
B样条计算错误: zero-size array to reduction operation maximum which has no identity
预测控制点形状: torch.Size([4, 10, 2])
预测节点形状: torch.Size([4, 14])
预测角度形状: torch.Size([4, 4])


RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1