## Data Utilities

In [3]:
# %load data_utils.py

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
from torchvision import datasets, transforms

import numpy as np
import h5py
import re


class SequenceGenerator(data.Dataset):
    """
    Sequence Generator

    the role of SequenceGenerator is equal to ImageFolder class in pytorch.

    the X_train.h5 contains 41396 images for 57 videos.
    the  X_test.h5 contains   832 images for  3 videos.
    the   X_val.h5 contains   154 images for  1 videos.

    Args:
        - data_file:
            data path, e.g., '/media/sdb1/chenrui/kitti_data/h5/X_train.h5'
        - source_file:
            e.g., '/media/sdb1/chenrui/kitti_data/h5/sources_train.h5'
            source for each image so when creating sequences can assure that consecutive frames are from same video.
                the content is like: 'road-2011_10_03_drive_0047_sync'
        - num_timeSteps:
            number of timesteps to predict
        - seed:
            Random seeding for data shuffling.
        - shuffle:
            shuffle or not
        - output_mode:
            `error` or `prediction`
        - sequence_start_mode:
            `all` or `unique`.
            `all`: allow for any possible sequence, starting from any frame.
            `unique`: create sequences where each unique frame is in at most one sequence
        - N_seq:
            TODO
    """
    def __init__(self, data_file, source_file, num_timeSteps, shuffle = False, seed = None,
                 output_mode = 'error', sequence_start_mode = 'all', N_seq = None, data_format = 'channels_first'):
        super(SequenceGenerator, self).__init__()
        pattern = re.compile(r'.*?h5/(.+?)\.h5')
        resList = re.findall(pattern, data_file)
        print(resList)
        varName = resList[0] 
        h5f = h5py.File(data_file, 'r')
        self.X = h5f[varName][:]    # X will be like (n_images, cols, rows, channels) 
        # self.X = h5f[data_0][:]    # X will be like (n_images, cols, rows, channels) (already printed)

        resList = re.findall(pattern, source_file)
        varName = resList[0]
        source_h5f = h5py.File(source_file, 'r')
        #self.sources = source_h5f[data_0][:]   # list
        self.sources = source_h5f[varName][:]   # list

        self.num_timeSteps = num_timeSteps
        self.shuffle = shuffle
        self.seed = seed
        assert output_mode in {'error', 'prediction'}
        self.output_mode = output_mode
        assert sequence_start_mode in {'all', 'unique'}
        self.sequence_start_mode = sequence_start_mode
        self.N_seq = N_seq
        self.data_format = data_format
        if self.data_format == 'channels_first':
            self.X = np.transpose(self.X, (0, 3, 1, 2))
        self.img_shape = self.X[0].shape
        self.num_samples = self.X.shape[0]

        if self.sequence_start_mode == 'all':       # allow for any possible sequence, starting from any frame (如果视频中任意一帧都可以作为起点,只需要确定加上序列长度后的小片段终点是否还属于同一个视频即可)
            self.possible_starts = np.array([i for i in range(self.num_samples - self.num_timeSteps) if self.sources[i] == self.sources[i + self.num_timeSteps - 1]])
        elif self.sequence_start_mode == 'unique':  # create sequences where each unique frame is in at most one sequence
            curr_location = 0
            possible_starts = []
            while curr_location < self.num_samples - self.num_timeSteps + 1:
                if self.sources[curr_location] == self.sources[curr_location + self.num_timeSteps - 1]:
                    possible_starts.append(curr_location)
                    curr_location += self.num_timeSteps
                else:
                    curr_location += 1
            self.possible_starts = possible_starts

        if shuffle:
            self.possible_starts = np.random.permutation(self.possible_starts)

        if N_seq is not None and len(self.possible_starts) > N_seq:     # select a subset of sequences if want to
            self.possible_starts = self.possible_starts[:N_seq]
        self.N_sequences = len(self.possible_starts)                    # Number of all possible training segments

    def __getitem__(self, index):
        '''
        Args:
            index (int): Index

        Returns:
            tuple: (stacked images, target) where target is NOT class_index of the target class
                BUT the order of frames in sorting task.
        '''
        idx = self.possible_starts[index]
        image_group = self.preprocess(self.X[idx : (idx + self.num_timeSteps)])
        
        if self.output_mode == 'error':
            target = 0.             # model outputs errors, so y should be zeros
        elif self.output_mode == 'prediction':
            target = image_group    # output actual pixels

        return image_group, target

    def preprocess(self, X):
        return X.astype(np.float32) / 255.

    def __len__(self):
        return self.N_sequences

    def create_all(self):
        '''It is equivalent to create_all in the original code. Serves the evaluate mode and returns all test data.'''
        X_all = np.zeros((self.N_sequences, self.num_timeSteps) + self.img_shape, np.float32)
        for i, idx in enumerate(self.possible_starts):
            X_all[i] = self.preprocess(self.X[idx : (idx + self.num_timeSteps)])
        return X_all


class ZcrDataLoader(object):
    '''[DataLoader for video frame predictation]'''
    def __init__(self, data_file, source_file, output_mode, sequence_start_mode, N_seq, args):
        super(ZcrDataLoader, self).__init__()
        self.data_file = data_file
        self.source_file = source_file
        self.output_mode = output_mode
        self.sequence_start_mode = sequence_start_mode
        self.N_seq = N_seq
        self.args = args

    def dataLoader(self):
        image_dataset = SequenceGenerator(self.data_file, self.source_file, self.args.num_timeSteps, self.args.shuffle, None, self.output_mode, self.sequence_start_mode, self.N_seq, self.args.data_format)
        # NOTE: Set drop_last to True, you can delete the last incomplete batch (eg, when the data set size is not divisible by batch_size, the number of samples in the last batch is not enough for one batch_size, which may cause some to be used last time The resulting code reports an error because the old size and the new size do not match (PredNet has this problem, so drop_last is set to True here))
        dataloader = data.DataLoader(image_dataset, batch_size = self.args.batch_size, shuffle = False, num_workers = self.args.workers, drop_last = True)
        return dataloader


if __name__ == '__main__':
    pass

## Debug

In [4]:
# %load debug.py

import os
import numpy as np
import h5py

# dataDir = '../coxlab-prednet-cc76248/kitti_data/'
# trainSet_path = os.path.join(dataDir, 'X_train.hkl')
# train_sources = os.path.join(dataDir, 'sources_train.hkl')
# testSet_path = os.path.join(dataDir, 'X_test.hkl')
# test_sources = os.path.join(dataDir, 'sources_test.hkl')

# @200.121
dataDir = 'C:\\Users\\kirub\\Desktop\\Summer project 2021\\PredNet_pytorch-master\\kitti_data\\prednet_kitti_data'                 
trainSet_path = os.path.join(dataDir, 'X_train.h5')
train_sources = os.path.join(dataDir, 'sources_train.h5')
testSet_path  = os.path.join(dataDir, 'X_test.h5')
test_sources  = os.path.join(dataDir, 'sources_test.h5')

valset_path = os.path.join(dataDir, 'X_val.h5')
val_sources = os.path.join(dataDir, 'sources_val.h5')



h5f = h5py.File(testSet_path,'r')


#h5f = h5py.File(testSet_path,'r+')
#h5f['data_0']=h5file['X_test'];

#abc = h5py.File(trainSet_path,'r')#Extra clarifications for the errors obtained in train.py 
#dfe = h5py.File(train_sources,'r')#To check names for X_train, sources_train, X_val, sources_val 
#ghi = h5py.File(valset_path,'r')
#jkl = h5py.File(val_sources,'r')

#print(list(h5f.keys()))
#dset=h5f['data_0']
#print(dset.shape)
#print(dset.dtype) (already printed)

#print(list(abc.keys()))
#print(abc['data_0'].shape)
#print(list(dfe.keys()))
#print(dfe['data_0'].shape)
#print(list(ghi.keys()))
#print(ghi['data_0'].shape)
#print(list(jkl.keys()))
#print(jkl['data_0'].shape)

testSet = h5f['data_0'][:]

# print(testSet)
# print(type(testSet))    # <class 'numpy.ndarray'>
# print(testSet.shape)    # (832, 128, 160, 3)

## Evaluate model

In [5]:
# %load evaluate.py

import os
import argparse
import numpy as np

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import torch
from torch.autograd import Variable

# zcr lib
from prednet import PredNet
from data_utils import ZcrDataLoader

def arg_parse():
    desc = "Video Frames Predicting Task via PredNet."
    parser = argparse.ArgumentParser(description = desc)

    parser.add_argument('--mode', default = 'train', type = str,
                        help = 'train or evaluate (default: train)')
    parser.add_argument('--dataPath', default = '', type = str, metavar = 'PATH',
                        help = 'path to video dataset (default: none)')
    parser.add_argument('--resultsPath', default = '', type = str, metavar = 'PATH',
                        help = 'saving path to results of PredNet (default: none)')
    parser.add_argument('--checkpoint_file', default = '', type = str,
                        help = 'checkpoint file for evaluating. (default: none)')
    parser.add_argument('--batch_size', default = 32, type = int, metavar = 'N',
                        help = 'The size of batch')
    parser.add_argument('--num_plot', default = 40, type = int, metavar = 'N',
                        help = 'how many images to plot')
    parser.add_argument('--num_timeSteps', default = 10, type = int, metavar = 'N',
                        help = 'number of timesteps used for sequences in training (default: 10)')
    parser.add_argument('--workers', default = 4, type = int, metavar = 'N',
                        help = 'number of data loading workers (default: 4)')
    parser.add_argument('--shuffle', default = True, type = bool,
                        help = 'shuffle or not')
    parser.add_argument('--data_format', default = 'channels_last', type = str,
                        help = '(c, h, w) or (h, w, c)?')
    parser.add_argument('--n_channels', default = 3, type = int, metavar = 'N',
                        help = 'The number of input channels (default: 3)')
    parser.add_argument('--img_height', default = 128, type = int, metavar = 'N',
                        help = 'The height of input frame (default: 128)')
    parser.add_argument('--img_width', default = 160, type = int, metavar = 'N',
                        help = 'The width of input frame (default: 160)')
    # parser.add_argument('--stack_sizes', default = '', type = str,
    #                     help = 'Number of channels in targets (A) and predictions (Ahat) in each layer of the architecture.')
    # parser.add_argument('--R_stack_sizes', default = '', type = str,
    #                     help = 'Number of channels in the representation (R) modules.')
    # parser.add_argument('--A_filter_sizes', default = '', type = str,
    #                     help = 'Filter sizes for the target (A) modules. (except the target (A) in lowest layer (i.e., input image))')
    # parser.add_argument('--Ahat_filter_sizes', default = '', type = str,
    #                     help = 'Filter sizes for the prediction (Ahat) modules.')
    # parser.add_argument('--R_filter_sizes', default = '', type = str,
    #                     help = 'Filter sizes for the representation (R) modules.')
    
    args = parser.parse_args()
    return args

