<a href="https://colab.research.google.com/github/iiyama-lab/semi_tutorial/blob/main/20220614.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# U-Net

# 0.　ドライブのマウント

In [None]:
from google.colab import drive
drive.mount("/content/drive")
!pwd

datadir = "/content/drive/MyDrive/iiyama-lab2022/data/face/train"
val_datadir = "/content/drive/MyDrive/iiyama-lab2022/data/face/test"
#datadir = "/root/data/share/face/train"
#val_datadir = "/root/data/share/face/test"

# 1. いろいろインポート

In [None]:
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision import transforms
import os
import glob

import matplotlib.pyplot as plt
import numpy as np


# 2. データローダの作成

In [None]:
class ColorizationImageDataset(Dataset):
    """カラー化用のImageDataset

    Attributes:
        filenames (list): 画像(PNG画像)のファイル名リスト
        transform (obj): 画像変換用の関数
    """

    def __init__(self, img_dir, transform):
        """
        Args:
            img_dir: 画像が置いてあるディレクトリ名
            transform: 画像変換用の関数
        """
        self.transform = transform
        self.filenames = glob.glob(os.path.join(img_dir, "*/*.png"))
        print(f"{self.__len__()} images for training")

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = read_image(img_path)        
        image = self.transform(image)        
        # gray_image = torch.mean(image, dim=0, keepdim=True)
        gray_image = transforms.functional.rgb_to_grayscale(image)
        return gray_image, image

In [None]:
class ImageTransform():
    def __init__(self,  mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
        self.data_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        return self.data_transform(img)

def tensor2RGBimage(image,  mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    """
    Args:
        image: pytorch Tensor
    """
    inp = image.numpy().transpose((1, 2, 0))
    mean = np.array(mean)
    std = np.array(std)
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp        

## 2.1 データローダテスト

In [None]:
import matplotlib.pyplot as plt
from pylab import cm
dataset = ColorizationImageDataset(datadir, transform=ImageTransform())
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=10, shuffle=True
)
gray_images, rgb_images = next(iter(dataloader))

fig, axes = plt.subplots(1, 2)
axes[0].imshow(gray_images[0,0], cmap=cm.gray)
axes[1].imshow(tensor2RGBimage(rgb_images[0]))
plt.show()

# 3. モデルの作成

U-Netの元の論文の構造です
![model](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

- paddingを行ってないので3x3の畳み込みを行う度に画像サイズが2ずつ小さくなってます。
- このままでもいいのですが、今回は（面倒なので）paddingを行って画像サイズが変わらないようにします
- 今回はカラー化なので出力は（２チャンネルではなく）3チャンネルです

## 3.1 部品をひとつずつ作っていきましょう

In [None]:
class Conv2d_twice(nn.Module):
    """
    3x3の畳み込みを2回行う層です。
    in_channels : 入力チャンネル
    mid_channels : 1回目の畳み込みの出力チャンネル
    out_channels : 出力チャンネル
    """
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.layer(x)

In [None]:
class EncoderParts(nn.Module):
    """
    U-Netのパーツ
    Max Poolingで画像サイズを半分にしてから、
    上のConv2d_twiceを実行する
    """
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.layer = nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            Conv2d_twice(in_channels, mid_channels, out_channels)
        )
    
    def forward(self, x):
        return self.layer(x)


class DecoderParts(nn.Module):
    """
    U-Netのパーツ
    ひとつ前のDecoderからの出力を Upsample (ConvTranspose2dでもいいのかも)
    その結果をEncoderからの出力と結合してから Conv2d_twice
    """
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv = Conv2d_twice(in_channels*2, mid_channels, out_channels)

    def forward(self, x_dec, x_enc):
        out = self.upsample(x_dec)
        out = torch.cat([x_enc, out], dim=1)
        return self.conv(out)


In [None]:
class UNet(nn.Module):
    """
    U-Net
    """
    def __init__(self):
        super().__init__()
        self.encoder1 = Conv2d_twice(1, 64, 64) # 1x256x256 --> 64x256x256
        self.encoder2 = EncoderParts(64, 128, 128) # 64x256x256 --> 128x128x128
        self.encoder3 = EncoderParts(128, 256, 256) # 128x128x128 --> 256x64x64
        self.encoder4 = EncoderParts(256, 512, 512) # 256x64x64 --> 512x32x32
        self.buttle_neck = EncoderParts(512, 1024, 512) # 512x32x32 --> 512x16x16
        self.decoder1 = DecoderParts(512, 512, 256) # 512x16x16 & 512x32x32 --> 256x32x32
        self.decoder2 = DecoderParts(256, 256, 128) # 256x32x32 & 256x64x64 --> 128x64x64
        self.decoder3 = DecoderParts(128, 128, 64) # 128x64x64 & 128x128x128 --> 64x128x128
        self.decoder4 = DecoderParts(64, 64, 64) # 64x128x128 & 64x256x256 --> 64x256x256
        self.last = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1)
        )
    
    def forward(self, x):
        out1 = self.encoder1(x)
        out2 = self.encoder2(out1)
        out3 = self.encoder3(out2)
        out4 = self.encoder4(out3)
        out = self.buttle_neck(out4)
        out = self.decoder1(out, out4)
        out = self.decoder2(out, out3)
        out = self.decoder3(out, out2)
        out = self.decoder4(out, out1)
        out = self.last(out)
        return out


## 3.2 試しに動かしてみましょう

In [None]:
model = UNet()
out = model(gray_images)
out = out.detach()

fig, axes = plt.subplots(1, 3)
axes[0].imshow(gray_images[0,0], cmap=cm.gray)
axes[1].imshow(tensor2RGBimage(rgb_images[0]))
axes[2].imshow(tensor2RGBimage(out[0]))
plt.show()

# 4.訓練してみましょう

In [None]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
batch_size = 32

dataset = ColorizationImageDataset(datadir, transform=ImageTransform())
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=4
)
val_dataset = ColorizationImageDataset(val_datadir, transform=ImageTransform())
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False,  num_workers=4
)

In [None]:
model = UNet()
optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-5, betas=[0.9, 0.999])

model.to(device)
torch.backends.cudnn.benchmark = True

criterion = nn.L1Loss()

num_epoches = 100
for epoch in range(num_epoches):
    epoch_loss = 0
    model.train()
    pbar = tqdm(total=len(dataloader.dataset))
    for i, (gray_images, rgb_images) in enumerate(dataloader):
        gray_images = gray_images.to(device)
        rgb_images = rgb_images.to(device)

        out = model(gray_images)
        loss = criterion(out, rgb_images)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        pbar.set_postfix({"train loss:": epoch_loss / (i+1)})
        pbar.update(batch_size)
    

In [None]:
gray_images, rgb_images = next(iter(val_dataloader))

idx = 0
fig, axes = plt.subplots(1, 3)
axes[0].imshow(gray_images[idx,0], cmap=cm.gray)
axes[1].imshow(tensor2RGBimage(rgb_images[idx]))

model.to(device)
model.eval()
gray_images = gray_images.to(device)
out = model(gray_images)
out = out.to('cpu')
out = out.detach()
axes[2].imshow(tensor2RGBimage(out[idx]))
plt.show()

