# canny 導入後の edge_loss 調整 検討  
  
  
+ [x] loss_edgeの効果を弱める  
loss = loss_main + 0.2 * loss_edge  
1.0より0.5のほうが明らかにいい  
0.2でも0.5でクセの差くらいしかないので、このパラメータは保留する  

+ [x] 線画の反転トレーニング  
なくすと薄すぎるため、継続

+ [x] 作り込んだUNetgeneratorと初期型の比較
初期型UNetgeneratorのほうがやや自然だが、どちらが良いかは決定的でない  
いくらか多めにデータを使って検証しきめる


---




### Residual Block

$$
\boxed{
\mathrm{ResBlock}(x)
= x + \mathrm{IN}\!\left(
W_2 * \mathrm{ReLU}\!\left(
\mathrm{IN}(W_1 * x + b_1)
\right) + b_2
\right)
}
$$


---




### Dilated Conv Block

$$
\boxed{
\mathrm{DilatedConvBlock}(x)
= \mathrm{ReLU}\!\left(
\mathrm{IN}(W *_{d} x + b)
\right)
}
$$




# U-Net Generator (1ch → 1ch) : Mathematical Formulation

## Encoder

### 1. First Encoder Block
入力：\(x\)

$$
e_1 = \mathrm{ResBlock}\!\left(
\mathrm{ReLU}\left(
\mathrm{IN}\left( W_{1} * x + b_{1} \right)
\right)
\right)
$$

$$
p_1 = \mathrm{MaxPool}(e_1)
$$


### 2. Second Encoder Block
$$
e_2 = \mathrm{ResBlock}\!\left(
\mathrm{ReLU}\left(
\mathrm{IN}\left( W_{2} * p_1 + b_{2} \right)
\right)
\right)
$$

$$
p_2 = \mathrm{MaxPool}(e_2)
$$


### 3. Third Encoder Block (Dilated)
$$
e_3 = \mathrm{ResBlock}\!\left(
\mathrm{ReLU}\left(
\mathrm{IN}\left( W^{(d=2)}_{3} *_{2} p_2 + b_{3} \right)
\right)
\right)
$$

$$
p_3 = \mathrm{MaxPool}(e_3)
$$


---

## Bottleneck

### Dilated → Residual → Dilated
$$
b = 
\mathrm{DilatedConv}^{(d=4)}_{512}\!\Bigg(
\mathrm{ResBlock}_{512}\!\Big(
\mathrm{DilatedConv}^{(d=4)}_{512}(p_3)
\Big)
\Bigg)
$$

---

## Decoder

### 1. Decoder Stage 3
Upsample:
$$
u_3 = \mathrm{ConvTranspose}(b)
$$

Skip connection & convolution:
$$
d_3 = \mathrm{ResBlock}\!\left(
\mathrm{ReLU}\left(
\mathrm{IN}\left(
W_{d3} * [u_3 \, ; \, e_3] + b_{d3}
\right)\right)\right)
$$


### 2. Decoder Stage 2
$$
u_2 = \mathrm{ConvTranspose}(d_3)
$$

$$
d_2 = \mathrm{ResBlock}\!\left(
\mathrm{ReLU}\left(
\mathrm{IN}\left(
W_{d2} * [u_2 \, ; \, e_2] + b_{d2}
\right)\right)\right)
$$


### 3. Decoder Stage 1
$$
u_1 = \mathrm{ConvTranspose}(d_2)
$$

$$
d_1 = \mathrm{ResBlock}\!\left(
\mathrm{DilatedConv}^{(d=1)}\big([u_1 \, ; \, e_1]\big)
\right)
$$

---

## Output Layer

$$
y = W_{\text{out}} * d_1 + b_{\text{out}}
$$

---

# Final Generator Output

$$
\boxed{
G(x) = y = W_{\text{out}} * d_1 + b_{\text{out}}
}
$$


## Sobel Edge Extraction

Sobel フィルタ \(S_x, S_y\):

