In [0]:
"""
以下に、提示方針（回転拡張の無効化、正則化強化、角度損失の重み増、線形ラベルのlog1p、分割シャッフル）を反映した「そのまま実行可能な」フルコードを示します。既存のPointNet＋STN＋OneCycleLR＋EMA＋TTAを維持し、前処理（SOR・正規化・PCAアラインメント）や評価保存も含みます。
重要
- params_all.txt の角度が「度（deg）」なら angle_unit_labels='deg' に必ず変更してください。ラジアン前提なら 'rad' のままでOKです。
- 角度が「グローバル座標基準」なら、回転拡張は無効化しています（augment_rotate=False）。主軸基準の角度なら canonical_align=True を試す価値があります。
コード（Train_pipe8-2.py）
"""
# Train_pipe8-2.py (汎化改善版)
# - weight_decay を 1e-3 に上げる、dropout_p を 0.5 に上げる、ft_reg_weight を 0.02 に上げる
# - num_points を 8192 に増やす、batch_size を増やす
import os
import json
import csv
from pathlib import Path
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
# ----------------
# 乱数シード固定
# ----------------
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # 完全決定論が必要なら以下も検討
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
# ----------------
# params_all.txt を 7次元ラベルに復元
# ----------------
def load_params_all_to_vec7(filename):
    """
    params_all.txt（ヘッダあり）を読み込み、
    各 case_id のラベル [L1, L2, L3, R1, R2, theta1, theta2]（7次元）を返す。
    戻り値:
      labels: (n_cases, 7) の numpy 配列
      case_ids: [0..n_cases-1] のリスト
    形式（CSVヘッダ: case_id,seg_idx,val0,val1）
      seg_idx=0: L1
      seg_idx=1: theta1 (val0), R1 (val1)
      seg_idx=2: L2
      seg_idx=3: theta2 (val0), R2 (val1)
      seg_idx=4: L3
    """
    data = np.loadtxt(filename, delimiter=",", skiprows=1)
    case_ids = data[:, 0].astype(int)
    seg_idx  = data[:, 1].astype(int)
    val0     = data[:, 2].astype(float)
    val1     = data[:, 3].astype(float)
    n_cases = int(case_ids.max()) + 1
    labels = np.zeros((n_cases, 7), dtype=np.float32)
    for cid in range(n_cases):
        mask = (case_ids == cid)
        segs = seg_idx[mask]
        v0   = val0[mask]
        v1   = val1[mask]
        labels[cid, 0] = v0[segs == 0][0]  # L1
        labels[cid, 5] = v0[segs == 1][0]  # theta1（読み込み時の単位）
        labels[cid, 3] = v1[segs == 1][0]  # R1
        labels[cid, 1] = v0[segs == 2][0]  # L2
        labels[cid, 6] = v0[segs == 3][0]  # theta2（読み込み時の単位）
        labels[cid, 4] = v1[segs == 3][0]  # R2
        labels[cid, 2] = v0[segs == 4][0]  # L3
    return labels, list(range(n_cases))
# ----------------
# ラベル単位変換・整形ユーティリティ
# ----------------
def adjust_label_units(labels, angle_unit='rad', length_scale=1.0, wrap_angles=True):
    """
    labels: (n_cases, 7) [L1,L2,L3,R1,R2,theta1,theta2]
    angle_unit: 'rad'（そのまま） or 'deg'（度→ラジアンに変換）
    length_scale: 長さのスケール係数（例: mm→mは 0.001）
    wrap_angles: Trueなら角度を [-pi, pi] にラップ
    """
    labels = labels.copy().astype(np.float32)
    labels[:, :5] *= float(length_scale)
    if angle_unit.lower() == 'deg':
        labels[:, 5:] = labels[:, 5:] * (np.pi / 180.0)
    if wrap_angles:
        labels[:, 5] = (labels[:, 5] + np.pi) % (2.0 * np.pi) - np.pi
        labels[:, 6] = (labels[:, 6] + np.pi) % (2.0 * np.pi) - np.pi
    return labels
# ----------------
# Farthest Point Sampling（簡易版、numpy）
# ----------------
def farthest_point_sampling(pts, num_samples):
    """
    pts: (N,3) numpy
    num_samples: int
    戻り: (num_samples,3) 選択点
    """
    N = pts.shape[0]
    if N <= num_samples:
        reps = num_samples - N
        extra_idx = np.random.choice(N, reps, replace=True)
        idx = np.concatenate([np.arange(N), extra_idx], axis=0)
        return pts[idx, :]
    selected_idx = np.zeros((num_samples,), dtype=np.int64)
    selected_idx[0] = np.random.randint(0, N)
    dists = np.full((N,), np.inf, dtype=np.float32)
    last_pt = pts[selected_idx[0]]
    dists = np.minimum(dists, np.sum((pts - last_pt)**2, axis=1))
    for i in range(1, num_samples):
        next_idx = int(np.argmax(dists))
        selected_idx[i] = next_idx
        last_pt = pts[next_idx]
        dists = np.minimum(dists, np.sum((pts - last_pt)**2, axis=1))
    return pts[selected_idx, :]
# ----------------
# 前処理ユーティリティ（重複除去・SOR・PCAアライン・正規化）
# ----------------
def remove_duplicates(pts, round_decimals=6):
    if pts.shape[0] <= 1:
        return pts
    rounded = np.round(pts, decimals=round_decimals)
    _, uniq_idx = np.unique(rounded, axis=0, return_index=True)
    return pts[np.sort(uniq_idx)]
def statistical_outlier_removal(pts, k=16, std_ratio=2.0):
    """
    統計的外れ値除去（SOR）。
    各点のk近傍平均距離の分布を算出し、平均+std_ratio*std を超える点を外れ値とみなし除去。
    O(N^2) 近傍距離計算（4096点程度なら前処理として許容）
    """
    N = pts.shape[0]
    if N < (k + 2):
        return pts
    mean_dists = np.zeros((N,), dtype=np.float32)
    chunk = 512
    for i in range(0, N, chunk):
        c_end = min(i + chunk, N)
        c_pts = pts[i:c_end]                   # (C,3)
        d2 = np.sum((c_pts[:, None, :] - pts[None, :, :])**2, axis=2)  # (C,N)
        k1 = k + 1
        part = np.partition(d2, k1, axis=1)[:, :k1]  # (C, k+1)
        part_sorted = np.sort(part, axis=1)
        nnk = part_sorted[:, 1:]  # (C, k)
        mean_dists[i:c_end] = np.sqrt(np.mean(nnk, axis=1)).astype(np.float32)
    mu = float(np.mean(mean_dists))
    sigma = float(np.std(mean_dists))
    if sigma < 1e-9:
        return pts
    keep = mean_dists <= (mu + std_ratio * sigma)
    if keep.sum() < max(32, int(0.3 * N)):
        return pts
    return pts[keep]
def pca_align_to_z(pts):
    """
    PCAで主軸アラインメント。最大固有値方向をZ軸に合わせ、右手系を維持。
    返り値: (aligned_pts, R, centroid)
    """
    centroid = pts.mean(axis=0, keepdims=True)
    X = pts - centroid
    cov = np.cov(X.T)
    eigvals, eigvecs = np.linalg.eigh(cov)
    order = np.argsort(eigvals)[::-1]
    v0 = eigvecs[:, order[0]]
    v1 = eigvecs[:, order[1]]
    v2 = np.cross(v0, v1)
    v1 = v1 / (np.linalg.norm(v1) + 1e-9)
    v2 = v2 / (np.linalg.norm(v2) + 1e-9)
    v0 = v0 / (np.linalg.norm(v0) + 1e-9)
    R = np.stack([v1, v2, v0], axis=1).astype(np.float32)
    aligned = X @ R
    return aligned.astype(np.float32), R, centroid.astype(np.float32)
