In [9]:
import argparse
import math
import random
import shutil
import sys
import os
from collections import defaultdict
from typing import List
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from UVG1 import UVG
from matplotlib import pyplot as plt
import torch.nn.functional as F
import matplotlib.patches as patches
import glob
from ssf_model import ScaleSpaceFlow
#from hific.src.model import Model
import numpy as np
def load_ssf_model(model, pre_path):
    model.motion_encoder.load_state_dict(torch.load(pre_path+'/m_enc.pth'))
    model.motion_decoder.load_state_dict(torch.load(pre_path+'/m_dec.pth'))
    model.P_encoder.load_state_dict(torch.load(pre_path+'/p_enc.pth'))
    model.res_encoder.load_state_dict(torch.load(pre_path+'/r_enc.pth'))
    model.res_decoder.load_state_dict(torch.load(pre_path+'/r_dec.pth'))
    return model

def hwc_tonp(x):
    x = x.detach().cpu().numpy()
    x = x.transpose([0,2,3,1])
    return x

device = torch.device('cuda' if torch.cuda.is_available else cpu)
from torchvision.transforms.functional import resize
!nvidia-smi

Tue Aug  6 17:29:43 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A40          On   | 00000000:D8:00.0 Off |                    0 |
|  0%   45C    P0    91W / 300W |  15720MiB / 46068MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [10]:
# define usefull paths
save_path = '/scratch/ssd004/scratch/joaodick/rebuttal_data/UVG_FT/'
sequence_number = 1000

In [11]:
train_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.RandomCrop(256)]
    )

uvg_dataset = UVG("./data/uvg/", train_transforms)
uvg_dataloader = DataLoader(
        uvg_dataset,
        batch_size=1,
        num_workers=0,
        shuffle=True,
        pin_memory=True,
    )

In [18]:
l_AR = 0.08
ssf_JD  = ScaleSpaceFlow().to(device)
ssf_JD.load_state_dict(torch.load('./saved_models/vimeo-90k/JD/ssf_uvg_JD.pth'))
ssf_AR = ScaleSpaceFlow().to(device)
ssf_AR = load_ssf_model(ssf_AR, f'./saved_models/vimeo-90k/AR_{l_AR}/')
ssf_AR_FT_1 = ScaleSpaceFlow().to(device)
ssf_AR_FT_1 = load_ssf_model(ssf_AR_FT_1, f'./saved_models/vimeo-90k/AR_FT_1_epoch_MSE/')
ssf_AR_FT_2 = ScaleSpaceFlow().to(device)
ssf_AR_FT_2 = load_ssf_model(ssf_AR_FT_2, f'./saved_models/vimeo-90k/AR_FT_2_epochs_MSE/')