$$
S_x =
\begin{bmatrix}
-1 & 0 & 1 \\
-2 & 0 & 2 \\
-1 & 0 & 1
\end{bmatrix},
\qquad
S_y =
\begin{bmatrix}
-1 & -2 & -1 \\
0 & 0 & 0 \\
1 & 2 & 1
\end{bmatrix}
$$

入力画像（1ch）を \(x\) とすると，

$$
g_x = S_x * x, 
\qquad
g_y = S_y * x
$$

出力するエッジ強度（勾配の大きさ）は：

$$
\mathrm{SobelEdges}(x)
=
\sqrt{
g_x^2 + g_y^2 + 10^{-6}
}
$$

## Edge Loss

予測画像：\(\hat{y}\)  
教師画像：\(y\)

Sobel エッジマップ：

$$
E(\hat{y}) = \mathrm{SobelEdges}(\hat{y}),
\qquad
E(y) = \mathrm{SobelEdges}(y)
$$

L1 損失によるエッジ損失：

$$
\mathrm{EdgeLoss}(\hat{y}, y)
=
\|\, E(\hat{y}) - E(y) \,\|_{1}
$$


In [7]:
import numpy
import cv2
import torch
import os
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.optim as optim
import torchvision.transforms.functional as TF

In [8]:
# -----------------------
# Residual Block
# -----------------------
class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.in1 = nn.InstanceNorm2d(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.in2 = nn.InstanceNorm2d(ch)

    def forward(self, x):
        h = F.relu(self.in1(self.conv1(x)))
        h = self.in2(self.conv2(h))
        return x + h


# -----------------------
# Dilated Conv Block
# -----------------------
class DilatedConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dilation=2):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, dilation=dilation, padding=dilation)
        self.inorm = nn.InstanceNorm2d(out_ch)

    def forward(self, x):
        return F.relu(self.inorm(self.conv(x)))


In [9]:
# smple 初期版
# -----------------------
# U-Net Generator (1ch Input / 1ch Output)
# -----------------------
class UNetGenerator(nn.Module):
    # 修正点 1: in_channels のデフォルト値を 3 から 1 に変更
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1), 
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            ResBlock(64)
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            ResBlock(128)
        )
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = nn.Sequential(
            DilatedConvBlock(128, 256, dilation=2),
            ResBlock(256)
        )
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = nn.Sequential(
            DilatedConvBlock(256, 512, dilation=4),
            ResBlock(512),
            DilatedConvBlock(512, 512, dilation=4),
        )

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = nn.Sequential(
            DilatedConvBlock(512, 256, dilation=2),
            ResBlock(256)
        )

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = nn.Sequential(
            DilatedConvBlock(256, 128, dilation=2),
            ResBlock(128)
        )

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = nn.Sequential(
            DilatedConvBlock(128, 64, dilation=1),
            ResBlock(64)
        )

        self.out_conv = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)

        e2 = self.enc2(p1)
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)
        p3 = self.pool3(e3)

        b = self.bottleneck(p3)

        u3 = self.up3(b)
        d3 = self.dec3(torch.cat([u3, e3], dim=1))

        u2 = self.up2(d3)
        d2 = self.dec2(torch.cat([u2, e2], dim=1))

        u1 = self.up1(d2)
        d1 = self.dec1(torch.cat([u1, e1], dim=1))

        return self.out_conv(d1)

In [5]:

# -----------------------
# U-Net Generator (1ch Input / 1ch Output)
# -----------------------
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1), 
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            ResBlock(64)
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            ResBlock(128)
        )
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = nn.Sequential(
            DilatedConvBlock(128, 256, dilation=2),
            ResBlock(256)
        )
        self.pool3 = nn.MaxPool2d(2)


        self.bottleneck = nn.Sequential(
        DilatedConvBlock(256, 512, dilation=4), 
        ResBlock(512),
        DilatedConvBlock(512, 512, dilation=4), 
        )

        self.up3 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1)
        self.dec3 = nn.Sequential(
        nn.Conv2d(512, 256, 3, padding=1),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
        ResBlock(256)
        )

        self.up2 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            ResBlock(128)
        )
        
        self.up1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
        self.dec1 = nn.Sequential(
            DilatedConvBlock(128, 64, dilation=1),
            ResBlock(64)
        )

        self.out_conv = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)

        e2 = self.enc2(p1)
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)
        p3 = self.pool3(e3)

        b = self.bottleneck(p3)

        u3 = self.up3(b)
        d3 = self.dec3(torch.cat([u3, e3], dim=1))

        u2 = self.up2(d3)
        d2 = self.dec2(torch.cat([u2, e2], dim=1))

        u1 = self.up1(d2)
        d1 = self.dec1(torch.cat([u1, e1], dim=1))

        return self.out_conv(d1)