def normalize_coords(pts, method='unit_sphere'):
    method = (method or 'none').lower()
    if method == 'unit_sphere':
        centroid = pts.mean(axis=0, keepdims=True)
        X = pts - centroid
        scale = np.max(np.linalg.norm(X, axis=1))
        scale = 1.0 if scale < 1e-9 else scale
        return (X / scale).astype(np.float32)
    elif method == 'zscore':
        mu = pts.mean(axis=0, keepdims=True)
        sigma = pts.std(axis=0, keepdims=True)
        sigma[sigma < 1e-9] = 1.0
        return ((pts - mu) / sigma).astype(np.float32)
    else:
        return pts.astype(np.float32)
def preprocess_points_np(
    pts,
    coord_scale=1.0,
    dedup_round_decimals=6,
    use_sor=True, sor_k=16, sor_std_ratio=2.0,
    canonical_align=False,
    normalize_method='unit_sphere'
):
    """
    推奨の前処理パイプライン（ランダム無し）。学習・推論共通で使用。
    1) 単位スケール（例: mm→m）
    2) 重複点除去
    3) 統計的外れ値除去（SOR）
    4) PCA主軸アラインメント（任意）
    5) 座標正規化（ユニットスフィア/Zスコア/無し）
    """
    pts = pts.astype(np.float32)
    pts = pts * float(coord_scale)
    pts = remove_duplicates(pts, round_decimals=dedup_round_decimals)
    if use_sor:
        pts = statistical_outlier_removal(pts, k=int(sor_k), std_ratio=float(sor_std_ratio))
    if canonical_align:
        pts, _, _ = pca_align_to_z(pts)
    pts = normalize_coords(pts, method=normalize_method)
    return pts
# ----------------
# Dataset（FPS + 拡張 + 前処理キャッシュ）
# ----------------
class PipePointParamDataset(torch.utils.data.Dataset):
    """
    - 入力: point cloud from pipe_case_{case_id:02d}.xyz
    - ラベル: params_all.txt から復元した 7次元ベクトル
    - augment_times: データセット長を倍増
    - use_fps: TrueならFarthest Point Samplingでサブサンプリング
    - 前処理はキャッシュ（SOR/PCA/正規化）
    - 回転拡張は本コードでは無効化（角度がグローバル座標基準のため）
    """
    def __init__(self, indices, labels, root_dir=".", num_points=4096, 
                 preprocess_cfg=None,
                 normalize=True,
                 augment_times=1, augment_rotate=False, rotate_axis=None,
                 augment_scale=True, scale_low=0.98, scale_high=1.02,
                 augment_jitter=True, jitter_std=0.002, jitter_clip=0.01,
                 augment_dropout=True, dropout_rate=0.03,
                 use_fps=True,
                 pre_cache=True):
        super().__init__()
        self.indices = list(indices)
        self.labels = labels
        self.root_dir = root_dir
        self.num_points = num_points
        self.normalize = normalize
        self.preprocess_cfg = preprocess_cfg or {}
        self.coord_scale = float(self.preprocess_cfg.get("coord_scale", 1.0))
        self.dedup_round_decimals = int(self.preprocess_cfg.get("dedup_round_decimals", 6))
        self.use_sor = bool(self.preprocess_cfg.get("use_sor", True))
        self.sor_k = int(self.preprocess_cfg.get("sor_k", 16))
        self.sor_std_ratio = float(self.preprocess_cfg.get("sor_std_ratio", 2.0))
        self.canonical_align = bool(self.preprocess_cfg.get("canonical_align", False))
        self.normalize_method = str(self.preprocess_cfg.get("normalize_method", "unit_sphere")).lower()
        self.augment_times = int(augment_times)
        self.augment_rotate = bool(augment_rotate)
        self.rotate_axis = rotate_axis
        self.augment_scale = bool(augment_scale)
        self.scale_low = float(scale_low)
        self.scale_high = float(scale_high)
        self.augment_jitter = bool(augment_jitter)
        self.jitter_std = float(jitter_std)
        self.jitter_clip = float(jitter_clip)
        self.augment_dropout = bool(augment_dropout)
        self.dropout_rate = float(dropout_rate)
        self.use_fps = bool(use_fps)
        self.base_len = len(self.indices)
        self.pre_cache = bool(pre_cache)
        self._cache = {}
        if self.pre_cache:
            self._build_cache()
    def __len__(self):
        return self.base_len * self.augment_times
    def _load_points(self, cid):
        path = os.path.join(self.root_dir, f"pipe_case_{cid:02d}.xyz")
        try:
            pts = np.loadtxt(path, skiprows=1)
        except Exception:
            pts = np.loadtxt(path)
        if pts.ndim != 2 or pts.shape[1] != 3:
            raise ValueError(f"Invalid point file shape: {path}, got {pts.shape}")
        return pts.astype(np.float32)
    def _build_cache(self):
        print("[Info] Building preprocessed cache...")
        for cid in self.indices:
            pts_raw = self._load_points(cid)
            pts_pp = preprocess_points_np(
                pts_raw,
                coord_scale=self.coord_scale,
                dedup_round_decimals=self.dedup_round_decimals,
                use_sor=self.use_sor, sor_k=self.sor_k, sor_std_ratio=self.sor_std_ratio,
                canonical_align=self.canonical_align,
                normalize_method=self.normalize_method
            )
            self._cache[cid] = pts_pp
        print(f"[Info] Preprocessed cache ready for {len(self._cache)} cases.")
    def _subsample_or_tile(self, pts):
        N = pts.shape[0]
        if N >= self.num_points:
            if self.use_fps:
                return farthest_point_sampling(pts, self.num_points)
            else:
                idx = np.random.choice(N, self.num_points, replace=False)
                return pts[idx, :]
        else:
            reps = self.num_points - N
            extra_idx = np.random.choice(N, reps, replace=True)
            idx = np.concatenate([np.arange(N), extra_idx], axis=0)
            return pts[idx, :]
    def _rotation_matrix_z(self, theta):
        c, s = np.cos(theta), np.sin(theta)
        R = np.array([[c, -s, 0.0],
                      [s,  c, 0.0],
                      [0.0, 0.0, 1.0]], dtype=np.float32)
        return R
    def _random_rotation_matrix_any(self):
        axis = np.random.randn(3).astype(np.float32)
        norm = np.linalg.norm(axis)
        axis = axis / (norm + 1e-9)
        theta = np.random.uniform(0.0, 2.0 * np.pi)
        K = np.array([[0, -axis[2], axis[1]],
                      [axis[2], 0, -axis[0]],
                      [-axis[1], axis[0], 0]], dtype=np.float32)
        I = np.eye(3, dtype=np.float32)
        R = I + np.sin(theta) * K + (1.0 - np.cos(theta)) * (K @ K)
        return R
    def _apply_random_rotation(self, pts):
        # 本コードでは augment_rotate=False を推奨（角度がグローバル基準のため）
        if not self.augment_rotate:
            return pts
        if self.rotate_axis == 'z':
            theta = np.random.uniform(0.0, 2.0*np.pi)
            R = self._rotation_matrix_z(theta)
        elif self.rotate_axis == 'random':
            R = self._random_rotation_matrix_any()
        else:
            return pts
        return pts @ R.T
    def _apply_random_scaling(self, pts):
        s = np.random.uniform(self.scale_low, self.scale_high)
        return pts * s
    def _apply_jitter(self, pts):
        noise = np.clip(self.jitter_std * np.random.randn(*pts.shape), -self.jitter_clip, self.jitter_clip)
        return (pts + noise).astype(np.float32)
    def _apply_point_dropout(self, pts):
        N = pts.shape[0]
        drop_idx = np.random.rand(N) < self.dropout_rate
        if np.any(drop_idx):
            pts[drop_idx] = pts[0]
        return pts
    def __getitem__(self, i):
        base_i = i % self.base_len
        cid = self.indices[base_i]
        pts = self._cache[cid] if self.pre_cache else self._load_points(cid)
        if not self.pre_cache:
            pts = preprocess_points_np(
                pts,
                coord_scale=self.coord_scale,
                dedup_round_decimals=self.dedup_round_decimals,
                use_sor=self.use_sor, sor_k=self.sor_k, sor_std_ratio=self.sor_std_ratio,
                canonical_align=self.canonical_align,
                normalize_method=self.normalize_method
            )
        pts = self._subsample_or_tile(pts)
        # ランダム拡張（回転は無効化）
        pts = self._apply_random_rotation(pts)
        if self.augment_scale:
            pts = self._apply_random_scaling(pts)
        if self.augment_jitter:
            pts = self._apply_jitter(pts)
        if self.augment_dropout:
            pts = self._apply_point_dropout(pts)
        pts = torch.from_numpy(pts.astype(np.float32)).transpose(0,1)  # (3, num_points)
        label = torch.from_numpy(self.labels[cid])                      # (7,)
        return pts, label
