<a href="https://colab.research.google.com/github/iiyama-lab/semi_tutorial/blob/main/20220614.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# U-Net

# 0.　ドライブのマウント

In [None]:
from google.colab import drive
drive.mount("/content/drive")
!pwd

datadir = "/content/drive/MyDrive/iiyama-lab2022/data/face/train"
val_datadir = "/content/drive/MyDrive/iiyama-lab2022/data/face/test"
#datadir = "/root/data/share/face/train"
#val_datadir = "/root/data/share/face/test"

# 1. いろいろインポート

In [None]:
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.io import read_image, ImageReadMode
from torchvision import transforms
import os
import glob

import matplotlib.pyplot as plt
import numpy as np


# 2. データローダの作成

In [None]:
class ColorizationImageDataset(Dataset):
    """カラー化用のImageDataset

    Attributes:
        filenames (list): 画像(PNG画像)のファイル名リスト
        transform_rgb & transform_gray (obj): 画像変換用の関数
    """

    def __init__(self, img_dir, transform_rgb, transform_gray):
        """
        Args:
            img_dir: 画像が置いてあるディレクトリ名
            transform_rgb & transform_gray: 画像変換用の関数
        """
        self.transform_rgb = transform_rgb
        self.transform_gray = transform_gray
        self.filenames = glob.glob(os.path.join(img_dir, "*/*.png"))
        print(f"{self.__len__()} images for training")

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = read_image(img_path, mode=ImageReadMode.RGB)
        image = self.transform_rgb(image)        
        gray_image = read_image(img_path, mode=ImageReadMode.GRAY) 
        gray_image = self.transform_gray(gray_image)
        return gray_image, image

In [None]:
class ImageTransform():
    def __init__(self,  mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
        self.data_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        return self.data_transform(img)

def tensor2RGBimage(image,  mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    """
    Args:
        image: pytorch Tensor
    """
    inp = image.numpy().transpose((1, 2, 0))
    mean = np.array(mean)
    std = np.array(std)
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

## 2.1 データローダテスト

In [None]:
import matplotlib.pyplot as plt
from pylab import cm
dataset = ColorizationImageDataset(datadir, transform_rgb=ImageTransform(), transform_gray=ImageTransform(mean=0.5, std=0.5))
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=10, shuffle=True
)
gray_images, rgb_images = next(iter(dataloader))

fig, axes = plt.subplots(1, 2)
axes[0].imshow(gray_images[0,0], cmap=cm.gray)
axes[1].imshow(tensor2RGBimage(rgb_images[0]))
plt.show()

# 3. モデルの作成

U-Netの元の論文の構造です
![model](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

- paddingを行ってないので3x3の畳み込みを行う度に画像サイズが2ずつ小さくなってます。
- このままでもいいのですが、今回は（面倒なので）paddingを行って画像サイズが変わらないようにします
- 今回はカラー化なので出力は（２チャンネルではなく）3チャンネルです

## 3.1 部品をひとつずつ作っていきましょう

In [None]:
class Conv2d_twice(nn.Module):
    """
    3x3の畳み込みを2回行う層です。
    in_channels : 入力チャンネル
    mid_channels : 1回目の畳み込みの出力チャンネル
    out_channels : 出力チャンネル
    """
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.layer(x)

In [None]:
class EncoderParts(nn.Module):
    """
    U-Netのパーツ
    Max Poolingで画像サイズを半分にしてから、
    上のConv2d_twiceを実行する
    """
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.layer = nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            Conv2d_twice(in_channels, mid_channels, out_channels)
        )
    
    def forward(self, x):
        return self.layer(x)


class DecoderParts(nn.Module):
    """
    U-Netのパーツ
    ひとつ前のDecoderからの出力を Upsample (ConvTranspose2dでもいいのかも)
    その結果をEncoderからの出力と結合してから Conv2d_twice
    """
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv = Conv2d_twice(in_channels*2, mid_channels, out_channels)

    def forward(self, x_dec, x_enc):
        out = self.upsample(x_dec)
        out = torch.cat([x_enc, out], dim=1)
        return self.conv(out)


In [None]:
class UNet(nn.Module):
    """
    U-Net
    """
    def __init__(self):
        super().__init__()
        self.encoder1 = Conv2d_twice(1, 64, 64) # 1x256x256 --> 64x256x256
        self.encoder2 = EncoderParts(64, 128, 128) # 64x256x256 --> 128x128x128
        self.encoder3 = EncoderParts(128, 256, 256) # 128x128x128 --> 256x64x64
        self.encoder4 = EncoderParts(256, 512, 512) # 256x64x64 --> 512x32x32
        self.bottle_neck = EncoderParts(512, 1024, 512) # 512x32x32 --> 512x16x16
        self.decoder1 = DecoderParts(512, 512, 256) # 512x16x16 & 512x32x32 --> 256x32x32
        self.decoder2 = DecoderParts(256, 256, 128) # 256x32x32 & 256x64x64 --> 128x64x64
        self.decoder3 = DecoderParts(128, 128, 64) # 128x64x64 & 128x128x128 --> 64x128x128
        self.decoder4 = DecoderParts(64, 64, 64) # 64x128x128 & 64x256x256 --> 64x256x256
        self.last = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1)
        )
    
    def forward(self, x):
        out1 = self.encoder1(x)
        out2 = self.encoder2(out1)
        out3 = self.encoder3(out2)
        out4 = self.encoder4(out3)
        out = self.bottle_neck(out4)
        out = self.decoder1(out, out4)
        out = self.decoder2(out, out3)
        out = self.decoder3(out, out2)
        out = self.decoder4(out, out1)
        out = self.last(out)
        return out


