In [None]:
# curve_extraction_model.ipynb

# Notebook形式完整项目：输入矩阵图像，输出多条曲线及其类型
# 作者：ChatGPT（OpenAI）

### 第 1 部分：依赖导入
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
import json
import os
import numpy as np
from PIL import Image

### 第 2 部分：数据集类
def load_json(json_path):
    with open(json_path, 'r') as f:
        return json.load(f)

class CurveDataset(Dataset):
    def __init__(self, data_dir, max_seq_len=64, transform=None):
        self.data_dir = data_dir
        self.samples = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".json")]
        self.max_seq_len = max_seq_len
        self.transform = transform or T.Compose([
            T.Resize((128, 128)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        label_path = self.samples[idx]
        sample = load_json(label_path)
        image = Image.open(os.path.join(self.data_dir, sample['image_path'])).convert('L')
        image = self.transform(image)  # shape: (1, H, W)

        # 构造线标签
        num_lines = len(sample['lines'])
        coords = torch.zeros((num_lines, self.max_seq_len, 2))
        masks = torch.ones((num_lines, self.max_seq_len))  # 1 表示 padding
        types = torch.tensor([line['type'] for line in sample['lines']])

        for i, line in enumerate(sample['lines']):
            pts = line['coords'][:self.max_seq_len]
            coords[i, :len(pts)] = torch.tensor(pts)
            masks[i, :len(pts)] = 0

        return image, types, coords, masks

### 第 3 部分：模型结构
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(pretrained=True)
        self.backbone = nn.Sequential(*list(base.children())[:-2])  # 去掉全连接层

    def forward(self, x):
        return self.backbone(x)  # (B, 512, H', W')

class CurveDecoder(nn.Module):
    def __init__(self, hidden_dim=256, num_queries=10, num_classes=4, max_seq_len=64):
        super().__init__()
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.transformer = nn.Transformer(hidden_dim, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
        self.input_proj = nn.Conv2d(512, hidden_dim, kernel_size=1)

        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.coord_embed = nn.Linear(hidden_dim, max_seq_len * 2)

    def forward(self, src):
        B = src.shape[0]
        src = self.input_proj(src).flatten(2).permute(2, 0, 1)  # (HW, B, C)
        tgt = self.query_embed.weight.unsqueeze(1).repeat(1, B, 1)  # (num_queries, B, C)
        hs = self.transformer(src, tgt)  # (num_queries, B, C)
        hs = hs.permute(1, 0, 2)  # (B, num_queries, C)

        pred_classes = self.class_embed(hs)  # (B, N, num_classes)
        pred_coords = self.coord_embed(hs).reshape(B, hs.size(1), -1, 2)  # (B, N, max_seq_len, 2)
        return pred_classes, pred_coords

### 第 4 部分：训练示例（伪代码）
def train(model, dataloader, optimizer):
    model.train()
    for batch in dataloader:
        image, types, coords, masks = [b.to(device) for b in batch]
        features = model.encoder(image)
        pred_cls, pred_coords = model.decoder(features)

        # 示例损失函数（需引入 Hungarian matching 做优化）
        cls_loss = F.cross_entropy(pred_cls.view(-1, pred_cls.shape[-1]), types.view(-1))
        coord_loss = F.mse_loss(pred_coords, coords)
        loss = cls_loss + coord_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

### 第 5 部分：完整模型封装
class FullCurveModel(nn.Module):
    def __init__(self, num_classes=4, max_seq_len=64):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = CurveDecoder(num_classes=num_classes, max_seq_len=max_seq_len)

    def forward(self, x):
        features = self.encoder(x)
        return self.decoder(features)

### 结尾：推理函数
@torch.no_grad()
def predict(model, image_tensor, threshold=0.5):
    model.eval()
    features = model.encoder(image_tensor.unsqueeze(0))
    pred_cls, pred_coords = model.decoder(features)

    pred_cls = torch.argmax(F.softmax(pred_cls, dim=-1), dim=-1).squeeze(0)
    pred_coords = pred_coords.squeeze(0)

    result = []
    for i in range(pred_cls.shape[0]):
        result.append({
            "type": int(pred_cls[i].item()),
            "coords": pred_coords[i].cpu().tolist()
        })
    return result