# ----------------
# STN（PointNet）
# ----------------
class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
    def forward(self, x):  # x: (B,3,N)
        b = x.size(0)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))  # (B,1024,N)
        x = torch.max(x, 2)[0]               # (B,1024)
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        iden = torch.eye(3, dtype=torch.float32, device=x.device).view(1, 9).repeat(b, 1)
        x = x + iden
        x = x.view(-1, 3, 3)
        return x
class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.k = k
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k * k)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
    def forward(self, x):  # x: (B,k,N)
        b = x.size(0)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2)[0]
        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)
        iden = torch.eye(self.k, dtype=torch.float32, device=x.device).view(1, self.k * self.k).repeat(b, 1)
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x
def feature_transform_regularizer(trans):
    k = trans.size(1)
    I = torch.eye(k, device=trans.device).unsqueeze(0).expand(trans.size(0), -1, -1)
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2)))
    return loss
# ----------------
# Model (PointNet + STN + BN + Dropout)
# ----------------
class PointNetBackbone(nn.Module):
    def __init__(self, use_feature_stn=True, dropout_p=0.5):
        super(PointNetBackbone, self).__init__()
        self.stn = STN3d()
        self.use_feature_stn = use_feature_stn
        if use_feature_stn:
            self.fstn = STNkd(k=64)
        self.conv1 = nn.Conv1d(3, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.dropout_p = dropout_p
    def forward(self, x):  # x: (B,3,N)
        B, _, N = x.size()
        trans = self.stn(x)  # (B,3,3)
        x_t = x.transpose(2, 1)  # (B,N,3)
        x_t = torch.bmm(x_t, trans)
        x = x_t.transpose(2, 1)  # (B,3,N)
        x = F.relu(self.bn1(self.conv1(x)))  # (B,64,N)
        trans_feat = None
        if self.use_feature_stn:
            trans_feat = self.fstn(x)        # (B,64,64)
            x_t = x.transpose(2, 1)          # (B,N,64)
            x_t = torch.bmm(x_t, trans_feat) # (B,N,64)
            x = x_t.transpose(2, 1)          # (B,64,N)
        x = F.relu(self.bn2(self.conv2(x)))  # (B,128,N)
        x = F.relu(self.bn3(self.conv3(x)))  # (B,1024,N)
        x = torch.max(x, 2)[0]               # (B,1024)
        return x, trans, trans_feat
class PipeDimensionRegressor(nn.Module):
    def __init__(self, out_dim=9, dropout_p=0.5, use_feature_stn=True):
        super(PipeDimensionRegressor, self).__init__()
        self.backbone = PointNetBackbone(use_feature_stn=use_feature_stn, dropout_p=dropout_p)
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dropout = nn.Dropout(dropout_p)
        self.fc_out = nn.Linear(256, out_dim)  # 9-dim: [L1,L2,L3,R1,R2,s1,c1,s2,c2]
        self._init_weights()
    def _init_weights(self):
        for m in [self.fc1, self.fc2, self.fc_out]:
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            nn.init.zeros_(m.bias)
    def forward(self, x):  # x: (B,3,N)
        feat, trans, trans_feat = self.backbone(x)
        x = F.relu(self.bn1(self.fc1(feat)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        out = self.fc_out(x)
        return out, trans, trans_feat
# ----------------
# ラベル標準化（線形5次元のみ + 任意でlog1p）
# ----------------
class LabelScaler:
    """
    先頭5次元（[L1,L2,L3,R1,R2]）のみ標準化パラメータを保持
    use_log_linear: Trueなら loss計算時に log1p を適用（動的レンジの圧縮）
    """
    def __init__(self, use_log_linear=True):
        self.mean5 = None
        self.std5 = None
        self.use_log_linear = bool(use_log_linear)
    def _prep(self, arr5):
        if self.use_log_linear:
            return np.log1p(np.maximum(arr5, 0.0))
        return arr5
    def fit(self, y_train: np.ndarray):
        arr5 = self._prep(y_train[:, :5])
        mean = arr5.mean(axis=0)
        std = arr5.std(axis=0)
        std[std < 1e-8] = 1.0
        self.mean5 = mean.astype(np.float32)
        self.std5 = std.astype(np.float32)
        return self
    def to_torch(self, device):
        mean_t = torch.from_numpy(self.mean5).to(device)  # (5,)
        std_t = torch.from_numpy(self.std5).to(device)    # (5,)
        return mean_t, std_t
# ----------------
# 角度ユーティリティ
# ----------------
def angles_to_sincos(thetas: torch.Tensor):
    s1 = torch.sin(thetas[:, 0:1])
    c1 = torch.cos(thetas[:, 0:1])
    s2 = torch.sin(thetas[:, 1:2])
    c2 = torch.cos(thetas[:, 1:2])
    return torch.cat([s1, c1, s2, c2], dim=1)
def normalize_pairwise_sincos(sincos: torch.Tensor):
    v1 = sincos[:, 0:2]
    v2 = sincos[:, 2:4]
    v1 = v1 / torch.clamp(v1.norm(dim=1, keepdim=True), min=1e-6)
    v2 = v2 / torch.clamp(v2.norm(dim=1, keepdim=True), min=1e-6)
    return torch.cat([v1, v2], dim=1)
def wrap_angle_diff(a, b):
    d = a - b
    return (d + np.pi) % (2.0 * np.pi) - np.pi
# ----------------
# EarlyStopping
# ----------------
class EarlyStopping:
    def __init__(self, patience=300, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best = float("inf")
        self.wait = 0
        self.best_state = None
        self.best_epoch = 0
    def step(self, metric, model, epoch):
        improved = (self.best - metric) > self.min_delta
        if improved:
            self.best = metric
            self.wait = 0
            self.best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            self.best_epoch = epoch
        else:
            self.wait += 1
        return self.wait > self.patience
# ----------------
# EMA（指数移動平均）
# ----------------
class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.device = next(model.parameters()).device
        self.ema = self._clone_model(model).to(self.device)
        self._hard_sync(model)
    @torch.no_grad()
    def _clone_model(self, model):
        ema = type(model)(out_dim=model.fc_out.out_features, dropout_p=model.dropout.p, use_feature_stn=True)
        ema.load_state_dict(model.state_dict(), strict=True)
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema
    @torch.no_grad()
    def _hard_sync(self, model):
        for ema_p, src_p in zip(self.ema.parameters(), model.parameters()):
            ema_p.data.copy_(src_p.data)
        for ema_b, src_b in zip(self.ema.buffers(), model.buffers()):
            ema_b.data.copy_(src_b.data)
    @torch.no_grad()
    def update(self, model):
        for ema_p, src_p in zip(self.ema.parameters(), model.parameters()):
            ema_p.data.mul_(self.decay).add_(src_p.data, alpha=(1.0 - self.decay))
        for ema_b, src_b in zip(self.ema.buffers(), model.buffers()):
            ema_b.data.copy_(src_b.data)
    def state_dict(self):
        return self.ema.state_dict()
# ----------------
# Train / Eval（カスタム損失: 線形5次元 + 角度sincos + STN正則化）
# ----------------
def compute_losses(outputs, targets, label_mu5, label_sigma5, criterion_linear, criterion_angle,
                   trans_feat=None, ft_reg_weight=0.01, angle_weight=2.5, use_log_linear=True):
    """
    outputs: (B,9) -> [L1..R2,s1,c1,s2,c2]
    targets: (B,7) -> [L1..R2,theta1,theta2]
    """
    # 線形5次元
    out_lin = outputs[:, :5]
    tgt_lin = targets[:, :5]
    if use_log_linear:
        out_lin_t = torch.log1p(torch.clamp(out_lin, min=0.0))
        tgt_lin_t = torch.log1p(torch.clamp(tgt_lin, min=0.0))
    else:
        out_lin_t = out_lin
        tgt_lin_t = tgt_lin
    out_lin_norm = (out_lin_t - label_mu5) / label_sigma5
    tgt_lin_norm = (tgt_lin_t - label_mu5) / label_sigma5
    loss_lin = criterion_linear(out_lin_norm, tgt_lin_norm)
    # 角度（sincos）
    out_sc = outputs[:, 5:9]
    out_sc = normalize_pairwise_sincos(out_sc)
    tgt_sc = angles_to_sincos(targets[:, 5:7])
    loss_ang = criterion_angle(out_sc, tgt_sc)
    loss = loss_lin + angle_weight * loss_ang
    # STN feature transform 正則化
    reg = 0.0
    if trans_feat is not None:
        reg = feature_transform_regularizer(trans_feat) * ft_reg_weight
        loss = loss + reg
    return loss, loss_lin.detach(), loss_ang.detach(), (reg if isinstance(reg, torch.Tensor) else torch.tensor(reg))
def train_one_epoch(model, optimizer, dataloader, device,
                    label_mu5, label_sigma5,
                    criterion_linear, criterion_angle,
                    scaler=None,
                    max_grad_norm=1.0, ft_reg_weight=0.01, angle_weight=2.5,
                    ema_updater=None,
                    use_log_linear=True):
    model.train()
    total_loss = 0.0
    for inputs, targets in dataloader:
        inputs = inputs.to(device)    # (B,3,N)
        targets = targets.to(device)  # (B,7)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(scaler is not None)):
            outputs, _, trans_feat = model(inputs)  # outputs: (B,9)
            loss, _, _, _ = compute_losses(outputs, targets, label_mu5, label_sigma5,
                                           criterion_linear, criterion_angle,
                                           trans_feat=trans_feat, ft_reg_weight=ft_reg_weight, angle_weight=angle_weight,
                                           use_log_linear=use_log_linear)
        if scaler is not None:
            scaler.scale(loss).backward()
            if max_grad_norm is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
            optimizer.step()
        if ema_updater is not None:
            ema_updater.update(model)
        total_loss += loss.item() * inputs.size(0)
    return total_loss / len(dataloader.dataset)
@torch.no_grad()
def evaluate(eval_model, dataloader, device,
             label_mu5, label_sigma5,
             criterion_linear, criterion_angle,
             ft_reg_weight=0.01, angle_weight=2.5,
             use_log_linear=True):
    eval_model.eval()
    total_loss = 0.0
    for inputs, targets in dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs, _, trans_feat = eval_model(inputs)
        loss, _, _, _ = compute_losses(outputs, targets, label_mu5, label_sigma5,
                                       criterion_linear, criterion_angle,
                                       trans_feat=trans_feat, ft_reg_weight=ft_reg_weight, angle_weight=angle_weight,
                                       use_log_linear=use_log_linear)
        total_loss += loss.item() * inputs.size(0)
    return total_loss / len(dataloader.dataset)
# ----------------
# 評価保存ユーティリティ（角度を元に戻して評価、TTA＋前処理対応）
# ----------------
def compute_metrics(preds_np, trues_np, angle_indices=(5,6), deg=False):
    err = preds_np - trues_np
    if angle_indices is not None:
        for ai in angle_indices:
            err[:, ai] = wrap_angle_diff(preds_np[:, ai], trues_np[:, ai])
    if deg:
        scale = np.ones(err.shape[1], dtype=np.float32)
        for ai in angle_indices:
            scale[ai] = 180.0 / np.pi
        err = err * scale
    mae = np.mean(np.abs(err), axis=0)
    mse = np.mean(err**2, axis=0)
    rmse = np.sqrt(mse)
    overall = {
        "MAE_mean": float(np.mean(np.abs(err))),
        "MSE_mean": float(np.mean(err**2)),
        "RMSE_mean": float(np.sqrt(np.mean(err**2)))
    }
    keys = ["L1","L2","L3","R1","R2","theta1","theta2"]
    by_dim = {k: {"MAE": float(mae[i]), "MSE": float(mse[i]), "RMSE": float(rmse[i])} for i, k in enumerate(keys)}
    return overall, by_dim
@torch.no_grad()
def _predict7_from_points(model, pts_np, device, num_points=4096,
                          preprocess_cfg=None, tta_times=8, rotate_axis=None):
    """
    1つの点群に対して、TTA（複数サブサンプリング）で予測平均を返す。
    model: PipeDimensionRegressor（out_dim=9）互換
    戻り: (7,) numpy [L1,L2,L3,R1,R2,theta1,theta2]
    """
    preprocess_cfg = preprocess_cfg or {}
    preds = []
    for _ in range(tta_times):
        pts = pts_np.copy().astype(np.float32)
        # 前処理
        pts = preprocess_points_np(
            pts,
            coord_scale=float(preprocess_cfg.get("coord_scale", 1.0)),
            dedup_round_decimals=int(preprocess_cfg.get("dedup_round_decimals", 6)),
            use_sor=bool(preprocess_cfg.get("use_sor", True)),
            sor_k=int(preprocess_cfg.get("sor_k", 16)),
            sor_std_ratio=float(preprocess_cfg.get("sor_std_ratio", 2.0)),
            canonical_align=bool(preprocess_cfg.get("canonical_align", False)),
            normalize_method=str(preprocess_cfg.get("normalize_method", "unit_sphere")).lower()
        )
        # サブサンプル
        pts = farthest_point_sampling(pts, num_points)
        # 任意の回転TTA（Valでは通常None）
        if rotate_axis == 'z':
            theta = np.random.uniform(0.0, 2.0*np.pi)
            c, s = np.cos(theta), np.sin(theta)
            R = np.array([[c, -s, 0.0],[s, c, 0.0],[0.0,0.0,1.0]], dtype=np.float32)
            pts = pts @ R.T
        x = torch.from_numpy(pts.astype(np.float32)).transpose(0,1).unsqueeze(0).to(device)  # (1,3,N)
        out9, _, _ = model(x)
        out9 = out9.cpu().numpy().reshape(-1)
        pred_lin = out9[:5]
        s1, c1, s2, c2 = out9[5], out9[6], out9[7], out9[8]
        n1 = np.linalg.norm([s1, c1]); n1 = 1.0 if n1 < 1e-6 else n1
        n2 = np.linalg.norm([s2, c2]); n2 = 1.0 if n2 < 1e-6 else n2
        s1, c1 = s1 / n1, c1 / n1
        s2, c2 = s2 / n2, c2 / n2
        theta1 = np.arctan2(s1, c1)
        theta2 = np.arctan2(s2, c2)
        preds.append(np.array([pred_lin[0], pred_lin[1], pred_lin[2], pred_lin[3], pred_lin[4], theta1, theta2], dtype=np.float32))
    return np.mean(np.stack(preds, axis=0), axis=0)
def evaluate_and_save_casewise(model, val_case_ids, labels, device,
                               root_dir=".", num_points=4096,
                               out_csv="results8-2/val_predictions.csv",
                               out_json="results8-2/val_metrics.json",
                               tta_times=8, rotate_axis=None,
                               preprocess_cfg=None, save_degree_metrics=True):
    Path(out_csv).parent.mkdir(parents=True, exist_ok=True)
    preds_all, trues_all = [], []
    header = ["case_id",
              "pred_L1","pred_L2","pred_L3","pred_R1","pred_R2","pred_theta1","pred_theta2",
              "true_L1","true_L2","true_L3","true_R1","true_R2","true_theta1","true_theta2",
              "abserr_L1","abserr_L2","abserr_L3","abserr_R1","abserr_R2","abserr_theta1","abserr_theta2"]
    with open(out_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(header)
        model.eval()
        with torch.no_grad():
            for cid in val_case_ids:
                path = os.path.join(root_dir, f"pipe_case_{cid:02d}.xyz")
                try:
                    pts = np.loadtxt(path, skiprows=1)
                except Exception:
                    pts = np.loadtxt(path)
                if pts.ndim != 2 or pts.shape[1] != 3:
                    raise ValueError(f"Invalid point file shape: {path}, got {pts.shape}")
                pred7 = _predict7_from_points(model, pts.astype(np.float32),
                                              device=device, num_points=num_points,
                                              preprocess_cfg=preprocess_cfg, tta_times=tta_times, rotate_axis=rotate_axis)
                true7 = labels[cid].astype(np.float32)
                abs_err = np.abs(pred7 - true7)
                abs_err[5] = np.abs(wrap_angle_diff(pred7[5], true7[5]))
                abs_err[6] = np.abs(wrap_angle_diff(pred7[6], true7[6]))
                preds_all.append(pred7)
                trues_all.append(true7)
                row = [cid] + pred7.tolist() + true7.tolist() + abs_err.tolist()
                writer.writerow(row)
    preds_all = np.vstack(preds_all)
    trues_all = np.vstack(trues_all)
    overall_rad, by_dim_rad = compute_metrics(preds_all, trues_all, angle_indices=(5,6), deg=False)
    metrics = {"overall_rad": overall_rad, "by_dim_rad": by_dim_rad, "num_cases": int(len(val_case_ids))}
    if save_degree_metrics:
        overall_deg, by_dim_deg = compute_metrics(preds_all, trues_all, angle_indices=(5,6), deg=True)
        metrics["overall_deg"] = overall_deg
        metrics["by_dim_deg"] = by_dim_deg
    with open(out_json, "w") as jf:
        json.dump(metrics, jf, indent=2)
    print(f"Saved case-wise predictions to {out_csv}")
    print(f"Saved metrics summary to {out_json}")
# ----------------
# Main（汎化改善セット）
# ----------------
def main():
    seed_everything(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results_dir = "results8-2"
    Path(results_dir).mkdir(parents=True, exist_ok=True)
    # ====== 単位・前処理の設定 ======
    # 角度ラベル単位: 'rad' or 'deg'（params_all.txt の角度に合わせる）
    angle_unit_labels = 'rad'  # params_all.txt が度なら 'deg' に必ず変更
    # 長さラベルスケール（例: mm→m は 0.001）
    label_length_scale = 1.0
    # 座標スケール（点群座標の単位変換、例: mm→m は 0.001）
    coord_scale = 1.0
    # PCA主軸アラインメント（角度が主軸基準で定義されているときのみ True 推奨）
    canonical_align = False
    # 座標正規化メソッド: 'unit_sphere' / 'zscore' / 'none'
    normalize_method = 'unit_sphere'
    # SOR設定（外れ値除去）
    use_sor = True
    sor_k = 16
    sor_std_ratio = 2.0
    preprocess_cfg = {
        "coord_scale": coord_scale,
        "dedup_round_decimals": 6,
        "use_sor": use_sor,
        "sor_k": sor_k,
        "sor_std_ratio": sor_std_ratio,
        "canonical_align": canonical_align,
        "normalize_method": normalize_method,
    }
    # ====== ラベルの読み込み＋単位変換 ======
    labels_raw, case_ids = load_params_all_to_vec7("params_all.txt")
    labels = adjust_label_units(labels_raw, angle_unit=angle_unit_labels,
                                length_scale=label_length_scale, wrap_angles=True)
    # ====== 分割のシャッフル（固定シードで再現性を確保） ======
    rng = np.random.default_rng(42)
    case_ids_shuffled = np.array(case_ids)
    rng.shuffle(case_ids_shuffled)
    # データが少ないため Val は2件程度に抑え、Trainはなるべく多く
    train_indices = case_ids_shuffled[:8]   # 8件学習（可能ならさらに増やす）
    val_indices   = case_ids_shuffled[8:10] # 2件検証
    # ====== ハイパーパラメータ（過学習抑制・角度重視） ======
    # - weight_decay を 1e-3 に上げる、dropout_p を 0.5 に上げる、ft_reg_weight を 0.02 に上げる
    # - num_points を 8192 に増やす、batch_size を増やす
    batch_size    = 32
    num_points    = 8192
    num_epochs    = 2000            # 10000は長すぎ。OneCycleLRなら2000程度で十分
    max_lr        = 1.0e-3
    weight_decay  = 1.0e-3          # 正則化強化
    max_grad_norm = 1.0
    patience_es   = 300             # 早めに止める
    ft_reg_weight = 0.02            # STN正則化強化
    angle_weight  = 2.5             # 角度の比重アップ
    use_log_linear = True           # 線形ラベルの log1p を有効化
    # ====== データセット（回転拡張を無効化、他は控えめ） ======
    train_dataset = PipePointParamDataset(
        train_indices, labels, root_dir=".", num_points=num_points,
        preprocess_cfg=preprocess_cfg,
        augment_times=30,                     # データ少ないので拡張回数を増やす
        augment_rotate=False, rotate_axis=None,  # ★ 角度ラベルがグローバル依存なら回転拡張は禁止
        augment_scale=True, scale_low=0.98, scale_high=1.02,  # スケール控えめ
        augment_jitter=True, jitter_std=0.002, jitter_clip=0.01,
        augment_dropout=True, dropout_rate=0.03,
        use_fps=True,
        pre_cache=True
    )
    val_dataset   = PipePointParamDataset(
        val_indices, labels, root_dir=".", num_points=num_points,
        preprocess_cfg=preprocess_cfg,
        augment_times=1, augment_rotate=False, rotate_axis=None,
        augment_scale=False, augment_jitter=False, augment_dropout=False,
        use_fps=True,
        pre_cache=True
    )
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  drop_last=False, num_workers=0)
    val_loader   = torch.utils.data.DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, drop_last=False, num_workers=0)
    # ====== ラベル標準化（線形5次元のみ） ======
    scaler = LabelScaler(use_log_linear=use_log_linear).fit(labels[np.array(train_indices)])
    label_mu5_t, label_sigma5_t = scaler.to_torch(device)
    # ====== モデル・損失・最適化・スケジューラ・EMA ======
    model = PipeDimensionRegressor(out_dim=9, dropout_p=0.5, use_feature_stn=True).to(device)
    criterion_linear = nn.SmoothL1Loss(beta=0.5)
    criterion_angle  = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
    steps_per_epoch = max(1, len(train_loader))
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=max_lr, epochs=num_epochs, steps_per_epoch=steps_per_epoch,
        pct_start=0.1, anneal_strategy='cos', div_factor=10.0, final_div_factor=10.0
    )
    ema = ModelEMA(model, decay=0.999)
    scaler_amp = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))
    early_stopping = EarlyStopping(patience=patience_es, min_delta=0.0)
    # ====== ログ ======
    train_losses, val_losses = [], []
    best_model_path = os.path.join(results_dir, "best_model.pt")
    last_ckpt_path  = os.path.join(results_dir, "last_checkpoint.pth")
    # ====== 学習ループ ======
    for epoch in range(1, num_epochs + 1):
        train_loss = train_one_epoch(
            model, optimizer, train_loader, device,
            label_mu5=label_mu5_t, label_sigma5=label_sigma5_t,
            criterion_linear=criterion_linear, criterion_angle=criterion_angle,
            scaler=scaler_amp,
            max_grad_norm=max_grad_norm, ft_reg_weight=ft_reg_weight, angle_weight=angle_weight,
            ema_updater=ema,
            use_log_linear=use_log_linear
        )
        val_loss = evaluate(
            ema.ema, val_loader, device,
            label_mu5=label_mu5_t, label_sigma5=label_sigma5_t,
            criterion_linear=criterion_linear, criterion_angle=criterion_angle,
            ft_reg_weight=ft_reg_weight, angle_weight=angle_weight,
            use_log_linear=use_log_linear
        )
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        scheduler.step()
        stop = early_stopping.step(val_loss, ema.ema, epoch)  # EMA重みを記録
        if epoch % 20 == 0 or epoch == 1:
            current_lr = optimizer.param_groups[0]["lr"]
            print(f"Epoch {epoch:04d} | LR {current_lr:.2e} | Train {train_loss:.6f} | Val {val_loss:.6f} | Best {early_stopping.best:.6f}")
        if epoch % 100 == 0 or epoch == num_epochs:
            torch.save({
                "epoch": epoch,
                "model_state_dict": ema.state_dict(),  # EMA重みで保存
                "optimizer_state_dict": optimizer.state_dict(),
                "best_val_loss": early_stopping.best,
            }, last_ckpt_path)
        if stop:
            print(f"Early stopping at epoch {epoch}. Best epoch: {early_stopping.best_epoch}, Best val: {early_stopping.best:.6f}")
            break
    # ベスト状態へ復元・保存（EMAのベスト状態）
    if early_stopping.best_state is not None:
        ema.ema.load_state_dict(early_stopping.best_state)
    torch.save(ema.ema.state_dict(), best_model_path)
    print(f"Saved best EMA model to {best_model_path}")
    # 履歴CSV
    hist_csv = os.path.join(results_dir, "metrics_history.csv")
    with open(hist_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "train_loss", "val_loss"])
        for e, (tr, va) in enumerate(zip(train_losses, val_losses), start=1):
            writer.writerow([e, tr, va])
    print(f"Saved loss history to {hist_csv}")
    # Val（EMA重み、TTAあり）で7次元に復元して評価・保存（ラジアン・度の両方）
    evaluate_and_save_casewise(
        ema.ema,
        val_case_ids=val_indices,
        labels=labels,
        device=device,
        root_dir=".",
        num_points=num_points,
        out_csv=os.path.join(results_dir, "val_predictions.csv"),
        out_json=os.path.join(results_dir, "val_metrics.json"),
        tta_times=8,
        rotate_axis=None,      # Valでは回転なし
        preprocess_cfg=preprocess_cfg,
        save_degree_metrics=True
    )
    # 損失曲線
    epochs = np.arange(1, len(train_losses) + 1)
    plt.figure(figsize=(7,5))
    plt.plot(epochs, train_losses, label="Train Loss", color="tab:blue")
    plt.plot(epochs, val_losses,   label="Val Loss (EMA)",   color="tab:orange")
    plt.xlabel("Epoch")
    plt.ylabel("Composite Loss (Std linear + angle MSE + STN reg)")
    plt.title("Training and Validation Loss (EMA)")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, "loss_curve.png"), dpi=150)
    # plt.show()