def print_args(args):
    print('-' * 50)
    for arg, content in args.__dict__.items():
        print("{}: {}".format(arg, content))
    print('-' * 50)


def evaluate(model, args):
    '''Evaluate PredNet on KITTI sequences'''
    prednet = model     # Now prednet is the testing model (to output predictions)

    DATA_DIR = args.dataPath
    RESULTS_SAVE_DIR = args.resultsPath
    test_file = os.path.join(DATA_DIR, 'X_test.h5')
    test_sources = os.path.join(DATA_DIR, 'sources_test.h5')

    output_mode = 'prediction'
    sequence_start_mode = 'unique'
    N_seq = None
    dataLoader = ZcrDataLoader(test_file, test_sources, output_mode, sequence_start_mode, N_seq, args).dataLoader()
    X_test = dataLoader.dataset.create_all()
    # print('X_test.shape', X_test.shape)       # (83, 10, 3, 128, 160)
    X_test = X_test[:8, ...]                    # to overcome `cuda runtime error: out of memory`
    batch_size = X_test.shape[0]
    X_groundTruth = np.transpose(X_test, (1, 0, 2, 3, 4))      # (timesteps, batch_size, 3, 128, 160)
    X_groundTruth_list = []
    for t in range(X_groundTruth.shape[0]):
        X_groundTruth_list.append(np.squeeze(X_groundTruth[t, ...]))    # (batch_size, 3, 128, 160)

    X_test = Variable(torch.from_numpy(X_test).float().cuda())

    if prednet.data_format == 'channels_first':
        input_shape = (batch_size, args.num_timeSteps, n_channels, img_height, img_width)
    else:
        input_shape = (batch_size, args.num_timeSteps, img_height, img_width, n_channels)
    initial_states = prednet.get_initial_states(input_shape)
    predictions = prednet(X_test, initial_states)
    # print(predictions)
    # print(predictions[0].size())    # torch.Size([8, 3, 128, 160])

    X_predict_list = [pred.data.cpu().numpy() for pred in predictions]  # length of X_predict_list is timesteps. 每个元素shape是(batch_size, 3, H, W)

    # Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt
    # MSE_PredNet  = np.mean((real_X[:, 1:  ] - pred_X[:, 1:])**2)    # look at all timesteps except the first
    # MSE_previous = np.mean((real_X[:,  :-1] - real_X[:, 1:])**2)
    # if not os.path.exists(RESULTS_SAVE_DIR):
    #     os.mkdir(RESULTS_SAVE_DIR)
    # score_file = os.path.join(RESULTS_SAVE_DIR, 'prediction_scores.txt')
    # with open(score_file, 'w') as f:
    #     f.write("PredNet MSE: %f\n" % MSE_PredNet)
    #     f.write("Previous Frame MSE: %f" % MSE_previous)

    # Plot some predictions
    if prednet.data_format == 'channels_first':
        X_groundTruth_list = [np.transpose(batch_img, (0, 2, 3, 1)) for batch_img in X_groundTruth_list]
        X_predict_list     = [np.transpose(batch_img, (0, 2, 3, 1)) for batch_img in X_predict_list]
    assert len(X_groundTruth_list) == len(X_predict_list) == args.num_timeSteps
    timesteps = args.num_timeSteps
    total_num = X_groundTruth_list[0].shape[0]
    height = X_predict_list[0].shape[1]
    width  = X_predict_list[0].shape[2]

    n_plot = args.num_plot
    if n_plot > total_num:
        n_plot = total_num
    aspect_ratio = float(height) / width
    plt.figure(figsize = (timesteps, (2 * aspect_ratio)))
    gs = gridspec.GridSpec(2, timesteps)
    gs.update(wspace = 0., hspace = 0.)
    plot_save_dir = os.path.join(RESULTS_SAVE_DIR, 'prediction_plots/')
    if not os.path.exists(plot_save_dir):
        os.mkdir(plot_save_dir)
    plot_idx = np.random.permutation(total_num)[:n_plot]
    for i in plot_idx:
        for t in range(timesteps):
            ## plot the ground truth.
            plt.subplot(gs[t])
            plt.imshow(X_groundTruth_list[t][i, ...], interpolation = 'none')
            plt.tick_params(axis = 'both', which = 'both', bottom = 'off', top = 'off', left = 'off', right = 'off', labelbottom = 'off', labelleft = 'off')
            if t == 0:
                plt.ylabel('Actual', fontsize = 10)

            ## plot the predictions.
            plt.subplot(gs[t + timesteps])
            plt.imshow(X_predict_list[t][i, ...], interpolation = 'none')
            plt.tick_params(axis = 'both', which = 'both', bottom = 'off', top = 'off', left = 'off', right = 'off', labelbottom = 'off', labelleft = 'off')
            if t == 0:
                plt.ylabel('Predicted', fontsize = 10)

        plt.savefig(plot_save_dir +  'plot_' + str(i) + '.png')
        plt.clf()
    print('The plots are saved in "%s"! Have a nice day!' % plot_save_dir)


def checkpoint_loader(checkpoint_file):
    '''load the checkpoint for weights of PredNet.'''
    print('Loading...', end = '')
    checkpoint = torch.load(checkpoint_file)
    print('Done.')
    return checkpoint

def load_pretrained_weights(model, state_dict_file):
    '''Directly use the parameters taken from the pre-trained PredNet model of the Keras version provided by the original author'''
    model = model.load_state_dict(torch.load(state_dict_file))
    print('weights loaded!')
    return model

if __name__ == '__main__':
    args = arg_parse()
    print_args(args)

    n_channels = args.n_channels
    img_height = args.img_height
    img_width  = args.img_width

    # stack_sizes       = eval(args.stack_sizes)
    # R_stack_sizes     = eval(args.R_stack_sizes)
    # A_filter_sizes    = eval(args.A_filter_sizes)
    # Ahat_filter_sizes = eval(args.Ahat_filter_sizes)
    # R_filter_sizes    = eval(args.R_filter_sizes)

    stack_sizes       = (n_channels, 48, 96, 192)
    R_stack_sizes     = stack_sizes
    A_filter_sizes    = (3, 3, 3)
    Ahat_filter_sizes = (3, 3, 3, 3)
    R_filter_sizes    = (3, 3, 3, 3)

    prednet = PredNet(stack_sizes, R_stack_sizes, A_filter_sizes, Ahat_filter_sizes, R_filter_sizes,
                      output_mode = 'prediction', data_format = args.data_format, return_sequences = True)
    print(prednet)
    prednet.cuda()

    # print('\n'.join(['%s:%s' % item for item in prednet.__dict__.items()]))
    # print(type(prednet.state_dict()))   # <class 'collections.OrderedDict'>
    # for k, v in prednet.state_dict().items():
    #     print(k, v.size())

    ## Use self-trained parameters
    checkpoint_file = args.checkpoint_file
    try:
        checkpoint = checkpoint_loader(checkpoint_file)
    except Exception:
        raise(RuntimeError('Cannot load the checkpoint file named %s!' % checkpoint_file))
    state_dict = checkpoint['state_dict']
    prednet.load_state_dict(state_dict)

    ## 直接使用作者提供的预训练参数
    # state_dict_file = './model_data_keras2/preTrained_weights_forPyTorch.pkl'
    # # prednet = load_pretrained_weights(prednet, state_dict_file)   # 这种不work... why?
    # prednet.load_state_dict(torch.load(state_dict_file))

    assert args.mode == 'evaluate'
    evaluate(prednet, args)


usage: ipykernel_launcher.py [-h] [--mode MODE] [--dataPath PATH]
                             [--resultsPath PATH]
                             [--checkpoint_file CHECKPOINT_FILE]
                             [--batch_size N] [--num_plot N]
                             [--num_timeSteps N] [--workers N]
                             [--shuffle SHUFFLE] [--data_format DATA_FORMAT]
                             [--n_channels N] [--img_height N] [--img_width N]
ipykernel_launcher.py: error: unrecognized arguments: -f C:\Users\kirub\AppData\Roaming\jupyter\runtime\kernel-6fb2f742-3858-417b-8683-b6694075ca9b.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


## Evaluate model (Shell)

In [None]:
# %load evaluate.sh
#!/bin/bash

# usage:
# 	./evaluate.sh

echo "Evaluate..."
mode='evaluate'

# @200.121
DATA_DIR='/media/sdb1/chenrui/kitti_data/h5/'
# Where results (prediction plots and evaluation file) will be saved.
RESULTS_SAVE_DIR='./kitti_results/'
checkpoint_file='./checkpoint/checkpoint_epoch1_trLoss1342.3278.pkl'	# load weights from checkpoint file for evaluating.

batch_size=10
num_plot=40		# how many images to plot.

# number of timesteps used for sequences in evaluating
num_timeSteps=10

workers=4
shuffle=false

data_format='channels_first'
n_channels=3
img_height=128
img_width=160

