In [1]:
# referenced from https://qiita.com/takubb/items/7d45ae701390912c7629
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms

from torch.utils.data import DataLoader, ConcatDataset

import os
import random
import glob
import numpy as np
import matplotlib.pyplot as plt
import cv2
import pickle

from PIL import Image

# from tqdm import tqdm  #コマンドラインで実行するとき
from tqdm.notebook import tqdm  # jupyter で実行するとき
from models import RESNETLIKE, MyModel, MyModel_shallow

In [2]:
class FocusPatchDataset(torch.utils.data.Dataset):
    # FILEはデータが入っていたファイル、imsize = 画像サイズ、psize = パッチサイズ
    def __init__(self, DIR, imsize, psize, channels, DEPTH_GAP, n_sample, transforms):
        # ground truth depthを読み込む
        with open(os.path.join(DIR,'depth.pkl'),"rb") as f:
            self.gt = pickle.load(f)
        
        if self.gt is None:
            print('cannot find ', os.path.join(DIR,'depth.pkl'))
            return None

        self.DIR = DIR
        self.DEPTH_GAP = DEPTH_GAP
        self.psize = psize
        self.channels = channels
        self.transforms = transforms
        self.n_sample = n_sample
        
        self.files = glob.glob(os.path.join(DIR,'[0-9]*.bmp'))
        # print(self.files)
        fvalues = []
        for fn in self.files:
            fvalues.append(int(os.path.splitext(os.path.basename(fn))[0]))

        # self fvaluesを作る. self.fvaluesにはフォーカスのペアが入る
        self.fvalues = []
        for f in fvalues:
            if f+DEPTH_GAP*50 in fvalues:
                self.fvalues.append([f,f+DEPTH_GAP*50])
                print("focus pair", f, f+DEPTH_GAP*50)

        self.imsize = imsize
        self.psize = psize
        self.locs = []

        # サンプル対象となる点を設定する
        w = imsize[0]
        h = imsize[1]        
        for i in range(self.n_sample):
            xx = random.randint(0,w-self.psize-1)
            yy = random.randint(0,h-self.psize-1)
            self.locs.append([xx,yy])

        print('locations', self.locs)

        # 学習高速化のためバッファを用意する
        self.buffer = {}
    
    def __len__(self):
        length = len(self.fvalues)*self.n_sample
        return length

    def __getitem__(self, index):
        imgs = []

        # ii: image index, pi: point index
        ii = index // self.n_sample
        pi = index % self.n_sample
        
        x = self.locs[pi][0]
        y = self.locs[pi][1]
        val = float(self.gt[y,x]) - float(self.fvalues[ii][0])

        # print(index, ii, pi, f"{self.fvalues[ii][0]:04d}.bmp", f"{self.fvalues[ii][1]:04d}.bmp", x, y, self.gt[y,x], self.fvalues[ii][0], val)
        
        # もしバッファにデータがあるならそれを返す
        if index in self.buffer.keys():
            return self.transforms(self.buffer[index]), val

        # バッファにない場合はファイルから読み出す
        FILE1 = os.path.join(self.DIR,f'{self.fvalues[ii][0]:04d}.bmp')
        FILE2 = os.path.join(self.DIR,f'{self.fvalues[ii][1]:04d}.bmp')
        
        img = Image.open(FILE1)
        img = img.crop((x, y, x+self.psize, y+self.psize))
        img = img.convert('L')
        img = np.array(img)
        img = np.array(img).astype('float32')
        imgs.append(img)

        img = Image.open(FILE2)
        img = img.crop((x, y, x+self.psize, y+self.psize))
        img = img.convert('L')
        img = np.array(img)
        img = np.array(img).astype('float32')
        imgs.append(img)
        
        out = np.stack(imgs,axis=2)

        # バッファーに保存する
        self.buffer[index] = out
        
        return self.transforms(out), val

In [3]:
# モデル訓練関数
def train_model(model, train_loader, test_loader):
    # Train loop ----------------------------
    model.train()  # 学習モードをオン
    train_batch_loss = []
    for data, val in train_loader:
        # GPUへの転送
        data, val = data.to(device), val.to(device)
        # 1. 勾配リセット
        optimizer.zero_grad()
        # 2. 推論
        output = model(data)
        val = val.unsqueeze(1).to(torch.float32)
        # 3. 誤差計算
        loss = criterion(output, val)
        # 4. 誤差逆伝播
        loss.backward()
        # 5. パラメータ更新
        optimizer.step()
        # train_lossの取得
        train_batch_loss.append(loss.item())

    # Test(val) loop ----------------------------
    model.eval()  # 学習モードをオフ
    test_batch_loss = []
    with torch.no_grad():  # 勾配を計算なし
        for data, val in test_loader:
            data, val = data.to(device), val.to(device)
            output = model(data)
            val = val.unsqueeze(1).to(torch.float32)
            loss = criterion(output, val)
            test_batch_loss.append(loss.item())

    return model, np.mean(train_batch_loss), np.mean(test_batch_loss)

