# 第9章 : 拡散モデルの実装

## U-Netの実装

畳み込み、バッチ正規化、ReLUを2回行うConvBlockを作成

In [2]:
import torch
from torch import nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.convs(x)

UNetを実装する

In [6]:
class UNet(nn.Module):
    def __init__(self, in_ch=1):
        super().__init__()
        self.down1 = ConvBlock(in_ch, 64)
        self.down2 = ConvBlock(64, 128)
        self.bot1 = ConvBlock(128, 256)
        self.up2 = ConvBlock(128 + 256, 128) # skip connectionがあるから
        self.up1 = ConvBlock(128 + 64, 64)
        self.out = nn.Conv2d(64, in_ch, 1)
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
    
    def forward(self, x):
        x1 = self.down1(x)
        x = self.maxpool(x1)
        x2 = self.down2(x)
        x = self.maxpool(x2)
        x = self.bot1(x)
        x = self.upsample(x)
        x = torch.cat([x, x2], dim=1) # テンソルの形状:(N, C, H, W)
        x = self.up2(x)
        x = self.upsample(x)
        x = torch.cat([x, x1], dim=1)
        x = self.up1(x)
        x = self.out(x)
        return x

UNetの出力の形状確認

In [7]:
model = UNet()
x = torch.randn(10, 1, 28, 28)
y = model(x)
print(y.shape)

torch.Size([10, 1, 28, 28])