CUDA_VISIBLE_DEVICES=2 python evaluate.py \
	--mode ${mode} \
	--dataPath ${DATA_DIR} \
	--resultsPath ${RESULTS_SAVE_DIR} \
	--checkpoint_file ${checkpoint_file} \
	--batch_size ${batch_size} \
	--num_plot ${num_plot} \
	--num_timeSteps ${num_timeSteps} \
	--workers ${workers} \
	--shuffle ${shuffle} \
	--data_format ${data_format} \
	--n_channels ${n_channels} \
	--img_height ${img_height} \
	--img_width ${img_width}
	# --stack_sizes ${stack_sizes} \
	# --R_stack_sizes ${R_stack_sizes} \
	# --A_filter_sizes ${A_filter_sizes} \
	# --Ahat_filter_sizes ${Ahat_filter_sizes} \
	# --R_filter_sizes ${R_filter_sizes} \

## Load weights

In [6]:
# %load load_weights.py

'''Load the parameters of the original Keras version of the PredNet model saved in hdf5 into the pytorch version of the model reproduced by zcr.'''

import os
import numpy as np
import h5py

import torch
# from torch.autograd import Variable


weights_file = 'C:\\Users\\kirub\\Desktop\\Summer project 2021\\PredNet_pytorch-master\\model_data_keras2\\prednet_kitti_weights.hdf5'
weights_f = h5py.File(weights_file, 'r')

#print(list(weights_f.items()))
#G1=weights_f.get('prednet_1')
#print(list(G1.items()))
#G2=G1.get('/prednet_1/prednet_1')
#print(list(G2.items()))
#G3=G1.get('/prednet_1/prednet_1/layer_a_0')
#print(list(G3.items()))


pred_weights = weights_f['prednet_1']['prednet_1']	# contains 23 item: 4x4(i,f,c,o for 4 layers) + 4(Ahat for 4 layers) + 3(A for 4 layers)
#pred_weights = weights_f['model_weights']['pred_net_1']['pred_net_1']

keras_items = ['bias', 'kernel']
#keras_items = ['bias:0', 'kernel:0']
pytorch_items = ['weight', 'bias']


keras_modules = ['a', 'ahat', 'c', 'f', 'i', 'o']
keras_modules = ['layer_' + m + '_' + str(i) for m in keras_modules for i in range(4)]
keras_modules.remove('layer_a_3')
assert len(keras_modules) == 4 * 4 + 4 + 3

pytorch_modules_1 = ['A', 'Ahat']
pytorch_modules_2 = ['c', 'f', 'i', 'o']
pytorch_modules_1 = [m + '.' + str(2 * i) + '.' + item for m in pytorch_modules_1 for i in range(4) for item in pytorch_items]
pytorch_modules_1.remove('A.6.weight')
pytorch_modules_1.remove('A.6.bias')
pytorch_modules_2 = [m + '.' + str(i) + '.' + item for m in pytorch_modules_2 for i in range(4) for item in pytorch_items]
pytorch_modules = pytorch_modules_1 + pytorch_modules_2
assert len(pytorch_modules) == (4 * 4 + 4 + 3) * 2

weight_dict = dict()


# Loaded from the h5 file is the weight of type <class'numpy.ndarray'>, which needs to be converted to cuda.Tensor
for i in range(len(keras_modules)):
	weight_dict[pytorch_modules[i * 2 + 1]] = pred_weights[keras_modules[i]]['bias'][:]
	# weight_dict[pytorch_modules[i * 2 + 1]] = pred_weights[keras_modules[i]]['bias:0']
	weight_dict[pytorch_modules[i * 2]] = np.transpose(pred_weights[keras_modules[i]]['kernel'][:], (3, 2, 1, 0))
	# weight_dict[pytorch_modules[i * 2]] = pred_weights[keras_modules[i]]['kernel:0']

for k, v in weight_dict.items():
	# print(k, v)
	# weight_dict[k] = Variable(torch.from_numpy(v).float().cuda())
	weight_dict[k] = torch.from_numpy(v).float().cuda()

fileName = 'C:\\Users\\kirub\\Desktop\\Summer project 2021\\PredNet_pytorch-master\\model_data_keras2\\preTrained_weights_forPyTorch.pkl'
weights_gift_from_keras = torch.save(weight_dict, fileName)

## Prednet

In [7]:
# %load prednet.py

'''
PredNet in PyTorch.
'''

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


def hard_sigmoid(x):
    '''
    - hard sigmoid function by zcr.
    - Computes element-wise hard sigmoid of x.
    - what is hard sigmoid?
        Segment-wise linear approximation of sigmoid. Faster than sigmoid.
        Returns 0. if x < -2.5, 1. if x > 2.5. In -2.5 <= x <= 2.5, returns 0.2 * x + 0.5.
    - See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279
    '''
    slope = 0.2
    shift = 0.5
    x = (slope * x) + shift
    x = F.threshold(-x, -1, -1)
    x = F.threshold(-x, 0, 0)
    return x

def get_activationFunc(act_str):
    act = act_str.lower()
    if act == 'relu':
        # return nn.ReLU(True)
        return nn.ReLU()
    elif act == 'tanh':
        # return F.tanh
        return nn.Tanh()
    # elif act == 'hard_sigmoid':
    #     return hard_sigmoid
    else:
        raise(RuntimeError('cannot obtain the activation function named %s' % act_str))

def batch_flatten(x):
    '''
    equal to the `batch_flatten` in keras.
    x is a Variable in pytorch
    '''
    shape = [*x.size()]
    dim = np.prod(shape[1:])
    dim = int(dim)      # Without this step, dim is of type <class'numpy.int64'> and cannot be used in the view. Add this step to convert to type <class'int'>.
    return x.view(-1, dim)