In [10]:
import cv2
import numpy as np

def canny_edges(x, low_th=50, high_th=150):
    """
    x: (B, 1, H, W) tensor, range [0,1]
    return: (B, 1, H, W) tensor
    """
    x_np = (x.detach().cpu().numpy() * 255).astype(np.uint8)  # B,1,H,W → uint8
    
    edges = []
    for i in range(x_np.shape[0]):
        img = x_np[i,0]
        edge = cv2.Canny(img, low_th, high_th)

        # 正規化してTensorへ
        edge = edge.astype(np.float32) / 255.0
        edges.append(edge)

    edges = np.stack(edges, axis=0)  # (B,H,W)
    edges = torch.from_numpy(edges).unsqueeze(1).to(x.device)  # (B,1,H,W)
    return edges


In [11]:
def sobel_edges(x):
    sobel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]],
                           dtype=torch.float32, device=x.device).view(1,1,3,3)
    sobel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]],
                           dtype=torch.float32, device=x.device).view(1,1,3,3)

    g_x = F.conv2d(x, sobel_x, padding=1)
    g_y = F.conv2d(x, sobel_y, padding=1)

    return torch.sqrt(g_x**2 + g_y**2 + 1e-6)
    
def edge_loss(pred, target):
    pred_edge = canny_edges(pred)       # pred は Sigmoid 前の値
    target_edge = canny_edges(target)   # target は ground truth
    return F.l1_loss(pred_edge, target_edge)


In [13]:
# Dataset
class SketchDataset(torch.utils.data.Dataset):
    def __init__(self, rough_dir, line_dir, transform=None):
        self.rough_files = sorted(os.listdir(rough_dir))
        self.line_files = sorted(os.listdir(line_dir))
        self.rough_dir = rough_dir
        self.line_dir = line_dir
        self.transform = transform

    def __len__(self):
        return len(self.rough_files)

    def __getitem__(self, idx):
        rough = Image.open(os.path.join(self.rough_dir, self.rough_files[idx])).convert("L")
        line  = Image.open(os.path.join(self.line_dir,  self.line_files[idx])).convert("L")

        if self.transform:
            rough = self.transform(rough)
            line  = self.transform(line)

        return rough, line

In [14]:

# ---------------------------------------------------------------------------
# II. データローディングと設定
# ---------------------------------------------------------------------------
# Dataset クラスはそのまま使用 (paired_transform 引数は不要)

# Transform の修正: Resize を 256x256 に変更
transform = transforms.Compose([
    transforms.Resize((256, 256)), # ★ 128x128 から 256x256 に変更 ★
    transforms.ToTensor(),
])

dataset = SketchDataset("dataset/train/rough", "dataset/train/line", transform)

# DataLoader: 解像度アップに伴い、VRAM節約のためバッチサイズを 2 に下げることを推奨
loader = DataLoader(dataset, batch_size=2, shuffle=True) 

device = "cuda" if torch.cuda.is_available() else "cpu"

model = UNetGenerator(in_channels=1, out_channels=1).to(device)


pos_weight_value = 3.0 
pos_weight_tensor = torch.tensor(pos_weight_value, dtype=torch.float).to(device)

# 2. criterion に Tensor 型の pos_weight を渡す
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
# ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★
# オプティマイザ: 学習率を 0.0001 に下げる (暴走防止)
optimizer = optim.Adam(model.parameters(), lr=0.0001) 

os.makedirs("checkpoints", exist_ok=True)


