In [None]:
import sys
sys.path.append('..')
from video.dataset import MnistVideoCodeLMDBDataset
from video.dataloader import video_mnist_dataloader
from torchvision import utils
import numpy as np
from video.LSTM3 import LSTM3
from video.LSTM_PixelSnail import LSTM_PixelSnail
from image.modified.m_vqvae import VQVAE_1
from torch import nn
import torch
from matplotlib import pyplot as plt
from image.modified.m_pixelsnail import PixelSNAIL
from tqdm import tqdm


In [None]:
lambda_name = 'vqvae_videomnist_2_00099'
vqvae_ckpt_path = '../video/checkpoints/videomnist/vqvae/1/00099.pt'

input_channel = 16
hidden_channel = input_channel
epoch_num = 100
batch_size = 8
device = 'cuda'
lr = 0.0004
run_num = 1
image_samples = 1
frame_learn = 8
frame_pred = 15

In [None]:
dataset = MnistVideoCodeLMDBDataset(lambda_name, 20)
loader = video_mnist_dataloader(dataset, batch_size, shuffle=False)


In [None]:
vqvae_model = VQVAE_1(in_channel=1,
            channel=32,
            n_res_block=4,
            n_res_channel=16,
            embed_dim=16,
            n_embed=input_channel,
            decay=0.99, )
vqvae_model = nn.DataParallel(vqvae_model)
vqvae_model.load_state_dict(torch.load(vqvae_ckpt_path))
vqvae_model = vqvae_model.to(device)



In [None]:
videomnist_path = '../video/datasets/mnist/moving_mnist/mnist_test_seq.npy'
orginal_frames = np.load(videomnist_path)
orginal_frames = orginal_frames.swapaxes(0, 1).astype(np.float32)
orginal_frames[orginal_frames > 0] = 1.


In [None]:
input_size = (16,16)
input_channel = 16
device = 'cuda'

hidden_channel = 256
cnn_channel = 256
channel = 256
cnn_kernel_size = 5
kernel_size = 5
n_block = 3
n_res_block = 3
n_res_channel = 246
dropout = 0.1
n_out_res_block = 3
n_cond_res_block = 3
cond_res_channel = 256


In [None]:
lstm_model = LSTM3( input_channel= input_channel,hidden_channel= hidden_channel, device=device)