class PredNet(nn.Module):
    """
    PredNet realized by zcr.
    
    Args:
        stack_sizes:
            - Number of channels in targets (A) and predictions (Ahat) in each layer of the architecture.
            - Length of stack_size (i.e. len(stack_size) and we use `num_layers` to denote it) is the number of layers in the architecture.
            - First element is the number of channels in the input.
            - e.g., (3, 16, 32) would correspond to a 3 layer architecture that takes in RGB images and
              has 16 and 32 channels in the second and third layers, respectively.
            - The value of the subscript (lay + 1) is the out_channels parameter of the lay-th convolutional layer in pytorch. For example, the above 16 corresponds to the out_channels of A and Ahat of the lay 0 layer (that is, the input layer) is 16.
        R_stack_sizes:
            - Number of channels in the representation (R) modules.
            - Length must equal length of stack_sizes, but the number of channels per layer can be different.
            - That is, the out_channels parameter of the convolutional layer in pytorch.
        A_filter_sizes:
            - Filter sizes for the target (A) modules. (except the target (A) in lowest layer (i.e., input image))
            - Has length of len(stack_sizes) - 1.
            - e.g., (3, 3) would mean that targets for layers 2 and 3 are computed by a 3x3 convolution of
              the errors (E) from the layer below (followed by max-pooling)
            - That is, the kernel_size of the convolutional layer in pytorch.
        Ahat_filter_sizes:
            - Filter sizes for the prediction (Ahat) modules.
            - Has length equal to length of stack_sizes.
            - e.g., (3, 3, 3) would mean that the predictions for each layer are computed by a 3x3 convolution
              of the representation (R) modules at each layer.
            - That is, the kernel_size of the convolutional layer in pytorch.
        R_filter_sizes:
            - Filter sizes for the representation (R) modules.
            - Has length equal to length of stack_sizes.
            - Corresponds to the filter sizes for all convolutions in the LSTM.
            - That is, the kernel_size of the convolutional layer in pytorch.

        pixel_max:
            - The maximum pixel value.
            - Used to clip the pixel-layer prediction.
        error_activation:
            - Activation function for the error (E) units.
        A_activation:
            - Activation function for the target (A) and prediction (A_hat) units.
        LSTM_activation:
            - Activation function for the cell and hidden states of the LSTM.
        LSTM_inner_activation:
            - Activation function for the gates in the LSTM.
        output_mode:
            - Either 'error', 'prediction', 'all' or layer specification (e.g., R2, see below).
            - Controls what is outputted by the PredNet.
                - if 'error':
                    The mean response of the error (E) units of each layer will be outputted.
                    That is, the output shape will be (batch_size, num_layers).
                - if 'prediction':
                    The frame prediction will be outputted.
                - if 'all':
                    The output will be the frame prediction concatenated with the mean layer errors.
                    The frame prediction is flattened before concatenation.
                    Note that nomenclature of 'all' means all TYPE of the output (i.e., `error` and `prediction`), but should not be confused with returning all of the layers of the model.
                - For returning the features of a particular layer, output_mode should be of the form unit_type + layer_number.
                    e.g., to return the features of the LSTM "representational" units in the lowest layer, output_mode should be specificied as 'R0'.
                    The possible unit types are 'R', 'Ahat', 'A', and 'E' corresponding to the 'representation', 'prediction', 'target', and 'error' units respectively.
        extrap_start_time:
            - Time step for which model will start extrapolating.
            - Starting at this time step, the prediction from the previous time step will be treated as the "actual"
        data_format:
            - 'channels_first': (channel, Height, Width)
            - 'channels_last' : (Height, Width, channel)

    """
    def __init__(self, stack_sizes, R_stack_sizes, A_filter_sizes, Ahat_filter_sizes, R_filter_sizes,
                 pixel_max = 1.0, error_activation = 'relu', A_activation = 'relu', LSTM_activation = 'tanh',
                 LSTM_inner_activation = 'hard_sigmoid', output_mode = 'error',
                 extrap_start_time = None, data_format = 'channels_last', return_sequences = False):
        super(PredNet, self).__init__()
        self.stack_sizes = stack_sizes
        self.num_layers  = len(stack_sizes)
        assert len(R_stack_sizes) == self.num_layers
        self.R_stack_sizes = R_stack_sizes
        assert len(A_filter_sizes) == self.num_layers - 1
        self.A_filter_sizes = A_filter_sizes
        assert len(Ahat_filter_sizes) == self.num_layers
        self.Ahat_filter_sizes = Ahat_filter_sizes
        assert len(R_filter_sizes) == self.num_layers
        self.R_filter_sizes = R_filter_sizes

        self.pixel_max = pixel_max
        self.error_activation = error_activation
        self.A_activation = A_activation
        self.LSTM_activation = LSTM_activation
        self.LSTM_inner_activation = LSTM_inner_activation

        default_output_modes = ['prediction', 'error', 'all']
        layer_output_modes = [layer + str(n) for n in range(self.num_layers) for layer in ['R', 'E', 'A', 'Ahat']]
        assert output_mode in default_output_modes + layer_output_modes
        self.output_mode = output_mode
        if self.output_mode in layer_output_modes:
            self.output_layer_type = self.output_mode[:-1]
            self.output_layer_NO = int(self.output_mode[-1])    # suppose the number of layers is < 10
        else:
            self.output_layer_type = None
            self.output_layer_NO = None

        self.extrap_start_time = extrap_start_time
        assert data_format in ['channels_first', 'channels_last']
        self.data_format = data_format
        if self.data_format == 'channels_first':
            self.channel_axis = -3
            self.row_axis = -2
            self.col_axis = -1
        else:
            self.channel_axis = -1
            self.row_axis = -3
            self.col_axis = -2

        self.return_sequences = return_sequences

        self.make_layers()


    def get_initial_states(self, input_shape):
        '''
        input_shape is like: (batch_size, timeSteps, Height, Width, 3)
                         or: (batch_size, timeSteps, 3, Height, Width)
        '''
        init_height = input_shape[self.row_axis]     # equal to `init_nb_rows` in original version
        init_width  = input_shape[self.col_axis]     # equal to `init_nb_cols` in original version

        base_initial_state = np.zeros(input_shape)
        non_channel_axis = -1 if self.data_format == 'channels_first' else -2
        for _ in range(2):
            base_initial_state = np.sum(base_initial_state, axis = non_channel_axis)
        base_initial_state = np.sum(base_initial_state, axis = 1)   # (batch_size, 3)

        initial_states = []
        states_to_pass = ['R', 'c', 'E']    # R is `representation`, c is Cell state in LSTM, E is `error`.
        layerNum_to_pass = {sta: self.num_layers for sta in states_to_pass}
        if self.extrap_start_time is not None:
            states_to_pass.append('Ahat')   # pass prediction in states so can use as actual for t+1 when extrapolating
            layerNum_to_pass['Ahat'] = 1

        for sta in states_to_pass:
            for lay in range(layerNum_to_pass[sta]):
                downSample_factor = 2 ** lay            # Downsampling scaling factor
                row = init_height // downSample_factor
                col = init_width  // downSample_factor
                if sta in ['R', 'c']:
                    stack_size = self.R_stack_sizes[lay]
                elif sta == 'E':
                    stack_size = self.stack_sizes[lay] * 2
                elif sta == 'Ahat':
                    stack_size = self.stack_sizes[lay]
                output_size = stack_size * row * col    # flattened size
                reducer = np.zeros((input_shape[self.channel_axis], output_size))   # (3, output_size)
                initial_state = np.dot(base_initial_state, reducer)                 # (batch_size, output_size)

                if self.data_format == 'channels_first':
                    output_shape = (-1, stack_size, row, col)
                else:
                    output_shape = (-1, row, col, stack_size)
                # initial_state = torch.from_numpy(np.reshape(initial_state, output_shape)).float().cuda()
                initial_state = Variable(torch.from_numpy(np.reshape(initial_state, output_shape)).float().cuda(), requires_grad = True)
                initial_states += [initial_state]

        if self.extrap_start_time is not None:
            # initial_states += [torch.IntTensor(1).zero_().cuda()]   # the last state will correspond to the current timestep
            initial_states += [Variable(torch.IntTensor(1).zero_().cuda())]   # the last state will correspond to the current timestep
        return initial_states


    # def compute_output_shape(self, input_shape):
    #     if self.output_mode == 'prediction':
    #         out_shape = input_shape[2:]
    #     elif self.output_mode == 'error':   # The error mode output is the error of each layer, with a scalar for each layer
    #         out_shape = (self.num_layers,)
    #     elif self.output_mode == 'all':
    #         out_shape = (np.prod(input_shape[2:]) + self.num_layers,)   # np.prod multiply the elements one by one
    #     else:
    #         if self.output_layer_type == 'R':
    #             stack_str = 'R_stack_sizes'
    #         else:
    #             stack_str = 'stack_sizes'

    #         if self.output_layer_type == 'E':
    #             stack_multi = 2
    #         else:
    #             stack_multi = 1

    #         out_stack_size = stack_multi * getattr(self, stack_str)[self.output_layer_NO]
    #         layer_out_row = input_shape[self.row_axis] / (2 ** self.output_layer_NO)
    #         layer_out_col = input_shape[self.col_axis] / (2 ** self.output_layer_NO)
    #         if self.data_format == 'channels_first':
    #             out_shape = (out_stack_size, layer_out_row, layer_out_col)
    #         else:
    #             out_shape = (layer_out_row, layer_out_col, out_stack_size)

    #         if self.return_sequences:
    #             return (input_shape[0], input_shape[1]) + out_shape    # input_shape[1] is the timesteps
    #         else:
    #             return (input_shape[0],) + out_shape


    def isNotTopestLayer(self, layerIndex):
        '''judge if the layerIndex is not the topest layer.'''
        if layerIndex < self.num_layers - 1:
            return True
        else:
            return False


    def make_layers(self):
        '''
        equal to the `build` method in original version.
        '''
        # i: input, f: forget, c: cell, o: output
        self.conv_layers = {item: [] for item in ['i', 'f', 'c', 'o', 'A', 'Ahat']}
        lstm_list = ['i', 'f', 'c', 'o']

        for item in sorted(self.conv_layers.keys()):
            for lay in range(self.num_layers):
                downSample_factor = 2 ** lay        # Downsampling scaling factor
                if item == 'Ahat':
                    in_channels = self.R_stack_sizes[lay]   # Because Ahat convolves the output of R, the number of channels input to Ahat is the number of output channels of R in the same layer.
                    self.conv_layers['Ahat'].append(nn.Conv2d(in_channels = in_channels,
                                                              out_channels = self.stack_sizes[lay],
                                                              kernel_size = self.Ahat_filter_sizes[lay],
                                                              stride = (1, 1),
                                                              padding = int((self.Ahat_filter_sizes[lay] - 1) / 2)    # the `SAME` mode (i.e.,(kernel_size - 1) / 2)
                                                              ))
                    act = 'relu' if lay == 0 else self.A_activation
                    self.conv_layers['Ahat'].append(get_activationFunc(act))

                elif item == 'A':
                    if self.isNotTopestLayer(lay):   # Here is just to control the number of layers (one less than other such as Ahat)
                        # NOTE: Here is the construction of A from the second layer (lay = 1) (because the A of the lowest layer (layer0) of the entire network is the original image (the A of layer0 can be regarded as an identity layer, that is, the input image, Output the image as it is))
                        in_channels = self.R_stack_sizes[lay] * 2   # The number of input features of the A convolutional layer (in_channels) is the number of features of the corresponding layer E. E contains two parts (Ahat-A) and (A-Ahat), so x2. [From the left picture of Fig.1 of the paper, E The output of Ahat is subtracted from A, and then spliced.)
                        self.conv_layers['A'].append(nn.Conv2d(in_channels = in_channels,
                                                               out_channels = self.stack_sizes[lay + 1],
                                                               kernel_size = self.A_filter_sizes[lay],
                                                               stride = (1, 1),
                                                               padding = int((self.A_filter_sizes[lay] - 1) / 2)    # the `SAME` mode
                                                               ))
                        self.conv_layers['A'].append(get_activationFunc(self.A_activation))

                elif item in lstm_list:     # Build the R module
                    # The number of input features of R (in_channels): the sum of the number of features of E in the same layer, R at the same time at the same time (i.e. R_t-1), and R at the same time engraved on the upper layer (i.e. R_l+1).
                    # If the R module is on the top layer, there is no R from the upper layer. Among them:
                    # -stack_sizes[lay] * 2 represents the number of channels in the same layer E (because E is obtained by splicing A and Ahat in the same layer in the channel dimension, so x2)
                    # -R_stack_sizes[lay] represents the number of R channels on the same layer at a time
                    # -R_stack_sizes[lay + 1] represents the number of channels that engrave the upper layer of R at the same time
                    in_channels = self.stack_sizes[lay] * 2 + self.R_stack_sizes[lay]
                    if self.isNotTopestLayer(lay):
                        in_channels += self.R_stack_sizes[lay + 1]
                    # for j in lstm_list:     # Serious bug! Quickly comment out... the following forward indentation 4 spaces...
                    # The non-linear activation function layer of i, f, c, and o in LSTM is implemented in forward. (Because i, f, o use hard_sigmoid function here, LSTM in Keras is hard_sigmoid by default, but you need to implement it yourself in pytorch)
                    # act = self.LSTM_activation if j == 'c' else self.LSTM_inner_activation
                    # act = get_activationFunc(act)
                    self.conv_layers[item].append(nn.Conv2d(in_channels = in_channels,
                                                         out_channels = self.R_stack_sizes[lay],
                                                         kernel_size = self.R_filter_sizes[lay],
                                                         stride = (1, 1),
                                                         padding = int((self.R_filter_sizes[lay] - 1) / 2)    # the `SAME` mode
                                                         ))

        for name, layerList in self.conv_layers.items():
            self.conv_layers[name] = nn.ModuleList(layerList)
            setattr(self, name, self.conv_layers[name])

        # see the source code in:
        #     [PyTorch]: http://pytorch.org/docs/master/_modules/torch/nn/modules/upsampling.html
        #     [Keras  ]: keras-master/keras/layers/convolution.py/`class UpSampling2D(Layer)`
        # self.upSample = nn.Upsample(size = (2, 2), mode = 'nearest')  # It's wrong! The scale_factor parameter in pytorch corresponds to the size parameter in keras.
        self.upSample = nn.Upsample(scale_factor = 2, mode = 'nearest')
        # see the source code in:
        #     [PyTorch]: http://pytorch.org/docs/master/_modules/torch/nn/modules/pooling.html#MaxPool2d
        #     [Keras  ]: keras-master/keras/layers/pooling.py/``
        # `pool_size` in Keras is equal to `kernel_size` in pytorch.
        # [TODO] padding here is not very clear. Is `0` here is the `SAME` mode in Keras?
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2, padding = 0)


    def step(self, A, states):
        '''
        This step function is equivalent to the `step` function in the original code. It is the core logic of PredNet.
        Analogous to the implementation of standard LSTM, the role of this step function is equivalent to LSTMCell, and the following forward function is equivalent to the LSTM class.

        Args:
            A: 4D tensor with the shape of (batch_size, 3, Height, Width). It is the data extracted from A_withTimeStep according to the time step.
            The form of the `initial_states` of the states and `forward` functions is exactly the same, except that the latter is the initialized PredNet state, and the states here are the PredNet parameters when calculating in timesteps.
        '''
        n = self.num_layers
        R_current = states[       :    (n)]
        c_current = states[    (n):(2 * n)]
        E_current = states[(2 * n):(3 * n)]

        if self.extrap_start_time is not None:
            timestep = states[-1]
            if timestep >= self.t_extrap:   # if past self.extrap_start_time, the previous prediction will be treated as the actual.
                A = states[-2]
            else:
                A = A

        R_list = []
        c_list = []
        E_list = []

        # Update R units starting from the top.
        for lay in reversed(range(self.num_layers)):
            inputs = [R_current[lay], E_current[lay]]   # If it is the top level, there are only two inputs for R_l: E_l^t, R_l^(t-1). That is, there are no input items for the high-level R module.
            if self.isNotTopestLayer(lay):              # If it is not the top level, R_l has three inputs: E_l^t, R_l^(t-1), R_(l+1)^t. R_up is R_(l+1)^t
                inputs.append(R_up)
            
            inputs = torch.cat(inputs, dim = self.channel_axis)
            if not isinstance(inputs, Variable):        # In the first time step, the inputs are still of Tensor type, but after going through the network, they will be of Variable type in subsequent time steps.
                inputs = Variable(inputs, requires_grad = True)

            # print(lay, type(inputs), inputs.size())   # In the right case, an example is as follows:
            # lay3: torch.Size([8, 576, 16, 20])  [576 = 384(E_l^t) + 192(R_l^(t-1))]
            # lay2: torch.Size([8, 480, 32, 40])  [480 = 192(E_l^t) +  96(R_l^(t-1)) + 192(R_(l+1)^t)]
            # lay1: torch.Size([8, 240, 64, 80])  [240 =  96(E_l^t) +  48(R_l^(t-1)) +  96(R_(l+1)^t)]
            # lay0: torch.Size([8, 57, 160, 128]) [ 57 =   6(E_l^t) +   3(R_l^(t-1)) +  48(R_(l+1)^t)]

            # see https://github.com/huggingface/torchMoji/blob/master/torchmoji/lstm.py
            in_gate     = hard_sigmoid(self.conv_layers['i'][lay](inputs))
            forget_gate = hard_sigmoid(self.conv_layers['f'][lay](inputs))
            cell_gate   = F.tanh(self.conv_layers['c'][lay](inputs))
            out_gate    = hard_sigmoid(self.conv_layers['o'][lay](inputs))

            # print(forget_gate.size())       # torch.Size([8, 192, 16, 20])
            # print(c_current[lay].size())    # torch.Size([8, 192, 16, 20])
            # print(in_gate.size())           # torch.Size([8, 192, 16, 20])
            # print(cell_gate.size())         # torch.Size([8, 192, 16, 20])
            # print(type(forget_gate))        # <class 'torch.autograd.variable.Variable'>
            # print(type(c_current[lay]))     # <class 'torch.cuda.FloatTensor'>
            # print(type(Variable(c_current[lay])))     # <class 'torch.autograd.variable.Variable'>
            # print(type(in_gate))            # <class 'torch.autograd.variable.Variable'>
            # print(type(cell_gate))          # <class 'torch.autograd.variable.Variable'>
            if not isinstance(c_current[lay], Variable):
                c_current[lay] = Variable(c_current[lay], requires_grad = True)
            c_next = (forget_gate * c_current[lay]) + (in_gate * cell_gate)     # Multiply corresponding elements
            R_next = out_gate * F.tanh(c_next)      # `R_next` here is equivalent to the hidden state in the standard LSTM.This is the representation of the video.

            c_list.insert(0, c_next)
            R_list.insert(0, R_next)

            if lay > 0:
                # R_up = self.upSample(R_next).data     # Note: What comes out here is Variable, the ones that need to be appended to the input list above are all FloatTensor, so it needs to be changed into Tensor form here, that is, add a `.data`
                R_up = self.upSample(R_next)            # NOTE:This is the reason why the loss.backward() error is reported for a long time: the error caused by mixing Tensor and Variable in torch.cat()!
                # print(R_up.size())  # lay3: torch.Size([8, 192, 32, 40])


        # Update feedforward path starting from the bottom.
        for lay in range(self.num_layers):
            Ahat = self.conv_layers['Ahat'][2 * lay](R_list[lay])   # Ahat is the convolution of R, so input the R in the same layer at the same time. Please pay attention here: each `lay` actually corresponds to two components: convolutional layer + nonlinear activation layer, so you need to use (2 * lay) to index the convolutional layer corresponding to `lay`, and (2 * lay + 1) to index the nonlinear activation function layer corresponding to `lay`. The same is true for A below.
            Ahat = self.conv_layers['Ahat'][2 * lay + 1](Ahat)      # Don't forget the nonlinear activation. The following is the same for A.
            if lay == 0:
                # Ahat = torch.min(Ahat, self.pixel_max)            # Error (representation in keras)
                Ahat[Ahat > self.pixel_max] = self.pixel_max        # passed through a saturating non-linearity set at the maximum pixel value
                frame_prediction = Ahat                             # The Ahat of the lowest layer is the predicted output frame image
                # if self.output_mode == 'prediction':
                #     break
            
            # print('&' * 10, lay)
            # print('Ahat', Ahat.size())  # torch.Size([batch_size, 3, 128, 160])
            # print('A', A.size())        # It turns out that A0 directly uses the data loaded from the dataloader, so torch.Size([batch_size, 10, 3, 128, 160]) is printed, and this is the problem: the data returned by the dataloader is (batch_size, timesteps, (image_shape)), but actually used in RNN is to separate each time step. Now decouple the core logic to form a `step` function, A0 becomes torch.Size([batch_size, 3, 128, 160]) this dimension.
            # print('&' * 20)
            
            # compute errors
            if self.error_activation.lower() == 'relu':
                E_up   = F.relu(Ahat - A)
                E_down = F.relu(A - Ahat)
            elif self.error_activation.lower() == 'tanh':
                E_up   = F.tanh(Ahat - A)
                E_down = F.tanh(A - Ahat)
            else:
                raise(RuntimeError('cannot obtain the activation function named %s' % self.error_activation))
            
            E_list.append(torch.cat((E_up, E_down), dim = self.channel_axis))

            # If you want to get the output of a specific module in a specific layer:
            if self.output_layer_NO == lay:
                if   self.output_layer_type == 'A':
                    output = A
                elif self.output_layer_type == 'Ahat':
                    output = Ahat
                elif self.output_layer_type == 'R':
                    output = R_list[lay]
                elif self.output_layer_type == 'E':
                    output = E_list[lay]

            if self.isNotTopestLayer(lay):
                A = self.conv_layers['A'][2 * lay](E_list[lay])     # After convolution + pooling on E, you get A that is engraved with a layer at the same time. If the layer is already the topmost layer, you don't need it
                A = self.conv_layers['A'][2 * lay + 1](A)           # Don't forget the nonlinear activation.
                A = self.pool(A)    # target for next layer
        

        if self.output_layer_type is None:
            if self.output_mode == 'prediction':
                output = frame_prediction
            else:
                for lay in range(self.num_layers):
                    layer_error = torch.mean(batch_flatten(E_list[lay]), dim = -1, keepdim = True)     # The batch_flatten function is implemented by zcr in accordance with the function of the same name in Kears. The 0th dimension is the batch_size dimension, and the dimensions other than this dimension are flattened
                    all_error = layer_error if lay == 0 else torch.cat((all_error, layer_error), dim = -1)
                if self.output_mode == 'error':
                    output = all_error
                else:
                    output = torch.cat((batch_flatten(frame_prediction), all_error), dim = -1)

        states = R_list + c_list + E_list
        if self.extrap_start_time is not None:
            states += [frame_prediction, (timestep + 1)]
        return output, states


    def forward(self, A0_withTimeStep, initial_states):
        '''
        A0_withTimeStep is the input from dataloader. Its shape is: (batch_size, timesteps, 3, Height, Width).
            To put it bluntly, this A0_withTimeStep is the original image loaded by the dataloader, that is, the A of the lowest layer (layer 0), but it is expanded in the two dimensions of batch_size and timestep.
        initial_states  is a list of pytorch-tensors. The states parameter is actually the initial state, because the forword function itself is not executed in a loop.

        NOTE: The purpose of this forward function is to implement the step function of the original Keras version, but it is not the same as the latter. Because the PredNet class of the original code is
              Inherited the `Recurrent` class in Keras, so it seems that the parent class implements the loading of dataloader (ie SequenceGenerator in the original code)
              The data (batch_size, timesteps, 3, H, W) is decomposed into (batch_size, 3, H, W), and then the solution is looped for timesteps.
              And the forward here needs to implement the loop timesteps by yourself. The shape of A here is the 5D tensor (batch_size, timesteps, 3, Height, Width) from the dataloader,
              The shape of the input `x` of the step function in the original code is 4D tensor (batch_size, 3, Height, Width).
        '''

        # The default is batch_fist == True, that is, the first dimension is batch_size, and the second dimension is timesteps.
        A0_withTimeStep = A0_withTimeStep.transpose(0, 1)   # (b, t, c, h, w) -> (t, b, c, h, w)

        num_timesteps = A0_withTimeStep.size()[0]

        hidden_states = initial_states    # Assigned to hidden_states is to be used painlessly in the following loop
        output_list = []                  # The output needs to be retained: In the error mode, it needs to be weighted according to the layer and timestep to get the final loss; in the prediction mode, it is necessary to output the predicted image at each time step (for example, if the timestep is 10, output 10 images)
        for t in range(num_timesteps):
            '''
                The original LSTM (or ordinary RNN) requires two cycles:
                for lay in range(num_layers):
                    for t in range(num_timesteps):
                        pass
                
But as the footnote part of the original Keras version of the code said: Although PredNet sets the number of layers, it is actually implemented using
A super layer (`super layer`) is implemented, that is, it is a layer itself. So there is no for lay loop here.
            '''
            A0 = A0_withTimeStep[t, ...]
            output, hidden_states = self.step(A0, hidden_states)
            output_list.append(output)
            # hidden_states There is no need to keep it, just let it carry out the iteration of the'Yangtze River Back Wave Push Forward Wave' type within the time step.

        if self.output_mode == 'error':
            '''Perform weighting according to layer and timestep. Different from the way of adding Dense layer in the original code, the weighting operation can be directly written in the PredNet model (in this if statement), or the error of each layer in all time steps can be returned, Calculate in the main function. zcr chooses the latter (consistent with the original code)'''
            # print(len(output_list))             # 10, The number of timesteps
            # print('output: ', output_list)      # The `error` of each time step is a matrix of (batch_size, num_layer), and the type is Variable. [torch.cuda.FloatTensor of size 8x4 (GPU 0)] According to this, weighting according to layer and timestep can be achieved. Calculation! (Two types of weighting according to the layer, you can get the so-called two types of loss of `L_0` and `L_all`)
            # print('Got the `error` list with the length of len(timeSteps) and shape of each element in this list is: (batch_size, num_layer).')
            return output_list
        elif self.output_mode == 'prediction':
            return output_list  # The output_list at this time is the timestep prediction frame image
        elif self.output_mode == 'all':
            pass
        else:
            raise(RuntimeError('Kidding? Unknown output mode!'))