## 3.2 試しに動かしてみましょう

In [None]:
model = UNet()
out = model(gray_images)
out = out.detach()

fig, axes = plt.subplots(1, 3)
axes[0].imshow(gray_images[0,0], cmap=cm.gray)
axes[1].imshow(tensor2RGBimage(rgb_images[0]))
axes[2].imshow(tensor2RGBimage(out[0]))
plt.show()

# 4.訓練してみましょう

## 4.1 まずはDataLoaderの設定

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
batch_size = 4

dataset = ColorizationImageDataset(datadir, transform_rgb=ImageTransform(), transform_gray=ImageTransform(mean=0.5, std=0.5))
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=2
)
val_dataset = ColorizationImageDataset(val_datadir, transform_rgb=ImageTransform(), transform_gray=ImageTransform(mean=0.5, std=0.5))
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size*2, shuffle=False,  num_workers=2
)

## 4.2 モデルとoptimizerと損失関数の準備
今回はoptimizerにAdam、損失関数にL1損失を用います。

In [None]:
model = UNet()
optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-5, betas=[0.9, 0.999])

criterion = nn.L1Loss()

## 4.3 Train, Validation, あとはモデルの保存
ちょっと長くなるので、訓練部分と検証部分、そしてモデルの保存部分を別々の関数にしておきます

In [None]:
def train(dataloader):
    """
    1エポック分の学習
    """
    epoch_loss = 0
    model.train()
    
    # プログレスバー。不要ならコメントアウトしといてください
    pbar = tqdm(total=len(dataloader.dataset), leave=False)

    for i, (gray_images, rgb_images) in enumerate(dataloader):
        _batch_size = len(gray_images)
        gray_images = gray_images.to(device)
        rgb_images = rgb_images.to(device)

        out = model(gray_images)
        loss = criterion(out, rgb_images)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # プログレスバー関係の記述。不要なら下2行はコメントアウト
        pbar.set_postfix({"train loss:": epoch_loss / (i+1)})
        pbar.update(batch_size)
    
    epoch_loss = epoch_loss / (i+1)

    return epoch_loss


In [None]:
def validation(dataloader):
    """
    検証部分。中身はほぼtrainと同じ
    """
    val_loss = 0
    model.eval()
    
    # プログレスバー。不要ならコメントアウトしといてください
    pbar = tqdm(total=len(dataloader.dataset), leave=False)

    for i, (gray_images, rgb_images) in enumerate(dataloader):
        _batch_size = len(gray_images)
        gray_images = gray_images.to(device)
        rgb_images = rgb_images.to(device)

        out = model(gray_images)
        loss = criterion(out, rgb_images)

        val_loss += loss.item()

        # プログレスバー関係の記述。不要なら下2行はコメントアウト
        pbar.set_postfix({"val loss:": val_loss / (i+1)})
        pbar.update(_batch_size)
    
    val_loss = val_loss / (i+1)
    
    return val_loss

In [None]:
def save_checkpoint(filename, epoch, train_loss=None, val_loss=None):
    """
    https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
    """
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss
    }, filename)

def load_checkpoint(filename) :
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    train_loss = checkpoint['train_loss']
    val_loss = checkpoint['val_loss']

    return epoch, train_loss, val_loss


## 4.4 学習のメイン部分

In [None]:
model.to(device)
torch.backends.cudnn.benchmark = True

num_epoches = 100
best_val_loss = None