if __name__ == "__main__":
    main()
"""
使い方
- 同ディレクトリに params_all.txt と pipe_case_XX.xyz 群を配置してください。
- params_all.txt の角度単位に合わせて、angle_unit_labels を 'rad' または 'deg' に設定してください。
- 実行: python Train_pipe8-2.py
- 出力（results8-2 ディレクトリ）
  - best_model.pt: EMAベストモデル（state_dict）
  - last_checkpoint.pth: 直近チェックポイント（EMA重み）
  - metrics_history.csv: 各エポックの Train/Val loss
  - val_predictions.csv: Val 各ケースの予測・真値・絶対誤差（角度は最小差）
  - val_metrics.json: MAE/MSE/RMSE（ラジアン・度の両方の要約）
  - loss_curve.png: 学習・検証損失曲線
補足（検証の着眼点）
- augment_rotate=False と angle_unit_labels の整合を最優先で確認してください。
- まだValが不安定な場合は、以下の追加を段階的に試してください。
  - weight_decay を 1e-3 に上げる、dropout_p を 0.5 に上げる、ft_reg_weight を 0.02 に上げる
  - num_points を 8192 に増やす、batch_size を増やす
  - canonical_align=True（角度定義が主軸基準の場合のみ）
- 交差検証（K-Fold）でスプリットの偏りをならして、by_dim_deg/by_dim_rad と best_val_loss の傾向を確認してください。
"""