if __name__ == '__main__':
    n_channels = 3
    img_height = 128
    img_width  = 160

    stack_sizes       = (n_channels, 48, 96, 192)
    R_stack_sizes     = stack_sizes
    A_filter_sizes    = (3, 3, 3)
    Ahat_filter_sizes = (3, 3, 3, 3)
    R_filter_sizes    = (3, 3, 3, 3)

    prednet = PredNet(stack_sizes, R_stack_sizes, A_filter_sizes, Ahat_filter_sizes, R_filter_sizes,
                      output_mode = 'error', return_sequences = True)



## Train model

In [8]:
# %load train.py
import traceback

import os
import numpy as np
import argparse
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import lr_scheduler

# zcr lib
from prednet import PredNet
from data_utils import ZcrDataLoader

# os.environ['CUDA_LAUNCH_BLOCKING'] = 1
# torch.backends.cudnn.benchmark = True

def arg_parse():
    desc = "Video Frames Predicting Task via PredNet."
    parser = argparse.ArgumentParser(description = desc)

    parser.add_argument('--mode', default = 'train', type = str,
                        help = 'train or evaluate (default: train)')
    parser.add_argument('--dataPath', default = '', type = str, metavar = 'PATH',
                        help = 'path to video dataset (default: none)')
    parser.add_argument('--checkpoint_savePath', default = '', type = str, metavar = 'PATH',
                        help = 'path for saving checkpoint file (default: none)')
    parser.add_argument('--epochs', default = 20, type = int, metavar='N',
                        help = 'number of total epochs to run')
    parser.add_argument('--batch_size', default = 32, type = int, metavar = 'N',
                        help = 'The size of batch')
    parser.add_argument('--optimizer', default = 'SGD', type = str,
                        help = 'which optimizer to use')
    parser.add_argument('--lr', default = 0.01, type = float,
                        metavar = 'LR', help = 'initial learning rate')
    parser.add_argument('--momentum', default = 0.9, type = float,
                        help = 'momentum for SGD')
    parser.add_argument('--beta1', default = 0.9, type = float,
                        help = 'beta1 in Adam optimizer')
    parser.add_argument('--beta2', default = 0.99, type = float,
                        help = 'beta2 in Adam optimizer')
    parser.add_argument('--workers', default = 4, type = int, metavar = 'N',
                        help = 'number of data loading workers (default: 4)')
    parser.add_argument('--checkpoint_file', default = '', type = str,
                        help = 'path to checkpoint file for restrating (default: none)')
    parser.add_argument('--printCircle', default = 100, type = int, metavar = 'N',
                        help = 'how many steps to print the loss information')
    parser.add_argument('--data_format', default = 'channels_last', type = str,
                        help = '(c, h, w) or (h, w, c)?')
    parser.add_argument('--n_channels', default = 3, type = int, metavar = 'N',
                        help = 'The number of input channels (default: 3)')
    parser.add_argument('--img_height', default = 128, type = int, metavar = 'N',
                        help = 'The height of input frame (default: 128)')
    parser.add_argument('--img_width', default = 160, type = int, metavar = 'N',
                        help = 'The width of input frame (default: 160)')
    # parser.add_argument('--stack_sizes', default = '', type = str,
    #                     help = 'Number of channels in targets (A) and predictions (Ahat) in each layer of the architecture.')
    # parser.add_argument('--R_stack_sizes', default = '', type = str,
    #                     help = 'Number of channels in the representation (R) modules.')
    # parser.add_argument('--A_filter_sizes', default = '', type = str,
    #                     help = 'Filter sizes for the target (A) modules. (except the target (A) in lowest layer (i.e., input image))')
    # parser.add_argument('--Ahat_filter_sizes', default = '', type = str,
    #                     help = 'Filter sizes for the prediction (Ahat) modules.')
    # parser.add_argument('--R_filter_sizes', default = '', type = str,
    #                     help = 'Filter sizes for the representation (R) modules.')
    parser.add_argument('--layer_loss_weightsMode', default = 'L_0', type = str,
                        help = 'L_0 or L_all for loss weights in PredNet')
    parser.add_argument('--num_timeSteps', default = 10, type = int, metavar = 'N',
                        help = 'number of timesteps used for sequences in training (default: 10)')
    parser.add_argument('--shuffle', default = True, type = bool,
                        help = 'shuffle or not')
    
    #args = parser.parse_args()
    args, unknown = parser.parse_known_args()
    return args

