In [1]:
from __future__ import print_function
import torch
import torch.utils.data as data
import linecache
import os
import numpy as np
import time
import torchvision
import torch.distributed as dist 
import argparse, pickle, torch, time, os,sys
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg 
import matplotlib
import numpy as np

from PIL import Image, ImageDraw, ImageFont


class ESImagenet_Dataset(data.Dataset):
    def __init__(self, mode, data_set_path='/data/dvsimagenet/'):
        super().__init__()
        self.mode = mode
        self.filenames = []
        self.trainpath = data_set_path+'train'
        self.testpath = data_set_path+'train'
        self.traininfotxt = data_set_path+'trainlabel.txt'
        self.testinfotxt = data_set_path+'trainlabel.txt'
        self.formats = '.npz'
        self.find_raw = data_set_path+'find_raw.txt'
        self.find_raw_dict = {}
        
        if mode == 'train':
            self.path = self.trainpath
            trainfile = open(self.traininfotxt, 'r')
            for line in trainfile:
                filename, classnum, a, b = line.split()
                realname,sub = filename.split('.')
                self.filenames.append(realname+self.formats)
            
            for line in open(self.find_raw, 'r'):
                npzfile,jpegfile = line.split()
                self.find_raw_dict[npzfile] = jpegfile
                
        else:
            self.path = self.testpath
            testfile = open(self.testinfotxt, 'r')
            for line in testfile:
                filename, classnum, a, b = line.split()
                realname,sub = filename.split('.')
                self.filenames.append(realname+self.formats)

    def __getitem__(self, index):
        if self.mode == 'train':
            info = linecache.getline(self.traininfotxt, index+1)
        else:
            info = linecache.getline(self.testinfotxt, index+1)
        filename, classnum, a, b = info.split()
        
        jpegfile = self.find_raw_dict[filename]
        
        realname,sub = filename.split('.')
        filename = realname+self.formats
        filename = self.path + r'/' + filename
        
        jpegfilename = self.path + r'_raw/' + jpegfile
        
        #print(jpegfilename)
        image = cv2.imread(jpegfilename)
        size = image.shape
        h, w = size[0], size[1]
        scale = max(w, h) / float(224)
        new_w, new_h = int(w/scale), int(h/scale)
        resize_img = cv2.resize(image, (new_w, new_h))
        resize_img = torch.Tensor(np.array(resize_img))
        jpgtensor = torch.zeros([224,224,3])
        jpgtensor[112-new_h//2:112-new_h//2+new_h,112-new_w//2:112-new_w//2+new_w] = resize_img

        classnum = int(classnum)
        a = int(a)
        b = int(b)
        datapos = np.load(filename)['pos'].astype(np.float64)
        dataneg = np.load(filename)['neg'].astype(np.float64)

        
        dy = (254 - b) // 2
        dx = (254 - a) // 2
        input = torch.zeros([8, 2, 256, 256])

        x = datapos[:,0]+ dx
        y = datapos[:,1]+ dy
        t = datapos[:,2]-1
        input[t ,0,x ,y ] = 1
        
        x = dataneg[:,0]+ dx
        y = dataneg[:,1]+ dy
        t = dataneg[:,2]-1
        input[t,1,x ,y ] = 1

        reshape = input[:,:,16:240,16:240]
        label = torch.tensor([classnum])
        
        data={}
        data['dvs'] = reshape
        data['jpg'] = jpgtensor
        
        return data, label

    def __len__(self):
        return len(self.filenames)
    
import matplotlib.animation as animation
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import cv2

def images_to_video(img_array,name):
    height, width, layers = img_array[0].shape
    size = (width, height)
    fps = 20  
    out = cv2.VideoWriter('./'+name+'.avi', cv2.VideoWriter_fourcc(*'DIVX'), fps, size)
    for i in range(len(img_array)):
        out.write(img_array[i])
    out.release()

batch_size = 3
data_set_path = './ES-imagenet-mini/'
train_dataset =  ESImagenet_Dataset(mode='train',data_set_path=data_set_path)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
img_array = []
count =0

label_dir_dict={}
label_cn_dict={}



for line in open(data_set_path+'IM2ES.txt'):
    n,l = line.split()
    label_dir_dict[l]=n
    
for line in open(data_set_path+'ENlabel.txt'):
    list_line = line.split()
    label_cn_dict[list_line[1]]=list_line[2]
    
    
imgdx = 5
imgdy = 5
            
for batch_idx, (inputs, targets) in enumerate(train_loader):
    
    dvsdata = inputs['dvs']
    rawdata = inputs['jpg']
    img = dvsdata[0,:,:,:,:]
    imsize = img.size()
    canve_bigger = torch.zeros(8,int(imsize[-2]*3.6),int(imsize[-2]*5),3)
    for i in range(batch_size):
        jpg = rawdata[i,:,:,:]
        img = dvsdata[i,:,:,:,:]
        RGB = torch.zeros(3,imsize[-2],imsize[-1])
        SUM = torch.zeros(imsize[-2]+4,imsize[-1]+4)+8
        tracex = [1,0,2,1,0,2,1,1,2]
        tracey = [0,2,1,0,1,2,0,1,1]
        for j in range(8):
            dx = tracex[j]
            dy = tracey[j]
            RGB[0:2,:,:]=img[j,:,:,:]
            SUM[2-dx:2-dx+imsize[-2],2-dy:2-dy+imsize[-1]] += RGB[0,:,:]
            SUM[2-dx:2-dx+imsize[-2],2-dy:2-dy+imsize[-1]] -= RGB[1,:,:]
        
        SUM = SUM[2:imsize[-2]+2,2:imsize[-1]+2]/17.0*255
        
        for j in range(8):
            canve = torch.zeros(imsize[-2],imsize[-1]*4+imgdx*3,3)
            
            
            canve[:,0:imsize[-1],1]=img[j,0,:,:]*238
            canve[:,0:imsize[-1],2]=img[j,0,:,:]*238
            
            canve[:,imsize[-1]+imgdx:imsize[-1]*2+imgdx,1]=img[j,1,:,:]*191
            canve[:,imsize[-1]+imgdx:imsize[-1]*2+imgdx,2]=img[j,1,:,:]*255
            
            #rgb-bgr
            canve[:,imsize[-1]*2+imgdx*2:imsize[-1]*3+imgdx*2,2]= SUM
            canve[:,imsize[-1]*2+imgdx*2:imsize[-1]*3+imgdx*2,1]= SUM
            canve[:,imsize[-1]*2+imgdx*2:imsize[-1]*3+imgdx*2,0]= SUM
            
            canve[:,imsize[-1]*3+imgdx*3:imsize[-1]*4+imgdx*3,2]= jpg[:,:,0]
            canve[:,imsize[-1]*3+imgdx*3:imsize[-1]*4+imgdx*3,1]= jpg[:,:,1]
            canve[:,imsize[-1]*3+imgdx*3:imsize[-1]*4+imgdx*3,0]= jpg[:,:,2]
            
            canve_bigger[j, imsize[-2]//8+imsize[-2]*i + imgdy*i:imsize[-2]//8+imsize[-2]*(i+1)+imgdy*i
                          , imsize[-1]//8                       :imsize[-1]//8+imsize[-1]*4+imgdx*3    ,:]= canve
    
    for j in range(8):
        ime_cv = cv2.cvtColor(canve_bigger[j,:,:,:].numpy().astype(np.uint8),cv2.COLOR_RGB2BGR)

        name0 = label_cn_dict[label_dir_dict[str(int(targets[0,0]))]]
        name1 = label_cn_dict[label_dir_dict[str(int(targets[1,0]))]]
        name2 = label_cn_dict[label_dir_dict[str(int(targets[2,0]))]]

        font = cv2.FONT_HERSHEY_COMPLEX

        cv2.putText(ime_cv, '+ channel', (int(imsize[-2]*0.4),int(imsize[-2]*3.3)), font, 0.6, (255, 255, 255), 1,)
        cv2.putText(ime_cv, '- channel', (int(imsize[-2]*1.45),int(imsize[-2]*3.3)), font, 0.6, (255, 255, 255), 1,)
        cv2.putText(ime_cv, 'reconstructed', (int(imsize[-2]*2.3),int(imsize[-2]*3.3)), font, 0.6, (255, 255, 255), 1,)
        cv2.putText(ime_cv, 'raw image', (int(imsize[-2]*3.45),int(imsize[-2]*3.3)), font, 0.6, (255, 255, 255), 1,)

        cv2.putText(ime_cv, name0, (int(imsize[-2]*4.2),int(imsize[-2]*0.6)), font, 0.5, (255, 255, 255), 1,)
        cv2.putText(ime_cv, name1, (int(imsize[-2]*4.2),int(imsize[-2]*1.7)), font, 0.5, (255, 255, 255), 1,)
        cv2.putText(ime_cv, name2, (int(imsize[-2]*4.2),int(imsize[-2]*2.8)), font, 0.5, (255, 255, 255), 1,)

        img_array.append(ime_cv)
            
    count+=1
    if count > 30:
        images_to_video(img_array,name='full90')
        break

     