## unet

In [None]:
#@title unet模型
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt


class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.decoder4 = UNet._block(features * 16, features * 8, name="dec4")
        self.decoder3 = UNet._block(features * 8, features * 4, name="dec3")
        self.decoder2 = UNet._block_specail(features * 4, features * 2, name="dec2")
        self.decoder1 = UNet._block_specail(features * 2, features, name="dec1")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self._pool(enc1))
        enc3 = self.encoder3(self._pool(enc2))
        enc4 = self.encoder4(self._pool(enc3))

        bottleneck = self.bottleneck(self._pool(enc4))

        dec4 = self.upconv4(bottleneck)

        # 对enc4进行裁剪以匹配dec4的大小
        enc4 = self.center_crop(enc4, dec4.shape[2:])
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)

        dec3 = self.upconv3(dec4)
        enc3 = self.center_crop(enc3, dec3.shape[2:])
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        dec2 = self.upconv2(dec3)
        enc2 = self.center_crop(enc2, dec2.shape[2:])
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        dec1 = self.upconv1(dec2)
        enc1 = self.center_crop(enc1, dec1.shape[2:])
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv(dec1))

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]


    @staticmethod
    def _pool(x):
        return nn.MaxPool2d(kernel_size=2, stride=2)(x)

    @staticmethod
    def _block_specail(in_channels, features, name):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=features,
                kernel_size=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=features),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=features,
                out_channels=features,
                kernel_size=2,
                padding=2,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=features,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=features),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=features,
                out_channels=features,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=features),
            nn.ReLU(inplace=True),
        )


# 定义数据集和数据加载器
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 初始化模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1, out_channels=1).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 训练模型
epochs = 5
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, _ in train_loader:
        images = images.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, images)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}')



# 可视化模型的输出
model.eval()
with torch.no_grad():
    sample_data, _ = next(iter(train_loader))
    sample_data = sample_data.to(device)
    output = model(sample_data)

    # 显示原始图像和输出结果
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(sample_data[0].cpu().squeeze(), cmap="gray")
    ax[0].set_title("Original")
    ax[1].imshow(output[0].cpu().squeeze(), cmap="gray")
    ax[1].set_title("Output")
    plt.show()




In [None]:
#@title 验证Unet


# 可视化模型的输出
model.eval()
with torch.no_grad():
    sample_data, _ = next(iter(train_loader))
    sample_data = sample_data.to(device)
    output = model(sample_data)

    # 显示原始图像和输出结果
    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(sample_data[0].cpu().squeeze(), cmap="gray")
    ax[0].set_title("Original")
    ax[1].imshow(output[0].cpu().squeeze(), cmap="gray")
    ax[1].set_title("Output")
    plt.show()