def print_args(args):
    print('-' * 50)
    for arg, content in args.__dict__.items():
        print("{}: {}".format(arg, content))
    print('-' * 50)

def train(model, args):
    '''Train PredNet on KITTI sequences'''
    
    # print('layer_loss_weightsMode: ', args.layer_loss_weightsMode)
    prednet = model
    # frame data files
    DATA_DIR = args.dataPath
    train_file = os.path.join(DATA_DIR, 'X_train.h5')
    train_sources = os.path.join(DATA_DIR, 'sources_train.h5')
    val_file = os.path.join(DATA_DIR, 'X_val.h5')
    val_sources = os.path.join(DATA_DIR, 'sources_val.h5')

    output_mode = 'error'
    sequence_start_mode = 'all'
    N_seq = None
    dataLoader = ZcrDataLoader(train_file, train_sources, output_mode, sequence_start_mode, N_seq, args).dataLoader()
    
    if prednet.data_format == 'channels_first':
        input_shape = (args.batch_size, args.num_timeSteps, n_channels, img_height, img_width)
    else:
        input_shape = (args.batch_size, args.num_timeSteps, img_height, img_width, n_channels)

    optimizer = torch.optim.Adam(prednet.parameters(), lr = args.lr)
    lr_maker  = lr_scheduler.StepLR(optimizer = optimizer, step_size = 75, gamma = 0.1)  # decay the lr every 50 epochs by a factor of 0.1

    printCircle = args.printCircle
    for e in range(args.epochs):
        tr_loss = 0.0
        sum_trainLoss_in_epoch = 0.0
        min_trainLoss_in_epoch = float('inf')
        startTime_epoch = time.time()
        lr_maker.step()

        initial_states = prednet.get_initial_states(input_shape)    # 原网络貌似不是stateful的, 故这里再每个epoch开始时重新初始化(如果是stateful的, 则只在全部的epoch开始时初始化一次)
        states = initial_states
        for step, (frameGroup, target) in enumerate(dataLoader):
            # print(frameGroup)   # [torch.FloatTensor of size 16x12x80x80]
            batch_frames = Variable(frameGroup.cuda())
            batch_y = Variable(target.cuda())
            output = prednet(batch_frames, states)

            # '''进行按照timestep和layer对error进行加权.'''
            ## 1. 按layer加权(巧妙利用广播. NOTE: 这里的error列表里的每个元素是Variable类型的矩阵, 需要转成numpy矩阵类型才可以用切片.)
            num_layer = len(stack_sizes)
            # weighting for each layer in final loss
            if args.layer_loss_weightsMode == 'L_0':        # e.g., [1., 0., 0., 0.]
                layer_weights = np.array([0. for _ in range(num_layer)])
                layer_weights[0] = 1.
                layer_weights = torch.from_numpy(layer_weights)
                # layer_weights = torch.from_numpy(np.array([1., 0., 0., 0.]))
            elif args.layer_loss_weightsMode == 'L_all':    # e.g., [1., 1., 1., 1.]
                layer_weights = np.array([0.1 for _ in range(num_layer)])
                layer_weights[0] = 1.
                layer_weights = torch.from_numpy(layer_weights)
                # layer_weights = torch.from_numpy(np.array([1., 0.1, 0.1, 0.1]))
            else:
                raise(RuntimeError('Unknown loss weighting mode! Please use `L_0` or `L_all`.'))
            # layer_weights = Variable(layer_weights.float().cuda(), requires_grad = False)  # NOTE: layer_weights默认是DoubleTensor, 而下面的error是FloatTensor的Variable, 如果直接相乘会报错!
            layer_weights = Variable(layer_weights.float().cuda())  # NOTE: layer_weights默认是DoubleTensor, 而下面的error是FloatTensor的Variable, 如果直接相乘会报错!
            error_list = [batch_x_numLayer__error * layer_weights for batch_x_numLayer__error in output]    # 利用广播实现加权

            ## 2. 按timestep进行加权. (paper: equally weight all timesteps except the first)
            num_timeSteps = args.num_timeSteps
            time_loss_weight  = (1. / (num_timeSteps - 1))
            time_loss_weight  = Variable(torch.from_numpy(np.array([time_loss_weight])).float().cuda())
            time_loss_weights = [time_loss_weight for _ in range(num_timeSteps - 1)]
            time_loss_weights.insert(0, Variable(torch.from_numpy(np.array([0.])).float().cuda()))

            error_list = [error_at_t.sum() for error_at_t in error_list]   # 是一个Variable的列表
            total_error = error_list[0] * time_loss_weights[0]
            for err, time_weight in zip(error_list[1:], time_loss_weights[1:]):
                total_error = total_error + err * time_weight

            loss = total_error
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # if (step + 1) == 2500:
            #     zcr_state_dict = {
            #         'epoch'     : (e + 1),
            #         'tr_loss'   : 0,
            #         'state_dict': prednet.state_dict(),
            #         'optimizer' : optimizer.state_dict()
            #     }
            #     saveCheckpoint(zcr_state_dict)

            # print('epoch: [%3d/%3d] | step: [%4d/%4d]  loss: %.4f' % ((e + 1), args.epochs, (step + 1), len(dataLoader), loss.data[0]))

            tr_loss += loss.data[0]
            sum_trainLoss_in_epoch += loss.data[0]
            if step % printCircle == (printCircle - 1):
                print('epoch: [%3d/%3d] | [%4d/%4d]  loss: %.4f  lr: %.5lf' % ((e + 1), args.epochs, (step + 1), len(dataLoader), tr_loss / printCircle, optimizer.param_groups[0]['lr']))
                tr_loss = 0.0

        endTime_epoch = time.time()
        print('Time Consumed within an epoch: %.2f (s)' % (endTime_epoch - startTime_epoch))

        if sum_trainLoss_in_epoch < min_trainLoss_in_epoch:
            min_trainLoss_in_epoch = sum_trainLoss_in_epoch
            zcr_state_dict = {
                'epoch'     : (e + 1),
                'tr_loss'   : min_trainLoss_in_epoch,
                'state_dict': prednet.state_dict(),
                'optimizer' : optimizer.state_dict()
            }
            saveCheckpoint(zcr_state_dict)