In [0]:
"""
結論
- ほぼできますが、提示の追記コードだけでは load_model_from_weights が未定義のため動きません。
- 下記のように「9次元モデル重みを読み込み、推論時に7次元（[L1..R2, theta1, theta2]）を返すラッパ」を備えた load_model_from_weights を実装して一緒に追記すれば、散布図を作成できます。
- 併せて RUN_SCATTER=1 でCLIを有効化する if ブロックも追記します。
ポイント
- 学習済み重みは本編コードの通り results8-2/best_model.pt（state_dictのみ）または results8-2/last_checkpoint.pth（辞書形式）に対応します。
- 9次元出力モデル（[L1..R2, s1, c1, s2, c2]）を内部で呼び、sincosを正規化→atan2で角度2次元に復元して7次元を返すラッパ SevenOutWrapper を用意します。
- 散布図は角度をdegに変換して描画します。params_all.txt の角度単位に合わせて --angle-unit を設定してください。
追記コード（Train_pipe8-2.py の末尾に貼り付け）
- 先にお示しの「散布図モジュール」の前に、互換ローダー load_model_from_weights と7次元ラッパを定義します。
- その後、いただいた散布図モジュールを置き、最後に RUN_SCATTER=1 のときだけCLIが起動する if ブロックを有効化します。
"""
# ========= ここから下を Train_pipe8-2.py の末尾に追記 =========
# 9→7次元互換ローダーと7次元ラッパ
import os as _os
import torch as _torch
import torch.nn.functional as _F
class _SevenOutWrapper(_torch.nn.Module):
    """
    内部に PipeDimensionRegressor(out_dim=9) を持ち、
    forwardで (B,7) = [L1,L2,L3,R1,R2,theta1,theta2] を返すラッパ
    """
    def __init__(self, base9):
        super().__init__()
        self.base9 = base9
    def forward(self, x):
        out9, _, _ = self.base9(x)  # (B,9)
        lin5 = out9[:, :5]          # (B,5)
        sc   = out9[:, 5:9]         # (B,4) = [s1,c1,s2,c2]
        v1 = _F.normalize(sc[:, 0:2], dim=1, eps=1e-6)
        v2 = _F.normalize(sc[:, 2:4], dim=1, eps=1e-6)
        s1, c1 = v1[:, 0], v1[:, 1]
        s2, c2 = v2[:, 0], v2[:, 1]
        th1 = _torch.atan2(s1, c1).unsqueeze(1)
        th2 = _torch.atan2(s2, c2).unsqueeze(1)
        y = _torch.cat([lin5, th1, th2], dim=1)  # (B,7)
        return y