In [20]:
for i in range(sequence_number):
    with torch.no_grad():
        data = next(iter(uvg_dataloader))
        x1 = 2*(data[:,0,...]-0.5)
        x2 = 2*(data[:,1,...]-0.5)
        x3 = 2*(data[:,2,...]-0.5)
        x1_hat = 2*(data[:,3,...] -0.5)
        
        #RUN INFERENCE
        x1=x1.to(device)
        x2=x2.to(device)
        x3=x3.to(device)
        
        x1_hat = x1_hat.to(device)
        
        x2_hat_JD = ssf_JD([x1_hat, x2])
        x2_hat_AR = ssf_AR([x1_hat, x2])
        x2_hat_AR_FT_1 = ssf_AR_FT_1([x1_hat, x2])
        x2_hat_AR_FT_2 = ssf_AR_FT_2([x1_hat, x2])
        
        x3_hat_JD = ssf_JD([x2_hat_JD, x3])
        x3_hat_AR = ssf_AR([x2_hat_AR, x3])
        x3_hat_AR_FT_1 = ssf_AR_FT_1([x2_hat_AR, x3])
        x3_hat_AR_FT_2 = ssf_AR_FT_2([x2_hat_AR, x3])
        
        #TRANSFORM BACK TO IMG DOMAIN
        x1_img = (hwc_tonp((resize(x1,(128,128))+1)*0.5))[0]
        x2_img = (hwc_tonp((resize(x2,(128,128))+1)*0.5))[0]
        x3_img = (hwc_tonp((resize(x3,(128,128))+1)*0.5))[0]
        
        x1_hat_img = (hwc_tonp((resize(x1_hat,(128,128))+1)*0.5))[0]
        
        x2_hat_JD_img = (hwc_tonp((resize(x2_hat_JD,(128,128))+1)*0.5))[0]
        x2_hat_AR_img = (hwc_tonp((resize(x2_hat_AR,(128,128))+1)*0.5))[0]
        x2_hat_AR_FT_1_img = (hwc_tonp((resize(x2_hat_AR_FT_1,(128,128))+1)*0.5))[0]
        x2_hat_AR_FT_2_img = (hwc_tonp((resize(x2_hat_AR_FT_2,(128,128))+1)*0.5))[0]
        
        x3_hat_JD_img = (hwc_tonp((resize(x3_hat_JD,(128,128))+1)*0.5))[0]
        x3_hat_AR_img = (hwc_tonp((resize(x3_hat_AR,(128,128))+1)*0.5))[0]
        x3_hat_AR_FT_1_img = (hwc_tonp((resize(x3_hat_AR_FT_1,(128,128))+1)*0.5))[0]
        x3_hat_AR__FT_2_img = (hwc_tonp((resize(x3_hat_AR_FT_2,(128,128))+1)*0.5))[0]
        
        #create directories
        os.makedirs(save_path + f'original/{i}/', exist_ok=True)
        os.makedirs(save_path + f'JD/{i}/', exist_ok=True)
        os.makedirs(save_path + f'AR/{i}/', exist_ok=True)
        os.makedirs(save_path + f'AR_FT_1/{i}/', exist_ok=True)
        os.makedirs(save_path + f'AR_FT_2/{i}/', exist_ok=True)
        #save images
        #original
        np.save(save_path + f'original/{i}/x1.npy', x1_img)
        np.save(save_path + f'original/{i}/x2.npy', x2_img)
        np.save(save_path + f'original/{i}/x3.npy', x3_img)
        #JD
        np.save(save_path + f'JD/{i}/x1.npy', x1_hat_img)
        np.save(save_path + f'JD/{i}/x2.npy', x2_hat_JD_img)
        np.save(save_path + f'JD/{i}/x3.npy', x3_hat_JD_img)
        #AR
        np.save(save_path + f'AR/{i}/x1.npy', x1_hat_img)
        np.save(save_path + f'AR/{i}/x2.npy', x2_hat_AR_img)
        np.save(save_path + f'AR/{i}/x3.npy', x3_hat_AR_img)
        #AR_FT_1
        np.save(save_path + f'AR_FT_1/{i}/x1.npy', x1_hat_img)
        np.save(save_path + f'AR_FT_1/{i}/x2.npy', x2_hat_AR_FT_1_img)
        np.save(save_path + f'AR_FT_1/{i}/x3.npy', x3_hat_AR_FT_1_img)
        #AR_FT_2
        np.save(save_path + f'AR_FT_2/{i}/x1.npy', x1_hat_img)
        np.save(save_path + f'AR_FT_2/{i}/x2.npy', x2_hat_AR_FT_2_img)
        np.save(save_path + f'AR_FT_2/{i}/x3.npy', x3_hat_AR__FT_2_img)
        
        if(i%100 == 0):
            print(i)
        

0
100
200
300
400
500
600
700
800
900