def saveCheckpoint(zcr_state_dict, fileName = 'C:\\Users\\kirub\\Desktop\\Summer project 2021\\PredNet_pytorch-master\\checkpoint\\checkpoint_newest.pkl'):
    '''save the checkpoint for both restarting and evaluating.'''
    tr_loss  = '%.4f' % zcr_state_dict['tr_loss']
    # val_loss = '%.4f' % zcr_state_dict['val_loss']
    epoch = zcr_state_dict['epoch']
    # fileName = './checkpoint/checkpoint_epoch' + str(epoch) + '_trLoss' + tr_loss + '_valLoss' + val_loss + '.pkl'
    fileName = 'C:\\Users\\kirub\\Desktop\\Summer project 2021\\PredNet_pytorch-master\\checkpoint\\PredNet\\checkpoint_epoch' + str(epoch) + '_trLoss' + tr_loss + '.pkl'
    torch.save(zcr_state_dict, fileName)



if __name__ == '__main__':
    args = arg_parse()
    print_args(args)

    # DATA_DIR = args.dataPath
    # data_file = os.path.join(DATA_DIR, 'X_test.h5')
    # source_file = os.path.join(DATA_DIR, 'sources_test.h5')
    # output_mode = 'error'
    # sequence_start_mode = 'all'
    # N_seq = None
    # dataLoader = ZcrDataLoader(data_file, source_file, output_mode, sequence_start_mode, N_seq, args).dataLoader()

    # images, target = next(iter(dataLoader))
    # print(images)
    # print(target)

    n_channels = args.n_channels
    img_height = args.img_height
    img_width  = args.img_width

    # stack_sizes       = eval(args.stack_sizes)
    # R_stack_sizes     = eval(args.R_stack_sizes)
    # A_filter_sizes    = eval(args.A_filter_sizes)
    # Ahat_filter_sizes = eval(args.Ahat_filter_sizes)
    # R_filter_sizes    = eval(args.R_filter_sizes)

    stack_sizes       = (n_channels, 48, 96, 192)
    R_stack_sizes     = stack_sizes
    A_filter_sizes    = (3, 3, 3)
    Ahat_filter_sizes = (3, 3, 3, 3)
    R_filter_sizes    = (3, 3, 3, 3)

    prednet = PredNet(stack_sizes, R_stack_sizes, A_filter_sizes, Ahat_filter_sizes, R_filter_sizes,
                      output_mode = 'error', data_format = args.data_format, return_sequences = True)
    print(prednet)
    prednet.cuda()

    assert args.mode == 'train'
    train(prednet, args)
    

    
    
    

    

    