def load_model_from_weights(weights_path, device=None, out_dim=7):
    """
    weights_path: results8-2/best_model.pt（state_dict）または results8-2/last_checkpoint.pth（辞書）
    out_dim=7 を指定すると、9次元モデルを内部でロードし、_SevenOutWrapper で7次元出力に変換して返す。
    """
    device = device or _torch.device("cuda" if _torch.cuda.is_available() else "cpu")
    # 9次元ベースモデルを構築
    base = PipeDimensionRegressor(out_dim=9, dropout_p=0.5, use_feature_stn=True).to(device)
    # 重みロード
    ckpt = _torch.load(weights_path, map_location=device)
    if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
        state = ckpt["model_state_dict"]
    elif isinstance(ckpt, dict) and "state_dict" in ckpt:
        state = ckpt["state_dict"]
    else:
        # best_model.pt は state_dict そのものを保存している想定
        state = ckpt
    missing, unexpected = base.load_state_dict(state, strict=False)
    if missing or unexpected:
        print(f"[Warn] load_state_dict strict=False | missing={len(missing)} unexpected={len(unexpected)}")
    base.eval()
    if out_dim == 7:
        model = _SevenOutWrapper(base).to(device)
    else:
        model = base
    model.eval()
    return model
# ========= ここから：散布図ユーティリティ（ご提示案を反映） =========
import re as _re
import argparse as _argparse
import numpy as _np
import matplotlib.pyplot as _plt
# すでにTrain_pipe8-2.py内に以下が定義済みである前提：
# - load_params_all_to_vec7
# - adjust_label_units
# - preprocess_points_np
# - farthest_point_sampling
def _extract_case_id(path):
    m = _re.search(r"pipe_case_(\d+)\.xyz", _os.path.basename(path))
    if not m:
        raise ValueError(f"Failed to parse case_id from filename: {path}")
    return int(m.group(1))