for epoch in range(num_epoches):
    train_loss = train(dataloader)
    val_loss = validation(val_dataloader)
    print(f"epoch={epoch} train L1Loss={train_loss}, val L1Loss={val_loss}")

    # 現時点で最善のやつを保存する
    if best_val_loss is None or best_val_loss > val_loss:
        best_val_loss = val_loss
        filename = "best.checkpoint"
        save_checkpoint(filename, epoch, train_loss, val_loss)
    
    # それとは別に10エポックごとに保存する
    if epoch % 10 == 0:
        filename = f"save{epoch:04d}.checkpoint"
        save_checkpoint(filename, epoch, train_loss, val_loss)

print(f"Done. best val loss={best_val_loss}")

# 5. テストしましょう

## 5.1 Validationデータ

In [None]:
gray_images, rgb_images = next(iter(val_dataloader))

idx = 1
fig, axes = plt.subplots(1, 3)
axes[0].imshow(gray_images[idx,0], cmap=cm.gray)
axes[1].imshow(tensor2RGBimage(rgb_images[idx]))

model.to(device)
model.eval()
gray_images = gray_images.to(device)
rgb_images = rgb_images.to(device)
out = model(gray_images)
loss = criterion(out[idx], rgb_images[idx]).to('cpu').detach().item()
out = out.to('cpu')
out = out.detach()
axes[2].imshow(tensor2RGBimage(out[idx]))
plt.show()

print(loss)


## 5.2 ネット上にある画像

In [None]:
#!curl https://researchmap.jp/masaakiiiyama/avatar.jpg -o sample.png
!curl https://www.iiyama-lab.org/static/cf450c1cdd8f3a93e9ec1c816db1c39b/c58a3/MasaakiIiyama2020.jpg -o sample.jpg
#!curl https://upload.wikimedia.org/wikipedia/commons/0/06/Shiga_University_Auditorium_%28Nationally_Registered_Tangible_Cultural_Property%29_at_Headquarter_in_Hikone_and_Headquarter_Building.jpg -o sample.jpg
#!curl https://upload.wikimedia.org/wikipedia/commons/thumb/1/12/Shiga-Univ-Otsu-Entrance-2016081701.jpg/1280px-Shiga-Univ-Otsu-Entrance-2016081701.jpg -o sample.jpg
filename = 'sample.jpg'
#filename = "/root/data/share/face/test/nonsmile/004000.png"

data_transform = ImageTransform(mean=0.5, std=0.5)
gray_image = read_image(filename,mode=ImageReadMode.GRAY)
gray_image = data_transform(gray_image)
plt.imshow(gray_image[0], cmap=cm.gray)
plt.colorbar()
plt.show()
gray_image = gray_image.reshape((1,1,256,256))
gray_image = gray_image.to(device)
model.to(device)
model.eval()
out = model(gray_image).to('cpu').detach()
plt.imshow(tensor2RGBimage(out[0]))
plt.show()


# 6. 中断した学習を再開

## 6.1 DataLoaderとモデルとoptimizerと損失関数を準備

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
batch_size = 4

dataset = ColorizationImageDataset(datadir, transform_rgb=ImageTransform(), transform_gray=ImageTransform(mean=0.5, std=0.5))
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=2
)
val_dataset = ColorizationImageDataset(val_datadir, transform_rgb=ImageTransform(), transform_gray=ImageTransform(mean=0.5, std=0.5))
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size*2, shuffle=False,  num_workers=2
)

model = UNet()
optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-5, betas=[0.9, 0.999])

criterion = nn.L1Loss()


## 6.2 保存してたモデルを読みこむ

In [None]:
filename = "best.checkpoint"
current_epoch, train_loss, val_loss = load_checkpoint(filename)

## 6.3 学習を再開。基本は4.4と同じ

In [None]:
model.to(device)
torch.backends.cudnn.benchmark = True

num_epoches = 100
best_val_loss = val_loss

for epoch in range(current_epoch+1, num_epoches):
    train_loss = train(dataloader)
    val_loss = validation(val_dataloader)
    print(f"epoch={epoch} train L1Loss={train_loss}, val L1Loss={val_loss}")

    # 現時点で最善のやつを保存する
    if best_val_loss is None or best_val_loss > val_loss:
        best_val_loss = val_loss
        filename = "best.checkpoint"
        save_checkpoint(filename, epoch, train_loss, val_loss)
    
    # それとは別に10エポックごとに保存する
    if epoch % 10 == 0:
        filename = f"save{epoch:04d}.checkpoint"
        save_checkpoint(filename, epoch, train_loss, val_loss)

print(f"Done. best val loss={best_val_loss}")