cnn_model = nn.Conv2d(hidden_channel,
            cnn_channel,
            cnn_kernel_size,
            stride=1,
            padding=cnn_kernel_size // 2,)
            
pixel_model = PixelSNAIL(
            shape = [input_size[0], input_size[1]],
            n_class = input_channel,
            cond_channel = cnn_channel,
            channel = channel,
            kernel_size = kernel_size,
            n_block = n_block,
            n_res_block = n_res_block,
            res_channel = n_res_channel,
            dropout=dropout,
            n_out_res_block=n_out_res_block,
            n_cond_res_block=n_cond_res_block,
            cond_res_channel=cond_res_channel,
        )

In [None]:
lstmpixelsnail_ckpt_path = '../video/checkpoints/videomnist/vqvae-lstm-pixelsnail/1/00013.pt'

lstmpixelsnail_model = LSTM_PixelSnail(lstm_model,cnn_model,pixel_model)
lstmpixelsnail_model = nn.DataParallel(lstmpixelsnail_model)

lstmpixelsnail_model.load_state_dict(torch.load(lstmpixelsnail_ckpt_path))

lstmpixelsnail_model = lstmpixelsnail_model.to(device)



In [None]:
def get_vqvae_decode(sample):
    sample = vqvae_model.module.decode_code(sample)
    sample = sample.cpu().detach()
    sample = sample.squeeze()
    sample = (sample > 0.5).float()
    return sample

In [None]:
def show_sample(sample):
    sample = get_vqvae_decode(sample)
    plt.imshow(sample[0,:,:])
    plt.show()


In [None]:
def model_learn(lstm_model, cnn_model,pixel_model , inputs, cells_state = None):
    
    outs = []
    size = inputs.size()
    for i in range(size[1]):
        lstm_out, cells_state = lstm_model(inputs[:,i,:,:,:], cells_state) 
        outs.append(lstm_out)
#     state = []
#     for state in cells_state:
#         states.append([state[0].detach(),state[1].detach()])
        
    return outs, cells_state

In [None]:
def _to_one_hot(y, num_classes):
    scatter_dim = len(y.size())
    y_tensor = y.view(*y.size(), -1)
    zeros = torch.zeros(*y.size(), num_classes, dtype=y.dtype)

    return zeros.scatter(scatter_dim, y_tensor, 1).permute(0, 3, 1, 2)


In [None]:
def one_hot_to_int(y):
    print('y {}'.format(y.size()))

    y_trans = y.permute(0, 2, 3, 1)
    y_trans = y_trans.argmax(dim=-1)
    return y_trans


In [None]:
def callback(sample,frame):
    
    with torch.no_grad():
        vqvae_model.eval()
        sample_decode = get_vqvae_decode(sample)
        frame_decode = get_vqvae_decode(frame)
#         sample = vqvae_model.module.decode_code(sample)
#         sample = sample.cpu().detach()
#         sample = sample.squeeze()
#         sample = (sample > 0.5).float()
        
#         frame = vqvae_model.module.decode_code(frame)
#         frame = frame.cpu().detach()
#         frame = frame.squeeze()
#         frame = (frame > 0.5).float()
    
#         merge = torch.cat([sample,frame], 0)
#         utils.save_image(
#             merge,
#             'nframe_pred_{}.png'.format(*[run_num]),
#             nrow=len_pred,
# #             normalize=False,
# #             range=(-1, 1),
#         )
#         img = plt.imread('nframe_pred_{}.png'.format(*[run_num]))
        plt.imshow(sample_decode[0,:,:])
        plt.show()


In [None]:
def get_pixel_snail_out(cnn_out, size,temperature=1.0):
    cache = {}
    row = torch.zeros( *size).to('cuda')
    for i in tqdm(range(size[2])):
        for j in range(size[3]):
            out, cache = pixel_model(row[: ,:, : i + 1, :], condition=cnn_out, cache=cache)
            prob = torch.softmax(out[:, :, i, j] / temperature, 1)
            sample = torch.argmax(prob, 1, keepdim=False)
            sample = cuda_to_one_hot(sample, size[1]).float()

            row[:,:, i, j] = sample
    return row

In [None]:


def cuda_to_one_hot(y, num_classes):
    scatter_dim = len(y.size())
    y_tensor = y.view(*y.size(), -1)
    zeros = torch.zeros(*y.size(), num_classes, dtype=y.dtype).to('cuda')

    return zeros.scatter(scatter_dim, y_tensor, 1)

def get_sample(lstm_model, cnn_model,pixel_model ,  input_ , cells_state, temperature=1.0):
    
    size = input_.size()
    

    lstm_out, cells_state = lstm_model(input_, cells_state)
    
    cnn_out = cnn_model(lstm_out)
    raw = get_pixel_snail_out(cnn_out, size)

    return row, cells_state


In [None]:
def visual(outs,ins):
    size = ins[:,0,:,:,:].size()
    cnn_model.eval()
    pixel_model.eval()
    with torch.no_grad():
        for out in outs:
            cnn_out = cnn_model(out)
            raw = get_pixel_snail_out(cnn_out, size) 
            sample_decode = get_vqvae_decode(one_hot_to_int(raw))
            plt.imshow(sample_decode[0,:,:])
            plt.show()


In [None]:
lstmpixelsnail_model.eval()
preds =[]
torch.backends.cudnn.enabled = False
lstm_model.eval()
cnn_model.eval()
pixel_model.eval() 
with torch.no_grad():
    for iter_, (frames, video_inds, frame_inds) in enumerate(loader):
        if iter_ ==10:
            inputs_ = []
            f0 = torch.zeros(frames.shape[0], 1, input_channel,frames.shape[2], frames.shape[3])
            f0 = f0.to(device)
            inputs_.append(f0)
            
            for i in range(frames.shape[1]):

                input_ = _to_one_hot(frames[:, i, :, :], input_channel).float()
                input_ = input_.to(device)
                inputs_.append(input_.unsqueeze(dim=1))

            inputs_ = torch.cat(inputs_, dim=1)
            
            
            outs, states = model_learn(lstm_model, cnn_model,pixel_model , inputs_[:,:frame_learn,:,:,:] )
            visual(outs,inputs_)
#             samples = None
#             sample = inputs_[:,frame_learn,:,:,:]
#             for i in range(frame_learn,frame_pred+frame_learn):
#                 new_sample,states = get_sample(lstm_model, cnn_model,pixel_model ,sample , states )
#                 pred = one_hot_to_int(new_sample)
#                 print(sample.size())
#                 print(new_sample.size())
#                 frames = frames.to('cuda')
#                 callback(pred[:,:,:],frames[:,i,:,:])
                
#                 if samples == None:
#                     samples = sample.unsqueeze(1)
#                 else:
#                     samples = torch.cat((samples,sample.unsqueeze(1)),dim=1)
                
#                 preds.append(sample)
                
# #                 states = model_learn(lstm_model, cnn_model,pixel_model , new_sample.unsqueeze(dim=0) ,states)
                
#                 sample = new_sample
                   

In [None]:
sample_int = one_hot_to_int(samples[0,:,:,:])
torch.backends.cudnn.enabled = False
callback(sample_int,frames[0,frame_learn:frame_pred+frame_learn,:,:],frame_pred)

In [None]:
print(samples[0,:,:,:].shape)

In [None]:
_input = inputs_[:,1,:,:,:]
# _input =_input.to('cuda')
# one = _to_one_hot(_input,16)
rel = one_hot_to_int(_input)
print(rel)
print(frames[:, 0, :, :])
# print(_input.shape)
# for i in range(_input.shape[1]):
#     print(_input[:,i,:,:])

In [None]:
a = torch.randn(4, 4)
print(a)
torch.argmax(a, dim=0,keepdim=True)