@_torch.no_grad()
def _predict7_single_file(model, path, device, num_points=4096,
                          preprocess_cfg=None, tta_times=8, rotate_axis=None):
    """
    単一ファイルに対して、前処理＋TTAを行い (7,) の予測を返す。
    load_model_from_weights を使っておけば model(x) は常に (B,7) なので安全。
    """
    preprocess_cfg = preprocess_cfg or {}
    try:
        pts = _np.loadtxt(path, skiprows=1)
    except Exception:
        pts = _np.loadtxt(path)
    if pts.ndim != 2 or pts.shape[1] != 3:
        raise ValueError(f"Invalid point file shape: {path}, got {pts.shape}")
    pts = pts.astype(_np.float32)
    preds = []
    for _ in range(tta_times):
        pts_pp = preprocess_points_np(
            pts,
            coord_scale=float(preprocess_cfg.get("coord_scale", 1.0)),
            dedup_round_decimals=int(preprocess_cfg.get("dedup_round_decimals", 6)),
            use_sor=bool(preprocess_cfg.get("use_sor", True)),
            sor_k=int(preprocess_cfg.get("sor_k", 16)),
            sor_std_ratio=float(preprocess_cfg.get("sor_std_ratio", 2.0)),
            canonical_align=bool(preprocess_cfg.get("canonical_align", False)),
            normalize_method=str(preprocess_cfg.get("normalize_method", "unit_sphere")).lower()
        )
        pts_s = farthest_point_sampling(pts_pp, num_points)
        if rotate_axis == 'z':
            theta = _np.random.uniform(0.0, 2.0*_np.pi)
            c, s = _np.cos(theta), _np.sin(theta)
            R = _np.array([[c, -s, 0.0],[s, c, 0.0],[0.0,0.0,1.0]], dtype=_np.float32)
            pts_s = pts_s @ R.T
        x = _torch.from_numpy(pts_s.astype(_np.float32)).transpose(0,1).unsqueeze(0).to(device)  # (1,3,N)
        pred7 = model(x)  # 常に (B,7)
        pred7 = pred7.squeeze(0).cpu().numpy().reshape(-1).astype(_np.float32)
        preds.append(pred7)
    return _np.mean(_np.stack(preds, axis=0), axis=0)  # (7,)
