In [1]:
# train_unet_res_dilated.py
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.optim as optim

[0;93m2025-11-16 16:31:48.892023521 [W:onnxruntime:Default, device_discovery.cc:164 DiscoverDevicesForPlatform] GPU device discovery failed: device_discovery.cc:89 ReadFileContents Failed to open file: "/sys/class/drm/card0/device/vendor"[m


In [2]:
# -----------------------
# Residual Block
# -----------------------
class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.in1 = nn.InstanceNorm2d(ch)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)
        self.in2 = nn.InstanceNorm2d(ch)

    def forward(self, x):
        h = F.relu(self.in1(self.conv1(x)))
        h = self.in2(self.conv2(h))
        return x + h


# -----------------------
# Dilated Conv Block
# -----------------------
class DilatedConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dilation=2):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, dilation=dilation, padding=dilation)
        self.inorm = nn.InstanceNorm2d(out_ch)

    def forward(self, x):
        return F.relu(self.inorm(self.conv(x)))


In [3]:
import torch
import torch.nn as nn

# ResBlockとDilatedConvBlockの定義が必要です。ここでは省略します。
# -----------------------
# U-Net Generator (1ch Input / 1ch Output)
# -----------------------
class UNetGenerator(nn.Module):
    # 修正点 1: in_channels のデフォルト値を 3 から 1 に変更
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        self.enc1 = nn.Sequential(
            # 修正点 2: nn.Conv2d の入力チャンネル数を in_channels に設定
            nn.Conv2d(in_channels, 64, 3, padding=1), 
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            ResBlock(64)
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            ResBlock(128)
        )
        self.pool2 = nn.MaxPool2d(2)

        self.enc3 = nn.Sequential(
            DilatedConvBlock(128, 256, dilation=2),
            ResBlock(256)
        )
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = nn.Sequential(
            DilatedConvBlock(256, 512, dilation=4),
            ResBlock(512),
            DilatedConvBlock(512, 512, dilation=4),
        )

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = nn.Sequential(
            DilatedConvBlock(512, 256, dilation=2),
            ResBlock(256)
        )

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = nn.Sequential(
            DilatedConvBlock(256, 128, dilation=2),
            ResBlock(128)
        )

        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = nn.Sequential(
            DilatedConvBlock(128, 64, dilation=1),
            ResBlock(64)
        )

        self.out_conv = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)

        e2 = self.enc2(p1)
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)
        p3 = self.pool3(e3)

        b = self.bottleneck(p3)

        u3 = self.up3(b)
        d3 = self.dec3(torch.cat([u3, e3], dim=1))

        u2 = self.up2(d3)
        d2 = self.dec2(torch.cat([u2, e2], dim=1))

        u1 = self.up1(d2)
        d1 = self.dec1(torch.cat([u1, e1], dim=1))

        return self.out_conv(d1)

In [4]:
def sobel_edges(x):
    sobel_x = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]],
                           dtype=torch.float32, device=x.device).view(1,1,3,3)
    sobel_y = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]],
                           dtype=torch.float32, device=x.device).view(1,1,3,3)

    g_x = F.conv2d(x, sobel_x, padding=1)
    g_y = F.conv2d(x, sobel_y, padding=1)

    return torch.sqrt(g_x**2 + g_y**2 + 1e-6)
    
def edge_loss(pred, target):
    return F.l1_loss(sobel_edges(pred), sobel_edges(target))

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os



# Dataset
class SketchDataset(torch.utils.data.Dataset):
    def __init__(self, rough_dir, line_dir, transform=None):
        self.rough_files = sorted(os.listdir(rough_dir))
        self.line_files = sorted(os.listdir(line_dir))
        self.rough_dir = rough_dir
        self.line_dir = line_dir
        self.transform = transform

    def __len__(self):
        return len(self.rough_files)

    def __getitem__(self, idx):
        rough = Image.open(os.path.join(self.rough_dir, self.rough_files[idx])).convert("L")
        line  = Image.open(os.path.join(self.line_dir,  self.line_files[idx])).convert("L")

        if self.transform:
            rough = self.transform(rough)
            line  = self.transform(line)

        return rough, line

# Transform
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
])

dataset = SketchDataset("dataset/train/rough", "dataset/train/line", transform)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

model = UNetGenerator(in_channels=1, out_channels=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0002)

os.makedirs("checkpoints", exist_ok=True)

# Training
for epoch in range(50):
    for rough, line in loader:
        rough, line = rough.to(device), line.to(device)

        optimizer.zero_grad()

        pred = model(rough)

        loss_main = criterion(pred, line)
        loss_edge = edge_loss(torch.sigmoid(pred), line)

        loss = loss_main + 0.2 * loss_edge

        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}: loss={loss.item():.4f}")

    torch.save(model.state_dict(), f"checkpoints/unet_1ch_epoch{epoch+1}.pth")


Epoch 1: loss=0.4146
Epoch 2: loss=0.3349
Epoch 3: loss=0.2849
Epoch 4: loss=0.3207
Epoch 5: loss=0.3052
Epoch 6: loss=0.2934
Epoch 7: loss=0.2422
Epoch 8: loss=0.2644
Epoch 9: loss=0.2596
Epoch 10: loss=0.2083
Epoch 11: loss=0.2544
Epoch 12: loss=0.2189
Epoch 13: loss=0.2262
Epoch 14: loss=0.2655
Epoch 15: loss=0.1703
Epoch 16: loss=0.2395
Epoch 17: loss=0.2405
Epoch 18: loss=0.2028
Epoch 19: loss=0.2444
Epoch 20: loss=0.1921
Epoch 21: loss=0.1851
Epoch 22: loss=0.2483
Epoch 23: loss=0.2662
Epoch 24: loss=0.2499
Epoch 25: loss=0.3225
Epoch 26: loss=0.2274
Epoch 27: loss=0.2516
Epoch 28: loss=0.2153
Epoch 29: loss=0.2198
Epoch 30: loss=0.2077
Epoch 31: loss=0.1409
Epoch 32: loss=0.1782
Epoch 33: loss=0.1973
Epoch 34: loss=0.1345
Epoch 35: loss=0.2050
Epoch 36: loss=0.1323
Epoch 37: loss=0.1777
Epoch 38: loss=0.1542
Epoch 39: loss=0.1517
Epoch 40: loss=0.1861
Epoch 41: loss=0.1959
Epoch 42: loss=0.2353
Epoch 43: loss=0.1374
Epoch 44: loss=0.1550
Epoch 45: loss=0.1233
Epoch 46: loss=0.17

In [16]:
import torch
import torchvision.transforms.functional as TF
from PIL import Image


device = "cuda" if torch.cuda.is_available() else "cpu"

model = UNetGenerator(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load("checkpoints/unet_1ch_epoch50.pth", map_location=device))
model.eval()

img = Image.open("test/rough/sample.jpg").convert("L")
img = TF.resize(img, (128,128))
img = TF.to_tensor(img).unsqueeze(0).to(device)

with torch.no_grad():
    out = torch.sigmoid(model(img))  # 0〜1

out = out.clamp(0,1)
out_img = TF.to_pil_image(out[0].cpu())
out_img.save("results/res_dilated_loss_line.png")

print("done")


done


result

![result](./results/result_line.png)


sample  
![sample](./test/rough/sample.jpg)