In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import cv2
from PIL import Image
import numpy as np
import argparse
import glob
import os

IMW = 320
IMH = 240

import matplotlib.pyplot as plt

class MLP(nn.Module):
    def __init__(self, n_units1):
        super(MLP, self).__init__()
        self.l1 = nn.Conv2d(3, n_units1, kernel_size=5, stride=1, padding=2)
        self.l2 = nn.Conv2d(n_units1, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = self.l2(x)
        return torch.sigmoid(x)

def load_image(fname, imw, imh):
    img = Image.open(fname).resize((imw, imh))
    a = np.asarray(img).transpose(2,0,1).astype(np.float32)/255.
    return a, img

import random
import numpy as np
import torch


#https://take-tech-engineer.com/pytorch-randam-seed-fix/
def torch_fix_seed(seed=42):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True


torch_fix_seed()

In [None]:
class WhitelineDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        imw = IMW
        imh = IMH
        x = np.zeros((0, 3, imh, imw), dtype=np.float32)
        t = np.zeros((0, 1, imh, imw), dtype=np.float32)
        jpgfiles = glob.glob(image_dir + '*.jpg')
        print(jpgfiles)
        for f in jpgfiles:
            plt.figure()
            a, img = load_image(f, imw, imh)
            a1 = np.expand_dims(a,axis=0)
            x = np.append(x, a1, axis=0)
            #im_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            #timg = Image.fromarray(im_rgb)
            plt.imshow(img)
            
            lfile = os.path.splitext(f)[0] + '_label.png'
            print(lfile)
            im_gray = cv2.imread(lfile, cv2.IMREAD_GRAYSCALE)
            im_gray = cv2.resize(im_gray, (imw, imh))
            #plt.figure()
            #plt.imshow(timg)
            a = np.asarray(im_gray).astype(np.float32) > 0.01
            a = a.astype(np.float32)
            print(np.sum(a))
            a1 = np.expand_dims(a,axis=0)
            a1 = np.expand_dims(a1,axis=0)
            t = np.append(t, a1, axis=0)
        self.data  = x
        self.label = t

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        #print(self.data.shape)
        #print(self.label.shape)
        images =  self.data[idx, :, :, :]
        labels = self.label[idx, :, :]

        return (images, labels)


use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
                    help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                    help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                    help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--dry-run', action='store_true', default=False,
                    help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
                    help='For Saving the current Model')
args = parser.parse_args(args=[])
#dataset = WhitelineDataset('./white_image/')
dataset = WhitelineDataset('./real_data_train/')
train_loader = DataLoader(dataset, batch_size=5, shuffle=True, num_workers=0)


In [None]:
# 訓練を実行
def train(args, model, device, dataloader, optimizer, epoch):
    model.train()
    lossfun = nn.BCELoss()
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = lossfun(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(dataloader.dataset),
                100. * batch_idx / len(dataloader), loss.item()))
            if args.dry_run:
                break

model = MLP(16).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
model.to(device)

niter = 1000  # 繰り返し数（要調整）

for i in range(niter):
    train(args, model=model, device=device, dataloader=train_loader, optimizer=optimizer, epoch=i)

# save to file
cpu = torch.device("cpu")
model.to(cpu)
torch.save(model.state_dict(), "wl_model.pt")

In [None]:
# load model from file
model_path = "wl_model.pt"

model = MLP(16)
model.load_state_dict(torch.load(model_path))
model.eval()

def eval_image(fname, thre):
    imw = IMW
    imh = IMH
    testx = np.zeros((0, 3, imh, imw), dtype=np.float32)

    a, img = load_image(fname, imw, imh)
    a1 = np.expand_dims(a,axis=0)
    #print(a1.shape)
    testx = np.append(testx, a1, axis=0)

    import time
    t0 = time.time()
    testy = model(torch.FloatTensor(testx))
    print('forward time [s]: ' + str(time.time()-t0))

    #fig, axs = plt.subplots(1,3)
    #plt.imshow(img, ax=axs[0])
    imd = Image.new('RGB', (imw*2, imh))
    imd.paste(img)
    thimg = (testy.detach().to(cpu).numpy()[0][0] > thre) * 255
    print(thimg.shape)
    print(f'max {np.max(thimg)}')
    print(f'min {np.min(thimg)}')
    thimg = thimg.astype(np.uint8)
    
    thimg = Image.fromarray(thimg)
    plt.imshow(thimg)
    thimg.save(fname+'_out.png')
    imd.paste(thimg, (imw, 0))
    plt.imshow(imd)

In [None]:
eval_image('real_data_test/20190607T203503_000026.jpg', 0.5)

In [None]:
eval_image('real_data_test/20190607T203649_000013.jpg', 0.5)

In [None]:
# save torchscript for C++
a, img = load_image('real_data_test/20190607T203649_000013.jpg', IMW, IMH)
#a1 = np.expand_dims(a,axis=0)
ex = np.expand_dims(a,axis=0)
traced_script_module = torch.jit.trace(model, torch.FloatTensor(ex))
traced_script_module.save("traced_wl_model.pt") # これをC++から呼び出す