In [16]:
import os
import datetime
import shutil
import logging
import yaml
import importlib
import time
from path import Path
from abc import ABC, abstractmethod
from PIL import Image as Im
import numpy as np
import torch.nn.functional as F
import tqdm

from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import cm

import torch
from tensorboardX import SummaryWriter

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

import dataloader
from dataloader import aachen_loader

from feature_descriptors import backbone
from feature_descriptors import detection_net
# from feature_descriptors import my_model

from tqdm import tqdm
import cv2
import copy
import matplotlib
import matplotlib.pyplot as plt

In [17]:
dataloader = aachen_loader.Aachen_Day_Night()

In [18]:
dataloader

<dataloader.aachen_loader.Aachen_Day_Night at 0x7fabe076d2e0>

In [19]:
len(dataloader)

7712

In [20]:
net1 = backbone.ResUNet_F2R().to("cuda")
net2 = detection_net.DetNet(net1, 128).to("cuda")

In [21]:
def opencv_rainbow(resolution=1000):
    # Construct the opencv equivalent of Rainbow
    opencv_rainbow_data = (
        (0.000, (1.00, 0.00, 0.00)),
        (0.400, (1.00, 1.00, 0.00)),
        (0.600, (0.00, 1.00, 0.00)),
        (0.800, (0.00, 0.00, 1.00)),
        (1.000, (0.60, 0.00, 1.00))
    )

    return LinearSegmentedColormap.from_list('opencv_rainbow', opencv_rainbow_data, resolution)

In [22]:
def high_res_colormap(low_res_cmap, resolution=1000, max_value=1):
    # Construct the list colormap, with interpolated values for higer resolution
    # For a linear segmented colormap, you can just specify the number of point in
    # cm.get_cmap(name, lutsize) with the parameter lutsize
    x = np.linspace(0, 1, low_res_cmap.N)
    low_res = low_res_cmap(x)
    new_x = np.linspace(0, max_value, resolution)
    high_res = np.stack([np.interp(new_x, x, low_res[:, i]) for i in range(low_res.shape[1])], axis=1)
    return ListedColormap(high_res)

In [23]:
COLORMAPS = {'rainbow': opencv_rainbow(),
             'magma': high_res_colormap(cm.get_cmap('magma')),
             'bone': cm.get_cmap('bone', 10000)}

  'magma': high_res_colormap(cm.get_cmap('magma')),
  'bone': cm.get_cmap('bone', 10000)}


In [24]:
def tensor2array(tensor, max_value=None, colormap='rainbow'):
    tensor = tensor.detach().cpu()
    if max_value is None:
        max_value = tensor[tensor < np.inf].max().item()
    if tensor.ndimension() == 2 or tensor.size(0) == 1:
        norm_array = tensor.squeeze().numpy()/max_value
        norm_array[norm_array == np.inf] = np.nan
        array = COLORMAPS[colormap](norm_array).astype(np.float32)
        array = array.transpose(2, 0, 1)[:3]

    elif tensor.ndimension() == 3:
        assert(tensor.size(0) == 3)
        array = 0.5 + tensor.numpy()*0.5
    return array

In [27]:
def extractor():
    bar = tqdm(dataloader, total=int(len(dataloader)), ncols=80)
    color = np.array(range(256)).astype(np.float64)[None,:].repeat(30, axis=0)
    color = np.concatenate([np.zeros((30,20)),255*np.ones((30,20)),color], axis=1)
    color = tensor2array(torch.tensor(color))[:3,:,:].transpose(1,2,0)
    color = Im.fromarray((255*color).astype(np.uint8))
    color.save('img_file/0_colorbar.jpg')
    name_list = ''
    
    for idx, inputs in enumerate(bar):
        for key, val in inputs.items():
            if key == 'name1' or key == 'pad1':
                continue
            inputs[key] = val.to("cuda")
        message = inputs['name1'][0]
        print(inputs['im1'].shape)
        print(inputs['im1'].dim())

        # batch_size가 10인 경우
        # batch_size = 10
        # batch 차원 추가하여 4D로 변환
        batch_size = 1
        input_data_4d = torch.unsqueeze(inputs['im1'], dim=0).expand(batch_size, -1, -1, -1)
        
        
        outputs = net1(input_data_4d)

        
        
        # processed = self.process(inputs, outputs)
        # if self.config['output_desc']:
        #     message += self.save_desc(inputs, outputs, processed)
        # if self.config['output_img']:
        #     message += self.save_imgs(inputs, outputs, processed, idx)
        # self.logger.info(message)
        # name_list += '{} {}\n'.format(idx, inputs['name1'][0])
        torch.cuda.empty_cache()
    # with open(self.img_root/'name_list.txt', 'w') as f:
    #     f.write(name_list)

In [28]:
extractor()

  0%|                                                  | 0/7712 [00:00<?, ?it/s]

torch.Size([3, 1056, 1600])
3


  0%|                                        | 2/7712 [00:01<1:26:00,  1.49it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                        | 3/7712 [00:01<1:19:22,  1.62it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                        | 4/7712 [00:02<1:16:16,  1.68it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                        | 5/7712 [00:03<1:14:20,  1.73it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                        | 6/7712 [00:03<1:12:24,  1.77it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                        | 7/7712 [00:04<1:11:26,  1.80it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                        | 8/7712 [00:04<1:09:27,  1.85it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                        | 9/7712 [00:05<1:12:24,  1.77it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 10/7712 [00:05<1:10:35,  1.82it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 11/7712 [00:06<1:08:52,  1.86it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 12/7712 [00:06<1:08:08,  1.88it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 13/7712 [00:07<1:07:31,  1.90it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 14/7712 [00:07<1:06:07,  1.94it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 15/7712 [00:08<1:05:26,  1.96it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 16/7712 [00:08<1:04:48,  1.98it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 17/7712 [00:09<1:04:35,  1.99it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 18/7712 [00:09<1:04:32,  1.99it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 19/7712 [00:10<1:10:42,  1.81it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 20/7712 [00:10<1:10:02,  1.83it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 21/7712 [00:11<1:08:58,  1.86it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 22/7712 [00:12<1:08:30,  1.87it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 23/7712 [00:12<1:07:41,  1.89it/s]

torch.Size([3, 1600, 1056])
3


  0%|                                       | 24/7712 [00:13<1:06:25,  1.93it/s]

torch.Size([3, 1600, 1056])
3


  0%|▏                                      | 25/7712 [00:13<1:05:46,  1.95it/s]

torch.Size([3, 1600, 1056])
3


  0%|▏                                      | 26/7712 [00:14<1:05:33,  1.95it/s]

torch.Size([3, 1600, 1056])
3


  0%|▏                                      | 27/7712 [00:14<1:05:56,  1.94it/s]

torch.Size([3, 1600, 1056])
3


  0%|▏                                      | 28/7712 [00:15<1:06:06,  1.94it/s]

torch.Size([3, 1600, 1056])
3


  0%|▏                                      | 29/7712 [00:15<1:11:08,  1.80it/s]

torch.Size([3, 1600, 1056])
3


  0%|▏                                      | 30/7712 [00:16<1:11:18,  1.80it/s]

torch.Size([3, 1600, 1056])
3


  0%|▏                                      | 31/7712 [00:16<1:09:44,  1.84it/s]

torch.Size([3, 1600, 1056])
3


  0%|▏                                      | 31/7712 [00:17<1:11:03,  1.80it/s]


KeyboardInterrupt: 