def scatter_plot(files, weights_path="results8-2/best_model.pt",
                 angle_unit_labels='rad', label_length_scale=1.0,
                 preprocess_cfg=None, num_points=4096, tta_times=8,
                 rotate_axis=None, out_path="results8-2/scatter_pred_vs_true.png"):
    """
    指定ファイル群について、真値（x） vs 予測（y）の散布図を作成・保存。
    - files: ["pipe_case_07.xyz", ...]
    - weights_path: 学習済み重み（results8-2/best_model.pt など）
    - angle_unit_labels: 'rad' or 'deg'（params_all.txtの角度単位）
    - label_length_scale: ラベル長さの単位スケール（例: mm->m は 0.001）
    - preprocess_cfg: 学習時と同じ前処理設定（coord_scaleなど）
    - rotate_axis: TTA用の回転軸（通常 None）
    """
    preprocess_cfg = preprocess_cfg or {
        "coord_scale": 1.0,
        "dedup_round_decimals": 6,
        "use_sor": True,
        "sor_k": 16,
        "sor_std_ratio": 2.0,
        "canonical_align": False,
        "normalize_method": "unit_sphere",
    }
    # ラベル読込＋単位変換
    labels_raw, _ = load_params_all_to_vec7("params_all.txt")
    labels = adjust_label_units(labels_raw,
                                angle_unit=angle_unit_labels,
                                length_scale=label_length_scale,
                                wrap_angles=True)
    # 真値
    sel_ids = [_extract_case_id(p) for p in files]
    trues = _np.stack([labels[cid] for cid in sel_ids], axis=0)  # (B,7)
    # モデル（7次元出力に統一）
    device = _torch.device("cuda" if _torch.cuda.is_available() else "cpu")
    model = load_model_from_weights(weights_path, device=device, out_dim=7).to(device)
    model.eval()
    # 予測
    preds = []
    for p in files:
        pred7 = _predict7_single_file(model, p, device=device, num_points=num_points,
                                      preprocess_cfg=preprocess_cfg, tta_times=tta_times, rotate_axis=rotate_axis)
        preds.append(pred7)
    preds = _np.stack(preds, axis=0)  # (B,7)
    # 散布図
    keys = ["L1", "L2", "L3", "R1", "R2", "theta1", "theta2"]
    B = preds.shape[0]
    fig, axes = _plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.flatten()
    for i, k in enumerate(keys):
        ax = axes[i]
        x = trues[:, i].copy()
        y = preds[:, i].copy()
        # 角度は度（deg）で描画
        if i >= 5:
            x = _np.degrees(x)
            y = _np.degrees(y)
            ax.set_xlabel(f"True {k} (deg)")
            ax.set_ylabel(f"Pred {k} (deg)")
        else:
            ax.set_xlabel(f"True {k}")
            ax.set_ylabel(f"Pred {k}")
        ax.scatter(x, y, c="tab:blue", s=50, alpha=0.85, edgecolors="white", linewidths=0.6)
        minv = min(_np.min(x), _np.min(y))
        maxv = max(_np.max(x), _np.max(y))
        ax.plot([minv, maxv], [minv, maxv], color="tab:gray", linestyle="--", linewidth=1.0)
        if B >= 2:
            r = _np.corrcoef(x, y)[0, 1]
            ax.set_title(f"{k} (r={r:.3f})")
        else:
            ax.set_title(k)
        ax.grid(True, alpha=0.3)
    axes[-1].axis("off")
    _plt.tight_layout()
    _os.makedirs(_os.path.dirname(out_path), exist_ok=True)
    _plt.savefig(out_path, dpi=150)
    _plt.show()
    print(f"Saved scatter figure to {out_path}")
def _scatter_cli():
    """
    コマンドラインインターフェース。
    既存の学習 main と干渉しないよう、環境変数 RUN_SCATTER=1 のときのみ起動します。
    """
    parser = _argparse.ArgumentParser(description="True vs Pred scatter plot")
    parser.add_argument("--files", nargs="+", required=True,
                        help="Target xyz files (e.g., pipe_case_07.xyz pipe_case_08.xyz ...)")
    parser.add_argument("--weights", default="results8-2/best_model.pt",
                        help="Path to trained weights")
    parser.add_argument("--angle-unit", choices=["rad","deg"], default="rad",
                        help="Angle unit in params_all.txt")
    parser.add_argument("--label-length-scale", type=float, default=1.0,
                        help="Length scale for labels (e.g., mm->m is 0.001)")
    parser.add_argument("--coord-scale", type=float, default=1.0,
                        help="Coordinate scale for point cloud (e.g., mm->m is 0.001)")
    parser.add_argument("--canonical-align", action="store_true",
                        help="Enable PCA canonical alignment")
    parser.add_argument("--normalize-method", choices=["unit_sphere","zscore","none"], default="unit_sphere",
                        help="Coordinate normalization method")
    parser.add_argument("--use-sor", action="store_true",
                        help="Enable SOR outlier removal (default: enabled)")
    parser.add_argument("--no-sor", action="store_true",
                        help="Disable SOR outlier removal")
    parser.add_argument("--sor-k", type=int, default=16, help="SOR k-neighbors")
    parser.add_argument("--sor-std-ratio", type=float, default=2.0, help="SOR std ratio")
    parser.add_argument("--num-points", type=int, default=4096, help="Subsample points per cloud")
    parser.add_argument("--tta-times", type=int, default=8, help="TTA samples")
    parser.add_argument("--rotate-axis", choices=["none","z"], default="none", help="Optional TTA rotation axis")
    parser.add_argument("--out", default="results8-2/scatter_pred_vs_true.png", help="Output path for figure")
    args = parser.parse_args()
    preprocess_cfg = {
        "coord_scale": args.coord_scale,
        "dedup_round_decimals": 6,
        "use_sor": False if args.no_sor else True if args.use_sor else True,
        "sor_k": args.sor_k,
        "sor_std_ratio": args.sor_std_ratio,
        "canonical_align": args.canonical_align,
        "normalize_method": args.normalize_method,
    }
    rotate_axis = None if args.rotate_axis == "none" else args.rotate_axis
    scatter_plot(files=args.files,
                 weights_path=args.weights,
                 angle_unit_labels=args.angle_unit,
                 label_length_scale=args.label_length_scale,
                 preprocess_cfg=preprocess_cfg,
                 num_points=args.num_points,
                 tta_times=args.tta_times,
                 rotate_axis=rotate_axis,
                 out_path=args.out)
# 環境変数 RUN_SCATTER=1 のときだけ、CLIを起動
if __name__ == "__main__" and _os.environ.get("RUN_SCATTER") == "1":
    _scatter_cli()
# ========= 追記ここまで =========
"""
使い方
- まずは学習を実行して best_model.pt を作ります。
  - python Train_pipe8-2.py
- 散布図のみ実行（学習は走らせず、RUN_SCATTER=1 を付与）
  - Windows: set RUN_SCATTER=1 && python Train_pipe8-2.py --files pipe_case_07.xyz pipe_case_08.xyz pipe_case_09.xyz --weights results8-2/best_model.pt --angle-unit rad
  - Linux/Mac: RUN_SCATTER=1 python Train_pipe8-2.py --files pipe_case_07.xyz pipe_case_08.xyz pipe_case_09.xyz --weights results8-2/best_model.pt --angle-unit rad
- よく使うオプション
  - params_all.txt の角度が度: --angle-unit deg
  - ラベル・座標がmm: --label-length-scale 0.001 --coord-scale 0.001
  - 学習時にPCAアラインメント使用: --canonical-align
  - 正規化: --normalize-method unit_sphere
  - 出力先変更: --out results8-2/my_scatter.png
補足
- 9→7変換は推論時にのみ行います（学習は従来通り9次元で実施）。これにより「not enough values to unpack」系の不整合を回避できます。
- 前処理（coord_scale／SOR／アラインメント／正規化）は学習時と揃えてください。揃っていないと散布図の精度が悪化します。
"""


In [0]:
scatter_plot(
    files=['pipe_case_01.xyz','pipe_case_02.xyz','pipe_case_03.xyz'],
    weights_path='results8-2/best_model.pt',
    angle_unit_labels='rad',
    label_length_scale=1.0,
    preprocess_cfg={'coord_scale':1.0,'dedup_round_decimals':6,'use_sor':True,'sor_k':16,'sor_std_ratio':2.0,'canonical_align':False,'normalize_method':'unit_sphere'},
    num_points=4096, tta_times=8, rotate_axis=None,
    out_path='results8-2/scatter_train_pred_vs_true.png'
)