In [14]:
path = '/scratch/ssd004/scratch/joaodick/rebuttal_data/UVG/'
os.makedirs(path + 'original_png/', exist_ok=True)
for i in range(sequence_number):
    x1 = np.load(path + f'original/{i}/x1.npy')
    x2 = np.load(path + f'original/{i}/x2.npy')
    x3 = np.load(path + f'original/{i}/x3.npy')

    plt.imsave(path + f'original_png/x1_{i}.png', x1)
    plt.imsave(path + f'original_png/x2_{i}.png', x2)
    plt.imsave(path + f'original_png/x3_{i}.png', x3)


In [24]:
for i in range(sequence_number):
    x1 = plt.imread(path + f'out/x1_{i}_compressed.png')
    x2 = plt.imread(path + f'out/x2_{i}_compressed.png')
    x3 = plt.imread(path + f'out/x3_{i}_compressed.png')
    
    os.makedirs(save_path + f'FMD/{i}/', exist_ok=True)
    np.save(path + f'FMD/{i}/x1.npy', x1)
    np.save(path + f'FMD/{i}/x2.npy', x2)
    np.save(path + f'FMD/{i}/x3.npy', x3)


In [8]:
path = '/scratch/ssd004/scratch/joaodick/rebuttal_data/UVG/'
for i in range(sequence_number):
    with torch.no_grad():
        x1 = torch.from_numpy(np.load(path + f'original/{i}/x1.npy')).to(device)
        x2 = torch.from_numpy(np.load(path + f'original/{i}/x2.npy')).to(device)
        x3 = torch.from_numpy(np.load(path + f'original/{i}/x3.npy')).to(device)
        
        #RUN INFERENCE
        print(x1.shape)
        break
        '''
        x1_hat = x1_hat.to(device)
        x2_hat_JD = ssf_JD([x1_hat, x2])
        x2_hat_AR = ssf_AR([x1_hat, x2])
        x3_hat_JD = ssf_JD([x2_hat_JD, x3])
        x3_hat_AR = ssf_AR([x2_hat_AR, x3])
        
        #TRANSFORM BACK TO IMG DOMAIN
        x1_img = (hwc_tonp((resize(x1,(128,128))+1)*0.5))[0]
        x2_img = (hwc_tonp((resize(x2,(128,128))+1)*0.5))[0]
        x3_img = (hwc_tonp((resize(x3,(128,128))+1)*0.5))[0]
        x1_hat_img = (hwc_tonp((resize(x1_hat,(128,128))+1)*0.5))[0]
        x2_hat_JD_img = (hwc_tonp((resize(x2_hat_JD,(128,128))+1)*0.5))[0]
        x2_hat_AR_img = (hwc_tonp((resize(x2_hat_AR,(128,128))+1)*0.5))[0]
        x3_hat_JD_img = (hwc_tonp((resize(x3_hat_JD,(128,128))+1)*0.5))[0]
        x3_hat_AR_img = (hwc_tonp((resize(x3_hat_AR,(128,128))+1)*0.5))[0]
        #create directories
        os.makedirs(save_path + f'original/{i}/', exist_ok=True)
        os.makedirs(save_path + f'JD/{i}/', exist_ok=True)
        os.makedirs(save_path + f'AR/{i}/', exist_ok=True)
        #save images
        #original
        np.save(save_path + f'original/{i}/x1.npy', x1_img)
        np.save(save_path + f'original/{i}/x2.npy', x2_img)
        np.save(save_path + f'original/{i}/x3.npy', x3_img)
        #JD
        np.save(save_path + f'JD/{i}/x1.npy', x1_hat_img)
        np.save(save_path + f'JD/{i}/x2.npy', x2_hat_JD_img)
        np.save(save_path + f'JD/{i}/x3.npy', x3_hat_JD_img)
        #AR
        np.save(save_path + f'AR/{i}/x1.npy', x1_hat_img)
        np.save(save_path + f'AR/{i}/x2.npy', x2_hat_AR_img)
        np.save(save_path + f'AR/{i}/x3.npy', x3_hat_AR_img)
        
        if(i%100 == 0):
            print(i)
        '''

torch.Size([128, 128, 3])