--------------------------------------------------
mode: train
dataPath: 
checkpoint_savePath: 
epochs: 20
batch_size: 32
optimizer: SGD
lr: 0.01
momentum: 0.9
beta1: 0.9
beta2: 0.99
workers: 4
checkpoint_file: 
printCircle: 100
data_format: channels_last
n_channels: 3
img_height: 128
img_width: 160
layer_loss_weightsMode: L_0
num_timeSteps: 10
shuffle: True
--------------------------------------------------
PredNet(
  (i): ModuleList(
    (0): Conv2d(57, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(240, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(480, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Conv2d(576, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (f): ModuleList(
    (0): Conv2d(57, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(240, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(480, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3):

IndexError: list index out of range

## Train Model (Shell)

In [None]:
# %load train.sh
#!/bin/bash

# usage:
# 	./train.sh

echo "Train..."
mode='train'

# @200.121
DATA_DIR='/media/sdb1/chenrui/kitti_data/h5/'
checkpoint_savePath='./checkpoint/'
checkpoint_file='./checkpoint/'	# checkpoint file name for restarting.

epochs=1
batch_size=8
optimizer='Adam'
learning_rate=0.001
momentum=0.9
beta1=0.9
beta2=0.99

workers=4

# it is vital for restarting
checkpoint_file='./checkpoint/'
printCircle=100

data_format='channels_first'
n_channels=3
img_height=128
img_width=160

# stack_sizes="($n_channels, 48, 96, 192)"
# R_stack_sizes=$stack_sizes
# A_filter_sizes="(3, 3, 3)"
# Ahat_filter_sizes="(3, 3, 3, 3)"
# R_filter_sizes="(3, 3, 3, 3)"

layer_loss_weightsMode='L_0'
# layer_loss='L_all'

# number of timesteps used for sequences in training
num_timeSteps=10

shuffle=true

CUDA_VISIBLE_DEVICES=0 python train.py \
	--mode ${mode} \
	--dataPath ${DATA_DIR} \
	--checkpoint_savePath ${checkpoint_savePath} \
	--epochs ${epochs} \
	--batch_size ${batch_size} \
	--optimizer ${optimizer} \
	--lr ${learning_rate} \
	--momentum ${momentum} \
	--beta1 ${beta1} \
	--beta2 ${beta2} \
	--workers ${workers} \
	--checkpoint_file ${checkpoint_file} \
	--printCircle ${printCircle} \
	--data_format ${data_format} \
	--n_channels ${n_channels} \
	--img_height ${img_height} \
	--img_width ${img_width} \
	--layer_loss_weightsMode ${layer_loss_weightsMode} \
	--num_timeSteps ${num_timeSteps} \
	--shuffle ${shuffle}
	# --stack_sizes ${stack_sizes} \
	# --R_stack_sizes ${R_stack_sizes} \
	# --A_filter_sizes ${A_filter_sizes} \
	# --Ahat_filter_sizes ${Ahat_filter_sizes} \
	# --R_filter_sizes ${R_filter_sizes} \


## Viusalization

In [None]:
# %load visualization.py


'''
Usage:
    python visualization.py
'''
import sys
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')

import torch
import torch.nn.functional as F
from torch.autograd import Variable

def sortByVariance(filtersData):
    '''resort the filters by variance.'''
    sumedData = np.sum(filtersData, axis = 3)
    flat = sumedData.reshape(sumedData.shape[0], sumedData.shape[1] * sumedData.shape[2])
    std = np.std(flat, axis = 1)
    order = np.argsort(std)
    filterNum = int(order.shape[0] - (order.shape[0] % 10))     # e.g., 57——>50
    sortedData = np.zeros((filterNum,) + filtersData.shape[1:])
    for i in range(filterNum):
        sortedData[i, :, :, :] = filtersData[order[i], :, :, :]
    return sortedData

def visualize(filtersData, output_figName):
    '''
    visualize the conv1 filters
    filtersData: (filters_num, height, width, 3)
    '''
    print(output_figName)
    filtersData = np.squeeze(filtersData)
    print('after squeeze: ', filtersData.shape)     # (96, 11, 11, 3)

    # normalize filtersData for display
    filtersData = (filtersData - filtersData.min()) / (filtersData.max() - filtersData.min())
    filtersData = sortByVariance(filtersData)
    print('after sorting: ', filtersData.shape)     # (96, 11, 11, 3)

    filters_num = filtersData.shape[0]
    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(filters_num)))
    # add some space between filters
    padding = (((0, 0), (0, 1), (0, 1)) + ((0, 0),) * (filtersData.ndim - 3))   # don't pad the last dimension (if there is one)
    # padding = (((0, 64 - filters_num), (0, 1), (0, 1)) + ((0, 0),) * (filtersData.ndim - 3))   # don't pad the last dimension (if there is one)
    print(padding)  # ((0, 0), (0, 1), (0, 1), (0, 0))
    filtersData = np.pad(filtersData, padding, mode = 'constant', constant_values = 1)  # pad with ones (white)
    print('after padding: ', filtersData.shape)     # (96, 12, 12, 3)
    # tile the filters into an image
    filtersData = filtersData.reshape((5, 10) + filtersData.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, filtersData.ndim + 1)))
    print('after reshape1: ', filtersData.shape)    # (6, 12, 16, 12, 3)
    # filtersData = filtersData.reshape((8, 8) + filtersData.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, filtersData.ndim + 1)))
    filtersData = filtersData.reshape((5 * filtersData.shape[1], 10 * filtersData.shape[3]) + filtersData.shape[4:])
    print('after reshape2: ', filtersData.shape)    # (72, 192, 3)
    # filtersData = filtersData.reshape((8 * filtersData.shape[1], 8 * filtersData.shape[3]) + filtersData.shape[4:])
    
    plt.imshow(filtersData)
    plt.axis('off')
    plt.savefig(output_figName, bbox_inches = 'tight')

def get_filtersData(checkpoint_file):
    '''get the filters data from checkpoint file.'''
    checkpoint = torch.load(checkpoint_file)
    stateDict = checkpoint['state_dict']
    ## debug
    # for k, v in stateDict.items():
    #     print(k)
    conv1_filters = stateDict['feature.0.weight']
    conv1_filters = conv1_filters.cpu().numpy() # if no `.cpu()`: RuntimeError: can't convert CUDA tensor to numpy (it doesn't support GPU arrays). Use .cpu() to move the tensor to host memory first.
    conv1_filters = conv1_filters.transpose(0, 2, 3, 1)
    # print(conv1_filters.shape)  # (96, 11, 11, 12)
    return conv1_filters

def visualize_layer2(filtersData, output_figName):
    '''A.2.weight'''
    filtersData = np.squeeze(filtersData)
    print('after squeeze: ', filtersData.shape)

    # normalize filtersData for display
    filtersData = (filtersData - filtersData.min()) / (filtersData.max() - filtersData.min())

    sumedData = np.sum(filtersData, axis = 3)
    flat = sumedData.reshape(sumedData.shape[0], sumedData.shape[1] * sumedData.shape[2])
    std = np.std(flat, axis = 1)
    order = np.argsort(std)
    # filterNum = int(order.shape[0] - (order.shape[0] % 10))
    sortedData = np.zeros(filtersData.shape)
    for i in range(filtersData.shape[0]):
        sortedData[i, :, :, :] = filtersData[order[i], :, :, :]
    filtersData = sortedData
    print('after sorting: ', filtersData.shape)

    filters_num = filtersData.shape[0]
    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(filters_num)))
    # add some space between filters
    padding = (((0, 0), (0, 1), (0, 1)) + ((0, 0),) * (filtersData.ndim - 3))   # don't pad the last dimension (if there is one)
    # padding = (((0, 64 - filters_num), (0, 1), (0, 1)) + ((0, 0),) * (filtersData.ndim - 3))   # don't pad the last dimension (if there is one)
    print(padding)  # ((0, 0), (0, 1), (0, 1), (0, 0))
    filtersData = np.pad(filtersData, padding, mode = 'constant', constant_values = 1)  # pad with ones (white)
    print('after padding: ', filtersData.shape)     # (96, 12, 12, 3)
    # tile the filters into an image
    filtersData = filtersData.reshape((3, 16) + filtersData.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, filtersData.ndim + 1)))
    print('after reshape1: ', filtersData.shape)    # (6, 12, 16, 12, 3)
    # filtersData = filtersData.reshape((8, 8) + filtersData.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, filtersData.ndim + 1)))
    filtersData = filtersData.reshape((3 * filtersData.shape[1], 16 * filtersData.shape[3]) + filtersData.shape[4:])
    print('after reshape2: ', filtersData.shape)    # (72, 192, 3)
    # filtersData = filtersData.reshape((8 * filtersData.shape[1], 8 * filtersData.shape[3]) + filtersData.shape[4:])
    
    plt.imshow(filtersData)
    plt.axis('off')
    plt.savefig(output_figName, bbox_inches = 'tight')



if __name__ == '__main__':
    state_dict_file = 'C:\\Users\\kirub\\Desktop\\Summer project 2021\\PredNet_pytorch-master\\model_data_keras2\\preTrained_weights_forPyTorch.pkl'
    stateDict = torch.load(state_dict_file)
    modules = ['A', 'Ahat', 'c', 'f', 'i', 'o']
    # for m in modules:
    #     # kernel = stateDict[m + '.0.weight'].cpu().numpy()
    #     kernel = stateDict[m + '.0.weight'].cpu()
    #     # print(kernel.shape)
    #     # A: (48, 6, 3, 3)
    #     # Ahat: (3, 3, 3, 3)
    #     # c、f、i、o: (3, 57, 3, 3)
    #     # kernel = F.upsample(input = Variable(kernel), scale_factor = 2, mode = 'nearest')
    #     # kernel = F.upsample(input = Variable(kernel), scale_factor = 4, mode = 'nearest')
    #     # kernel = F.upsample(input = Variable(kernel), scale_factor = 2, mode = 'bilinear')
    #     # kernel = F.upsample(input = Variable(kernel), scale_factor = 4, mode = 'bilinear')
    #     # kernel = F.upsample(input = Variable(kernel), scale_factor = 2, mode = 'linear')  # 不行, linear只接受3D输入
    #     print(kernel.data.size())
    #     kernel = kernel.data.numpy()
    #     kernel = np.transpose(kernel, (1, 2, 3, 0))
    #     if m in ['c', 'f', 'i', 'o']:
    #         visualize(kernel, './conv1_filters/' + m + '.png')

    # kernel = stateDict['A.2.weight'].cpu()  # (96, 96, 3, 3)
    kernel = stateDict['Ahat.2.weight'].cpu()  # (48, 48, 3, 3)
    kernel = F.upsample(input = Variable(kernel), scale_factor = 4, mode = 'bilinear')
    kernel = kernel.data.numpy()
    kernel = np.transpose(kernel, (1, 2, 3, 0))[..., :3]    # orz...原来有96个'RGB通道', 无法显示成图像, 人为截取前三维
    print('before calling visualization func: ', kernel.shape)
    visualize_layer2(kernel, './conv1_filters/Ahat.2.kernel.png')


## Preprocessed KITTI data

In [56]:
#!/bin/bash
savedir="kitti_data"
!mkdir-p--"$savedir"

In [39]:
!wget https://www.dropbox.com/s/rpwlnn6j39jjme4/kitti_data.zip?dl=0 -O $savedir/prednet_kitti_data.zip --no-check-certificate

--2021-06-30 21:17:38--  https://www.dropbox.com/s/rpwlnn6j39jjme4/kitti_data.zip?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 2620:100:6031:18::a27d:5112, 162.125.81.18
Connecting to www.dropbox.com (www.dropbox.com)|2620:100:6031:18::a27d:5112|:443... connected.
  Unable to locally verify the issuer's authority.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /s/raw/rpwlnn6j39jjme4/kitti_data.zip [following]
--2021-06-30 21:17:39--  https://www.dropbox.com/s/raw/rpwlnn6j39jjme4/kitti_data.zip
Reusing existing connection to [www.dropbox.com]:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc1a2becd9268443431c961cbed1.dl.dropboxusercontent.com/cd/0/inline/BRbRuRFMIokZwa2-nmkwfM9zVIgTNr12FzrWz_JNbD3IGXTQbGNxtD9ZmS37DGhWnSznFpId5DIRcjKf7olCQNhx-YFetYl80vMUGSy3uibwe5RoWb-JMDzcJdss_oA96S6dfuXrog3EfMOvsDmqM4tU/file# [following]
--2021-06-30 21:17:39--  https://uc1a2becd9268443431c961cbed1.dl.dropboxusercontent.com/cd/0/inline/BRbRuRFMI

 41600K .......... .......... .......... .......... ..........  2% 28.3M 14m50s
 41650K .......... .......... .......... .......... ..........  2% 23.0M 14m49s
 41700K .......... .......... .......... .......... ..........  2% 1.17M 14m50s
 41750K .......... .......... .......... .......... ..........  2% 2.25M 14m50s
 41800K .......... .......... .......... .......... ..........  2% 3.69M 14m50s
 41850K .......... .......... .......... .......... ..........  2% 3.02M 14m50s
 41900K .......... .......... .......... .......... ..........  2% 4.94M 14m49s
 41950K .......... .......... .......... .......... ..........  2% 3.30M 14m49s
 42000K .......... .......... .......... .......... ..........  2% 40.7M 14m48s
 42050K .......... .......... .......... .......... ..........  2% 45.6M 14m46s
 42100K .......... .......... .......... .......... ..........  2% 40.3M 14m45s
 42150K .......... .......... .......... .......... ..........  2% 38.1M 14m44s
 42200K .......... .......... ..........