# ---------------------------------------------------------------------------
# III. トレーニングループ
# ---------------------------------------------------------------------------
for epoch in range(50):
    total_loss = 0.0
    num_batches = 0
    model.train() 
    
    for rough, line in loader:
        rough, line = rough.to(device), line.to(device)

        line = 1.0 - line
        
        optimizer.zero_grad()
        pred = model(rough)
    
        # 1. メイン損失 (BCE with pos_weight)
        loss_main_bce = criterion(pred, line)
    
        # 2. L1損失 (Sigmoid後の出力とターゲットの絶対誤差)
        # L1損失を導入することで、ピクセル値がターゲットに近づくように強制する。
        loss_main_l1 = F.l1_loss(torch.sigmoid(pred), line) 
    
        # 3. メイン損失の組み合わせとエッジ損失
        # BCEとL1をミックス (例: 80% BCE, 20% L1)
        loss_main = 0.8 * loss_main_bce + 0.2 * loss_main_l1 
    
        loss_edge = edge_loss(torch.sigmoid(pred), line) 

        # 総損失 0.2でも0.5でクセの差くらいしかないので、このパラメータの検討は保留する
        loss = loss_main + 0.2 * loss_edge

        loss.backward()
        optimizer.step()
        
        # ログ改善
        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}: avg_loss={avg_loss:.4f}")

    torch.save(model.state_dict(), f"checkpoints/unet_1ch_epoch{epoch+1}.pth")

Epoch 1: avg_loss=0.5712
Epoch 2: avg_loss=0.4557
Epoch 3: avg_loss=0.4234
Epoch 4: avg_loss=0.4019
Epoch 5: avg_loss=0.3858
Epoch 6: avg_loss=0.3728
Epoch 7: avg_loss=0.3643
Epoch 8: avg_loss=0.3572
Epoch 9: avg_loss=0.3527
Epoch 10: avg_loss=0.3474
Epoch 11: avg_loss=0.3425
Epoch 12: avg_loss=0.3399
Epoch 13: avg_loss=0.3388
Epoch 14: avg_loss=0.3354
Epoch 15: avg_loss=0.3324
Epoch 16: avg_loss=0.3301
Epoch 17: avg_loss=0.3274
Epoch 18: avg_loss=0.3242
Epoch 19: avg_loss=0.3241
Epoch 20: avg_loss=0.3197
Epoch 21: avg_loss=0.3166
Epoch 22: avg_loss=0.3121
Epoch 23: avg_loss=0.3102
Epoch 24: avg_loss=0.3088
Epoch 25: avg_loss=0.3019
Epoch 26: avg_loss=0.3008
Epoch 27: avg_loss=0.2953
Epoch 28: avg_loss=0.2912
Epoch 29: avg_loss=0.2877
Epoch 30: avg_loss=0.2836
Epoch 31: avg_loss=0.2788
Epoch 32: avg_loss=0.2750
Epoch 33: avg_loss=0.2703
Epoch 34: avg_loss=0.2661
Epoch 35: avg_loss=0.2613
Epoch 36: avg_loss=0.2564
Epoch 37: avg_loss=0.2497
Epoch 38: avg_loss=0.2453
Epoch 39: avg_loss=0.

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = UNetGenerator(in_channels=1, out_channels=1).to(device)
# チェックポイントのロード...
model.load_state_dict(torch.load("checkpoints/unet_1ch_epoch50.pth", map_location=device))
model.eval()

img = Image.open("test/rough/sample.jpg").convert("L")

# ★★★ リサイズ解像度を 256x256 に変更 ★★★
img = TF.resize(img, (256, 256)) 
# ★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★★

img = TF.to_tensor(img).unsqueeze(0).to(device)

with torch.no_grad():
    out = torch.sigmoid(model(img))  # 0〜1

out = 1.0 - out 

out = out.clamp(0,1)
out_img = TF.to_pil_image(out[0].cpu())
out_img.save("results/plan2.png")

print("done")

done


result
![result](./results/edge_loss0.2.png)0.2
![result](./results/edge_loss0.5.png)0.5

![result](./results/simple.png)simple Unetgenerator

sample  
![sample](./test/rough/sample.jpg)