In [2]:
import os
import time

import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from torch import nn
from PIL import Image

from tensorboardX import SummaryWriter

import numpy as np
from lib import RedNet_model
from lib import RedNet_data_test as RedNet_data
from lib.utils import utils
from lib.utils.utils import save_ckpt
from lib.utils.utils import print_log
from lib.utils.utils import load_ckpt
import imageio
from torch.optim.lr_scheduler import LambdaLR

In [3]:
checkpoint_path = "lib/models/model1/ckpt_epoch_410.00.pth"
image_path = "out/image_0.png"
depth_path = "out/image_0.npy"
device = torch.device("cuda:0")
image_w = 640
image_h = 480
weight_decay = 1e-4
momentum = 0.9
lr = 2e-3

In [13]:
transform=transforms.Compose([RedNet_data.scaleNorm(),
                              RedNet_data.RandomHSV((0.9, 1.1),
                                                    (0.9, 1.1),
                                                    (25, 25)),
                              RedNet_data.ToTensor(),
                              RedNet_data.Normalize()])
model = RedNet_model.RedNet(pretrained=False)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                            momentum=momentum, weight_decay=weight_decay)
global_step, start_epoch = load_ckpt(model, optimizer, checkpoint_path, device)

=> loading checkpoint 'lib/models/model1/ckpt_epoch_410.00.pth'
=> loaded checkpoint 'lib/models/model1/ckpt_epoch_410.00.pth' (epoch 410.0)


In [15]:
depth_raw = np.load(depth_path)
image_raw = imageio.v2.imread(image_path, pilmode='RGB')
sample = {'image':image_raw, 'depth': depth_raw}
sample = transform(sample)
image = sample['image'].reshape((1,3,480,640))
depth = sample['depth'].reshape((1,1,480,640))
image = image.to(device)
depth = depth.to(device)
print(image.shape, depth.shape)
pred_scales = model(image, depth, False)
pred_scales

torch.Size([1, 3, 480, 640]) torch.Size([1, 1, 480, 640])


(tensor([[[[ 1.0233e+01,  1.5092e+01,  1.3619e+01,  ...,  1.4025e+01,
             4.1557e+00,  9.3533e+00],
           [ 6.2937e+00,  1.0560e+01,  8.9087e+00,  ...,  9.5228e+00,
             3.5957e+00,  5.8160e+00],
           [ 5.0701e+00,  6.0989e+00,  1.1211e+01,  ...,  8.4574e+00,
             1.5749e-01,  4.5789e+00],
           ...,
           [ 1.7223e+00,  1.6987e+00,  3.8527e+00,  ...,  1.9944e+00,
             1.5214e+00,  2.5575e+00],
           [ 2.5427e+00,  4.8746e+00,  2.7340e+00,  ...,  3.4493e+00,
             2.1133e+00,  6.8259e+00],
           [ 1.1783e+00,  3.3891e+00,  1.1062e+00,  ...,  3.0241e+00,
             2.3961e+00,  4.3854e+00]],
 
          [[-9.1069e+00, -4.6165e+00, -4.5849e+00,  ..., -5.0374e+00,
            -3.8168e+00, -2.4436e+00],
           [-9.0016e+00, -5.9510e+00, -4.5201e+00,  ..., -3.3469e+00,
            -4.5130e+00,  1.8411e-01],
           [-1.6714e+00,  1.2031e+00, -4.6064e+00,  ..., -5.9614e+00,
            -4.5452e+00, -1.7874e+00],


In [28]:
label_colours = [[0, 0, 0],[148, 65, 137], [255, 116, 69], [86, 156, 137]]
colors = torch.max(pred_scales[0][:3], 1)[1] + 1
colors = colors.cpu().detach().numpy()
colors = colors.reshape(colors.shape[1:])
colors.shape

(480, 640)

In [48]:
from PIL import Image
import cv2
image_raw = imageio.v2.imread(image_path, pilmode='RGB')
image_raw = cv2.resize(image_raw,(640,480))
image_raw.shape

(480, 640, 3)

In [44]:
np.unique(colors)

array([1, 2, 3], dtype=int64)

In [49]:
image_raw = imageio.v2.imread(image_path, pilmode='RGB')
image_raw = cv2.resize(image_raw,(640,480))
new_image = image_raw
#colors.reshape(colors.shape[1:])
for i in range(480):
    for j in range(640):
        if colors[i,j] != 1:
            new_image[i,j] = label_colours[colors[i,j]]
new_image.shape

(480, 640, 3)

In [50]:
import cv2
cv2.imwrite("out/test_0.png",new_image)

True