In [4]:
if __name__ == "__main__":
    # リソースの指定（CPU/GPU）
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # データセットの作成
    trans = transforms.Compose([
        transforms.ToTensor(), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()
    ])
    
    # パラメータの設定
    #PSIZE = 65
    PSIZE = 111
    DEPTH_GAP = 100
    N_SAMPLE = 5000
    
    # merge several dataset together
    dataset1 = FocusPatchDataset(DIR = 'data/11000-16000/202410041703/', imsize = [1600,1200], 
                                psize=PSIZE, channels=2, n_sample=N_SAMPLE, 
                                DEPTH_GAP=DEPTH_GAP, transforms = trans)
    dataset2 = FocusPatchDataset(DIR = 'data/11000-16000/202410041747/', imsize = [1600,1200], 
                                psize=PSIZE, channels=2, n_sample=N_SAMPLE, 
                                DEPTH_GAP=DEPTH_GAP, transforms = trans)
    dataset3 = FocusPatchDataset(DIR = 'data/11000-16000/202410080808/', imsize = [1600,1200], 
                                psize=PSIZE, channels=2, n_sample=N_SAMPLE, 
                                DEPTH_GAP=DEPTH_GAP, transforms = trans)
    dataset4= FocusPatchDataset(DIR = 'data/11000-16000/202410080819/', imsize = [1600,1200], 
                                psize=PSIZE, channels=2, n_sample=N_SAMPLE, 
                                DEPTH_GAP=DEPTH_GAP, transforms = trans)
    
    dataset = torch.utils.data.ConcatDataset([dataset1,dataset2,dataset3,dataset4])


    # test dataset, train datasetに分割する
    train_set, test_set = torch.utils.data.random_split(dataset, [0.9,0.1])
    print("train dataset", len(train_set), "test dataset", len(test_set))

    # データローダーの作成
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=128,  # バッチサイズ
                                               shuffle=True,  # データシャッフル
                                               num_workers=0,  # 高速化
                                               pin_memory=True  # 高速化                                      
                                               )
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=128,
                                              shuffle=True,
                                              num_workers=0,  # 高速化
                                              pin_memory=True  # 高速化                                                   
                                              )

    # モデル・損失関数・最適化アルゴリスムの設定
    # model = RESNETLIKE(channels=2).to(device)
    model = MyModel(channels=2)
    #model = MyModel_shallow(channels=2).to(device)
    weight_file = f"weights/weight_11000_16000_{PSIZE}_{DEPTH_GAP}.pth"

    if os.path.exists(weight_file):
        # 学習済みモデルのロード
        print("load weight", weight_file)
        model.load_state_dict(torch.load(weight_file))
        model.eval()
    
    model = model.to(device)
    
    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters())
    #optimizer = optim.SGD(model.parameters())

    # 訓練の実行
    epoch = 50
    train_loss = []
    test_loss = []
    
    for epoch in tqdm(range(epoch)):
        model, train_l, test_l = train_model(model,train_loader,test_loader)
        train_loss.append(train_l)
        test_loss.append(test_l)    
        # 10エポックごとにロスを表示
        #if epoch % 5 == 0:
        print(f"{epoch}: train loss: {train_loss[-1]:.3f}, test loss: {test_loss[-1]:.3f}")
    
    # モデルの保存
    print("save weight as: ", weight_file)
    torch.save(model.state_dict(), weight_file)

    # 学習状況（ロス）の確認
    plt.plot(train_loss, label='train_loss')
    plt.plot(test_loss, label='test_loss')
    plt.legend()

focus pair 11000 16000
locations [[405, 386], [1304, 560], [1219, 753], [867, 754], [661, 1012], [1050, 228], [1384, 276], [206, 322], [889, 519], [42, 682], [106, 563], [415, 401], [1156, 22], [889, 421], [522, 482], [946, 417], [350, 193], [301, 711], [123, 142], [231, 360], [1332, 285], [828, 652], [1459, 219], [354, 469], [318, 448], [1148, 201], [299, 586], [845, 223], [339, 18], [120, 933], [1297, 1046], [1283, 500], [1152, 932], [1202, 346], [283, 558], [813, 508], [161, 184], [223, 428], [1230, 601], [1102, 930], [765, 468], [842, 183], [17, 478], [774, 1012], [1167, 844], [765, 945], [1059, 670], [621, 421], [513, 900], [1432, 669], [809, 648], [1095, 834], [1449, 1044], [118, 994], [433, 461], [1295, 99], [103, 399], [850, 76], [963, 456], [152, 523], [752, 473], [1268, 810], [484, 157], [483, 803], [1410, 274], [961, 679], [1247, 696], [1243, 46], [1405, 359], [367, 2], [745, 640], [493, 724], [682, 3], [774, 394], [1098, 531], [945, 644], [2, 1039], [784, 122], [64, 373], [