In [1]:
import argparse
import os
import numpy as np
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torch.nn.utils.spectral_norm as spectralnorm
import torchvision
import matplotlib.pyplot as plt

from models import *
from utils import *
from helper import *

In [2]:
!nvidia-smi

Tue Jan 30 14:18:18 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            On   | 00000000:07:00.0 Off |                    0 |
| N/A   69C    P8    19W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla T4            On   | 00000000:87:00.0 Off |                    0 |
| N/A   33C    P8    14W /  70W |      0MiB / 15109MiB |      0%      Default |
|       

In [3]:
device = torch.device('cuda' if torch.cuda.is_available else cpu)

In [4]:
def set_models_state(list_models, state):
    if state =='train':
        for model in list_models:
            model.train()
    else:
        for model in list_models:
            model.eval()

def set_opt_zero(opts):
    for opt in opts:
        opt.zero_grad()
        
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0),  1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake[:,0],
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

def cal_W1(ssf, encoder, decoder, decoder_hat, discriminator, discriminator_M, test_loader, list_models):
    mse_loss = nn.MSELoss(reduction='sum')
    mse_avg = nn.MSELoss()
    set_models_state(list_models, 'eval')

    W1_distance = []
    W1M_distance = []
    MSE = []

    num_x = 0
    for i, x in enumerate(iter(test_loader)):
        with torch.no_grad():
            #Get the data
            x = x.permute(0, 4, 1, 2, 3)
            x = x.cuda().float()
            x_cur = x[:,:,1,...]
            with torch.no_grad():
                hx = encoder(x[:,:,0,...])[0]
                x_ref = decoder(hx).detach()
                x_1_hat = decoder_hat(hx).detach()
            #x_ref[x_ref < 0.1] = 0.0
            x_hat = ssf(x_cur, x_ref, x_1_hat)


            fake_vid = torch.cat((x_1_hat, x_hat), dim = 1).detach()
            real_vid = x[:,0,:2,...].detach() #this looks good!

            fake_validity = discriminator(fake_vid)
            real_validity = discriminator(real_vid)

            fake_img = x_hat.detach()
            real_img = x[:,0,6:7,...].detach()
            fake_valid_m = discriminator_M(fake_img)
            real_valid_m = discriminator_M(real_img)

            W1_distance.append(torch.sum(real_validity) - torch.sum(fake_validity))
            W1M_distance.append(torch.sum(real_valid_m) - torch.sum(fake_valid_m))
            #print (F.mse_loss(x[:,:,1,:,:], x_hat)* x.size()[0])
            MSE.append(mse_loss(x[:,:,1,:,:], x_hat))
            num_x += len(x)

    W1_distance = torch.Tensor(W1_distance)
    W1M_distance = torch.Tensor(W1M_distance)
    MSE = torch.Tensor(MSE)

    return W1M_distance.sum()/num_x, W1_distance.sum()/num_x, MSE.sum()/(64*64*num_x)

def cal_W1_MMSE(ssf, encoder, decoder, discriminator, discriminator_M, test_loader, list_models):
    mse_loss = nn.MSELoss(reduction='sum')
    mse_avg = nn.MSELoss()
    set_models_state(list_models, 'eval')

    W1_distance = []
    W1M_distance = []
    MSE = []

    num_x = 0
    for i, x in enumerate(iter(test_loader)):
        with torch.no_grad():
            #Get the data
            x = x.permute(0, 4, 1, 2, 3)
            x = x.cuda().float()
            x_cur = x[:,:,1,...]
            with torch.no_grad():
                x_ref = decoder(encoder(x[:,:,0,...])[0]).detach()
            x_hat = ssf(x_cur, x_ref)
            MSE.append(mse_loss(x[:,:,1,:,:], x_hat))
            num_x += len(x)
    MSE = torch.Tensor(MSE)

    return MSE.sum()/(64*64*num_x)

In [13]:
def main(dim = 128,
        z_dim = 1,
        lambda_gp = 10,
        bs = 64,
        d_penalty = 0,
        skip_fq = 5,
        total_epochs = 10,
        lambda_P = 0,
        lambda_PM = 0,
        lambda_MSE = 0,
        L = 2,
        path = './data/',
        pre_path = 'None'):
        
    #No quantization:
    stochastic = True
    quantize_latents = True
    if L == -1:
        stochastic = False
        quantize_latents = False
    print ('Stochastic: ', stochastic)
    print ('Quantize: ', quantize_latents)
    #Create folder:
    folder_name='New_R1eps_dim_'+str(dim)+'|z_dim_'+str(z_dim)+'|L_'+str(L)+'|lambda_gp_'+str(lambda_gp) \
        +'|bs_'+str(bs)+'|dpenalty_'+str(d_penalty)+'|lambdaP_'+str(lambda_P)+'|lambdaPM_'+str(lambda_PM)+'|lambdaMSE_' + str(lambda_MSE)
    print ("Settings: ", folder_name)

    os.makedirs('./saved_models/'+ folder_name, exist_ok=True)
    f = open('./saved_models/'+ folder_name + "/performance.txt", "a")

    #Define Models
    discriminator = Discriminator_v3(out_ch=2) #Generator Side
    discriminator_M = Discriminator_v3(out_ch=1) #Marginal Discriminator
    ssf = ScaleSpaceFlow_R1eps(num_levels=1, dim=z_dim, stochastic=stochastic, quantize_latents=quantize_latents, L=L)

    list_models = [discriminator, discriminator_M, ssf]

    ssf.cuda()
    discriminator.cuda()
    discriminator_M.cuda()
    

    #Load models:
    if pre_path != 'None':
        ssf.motion_encoder.load_state_dict(torch.load(pre_path+'/m_enc.pth'))
        ssf.motion_decoder.load_state_dict(torch.load(pre_path+'/m_dec.pth'))
        ssf.P_encoder.load_state_dict(torch.load(pre_path+'/p_enc.pth'))
        ssf.res_encoder.load_state_dict(torch.load(pre_path+'/r_enc.pth'))
        ssf.res_decoder.load_state_dict(torch.load(pre_path+'/r_dec.pth'))
        discriminator.load_state_dict(torch.load(pre_path+'/discriminator.pth'))
        discriminator_M.load_state_dict(torch.load(pre_path+'/discriminator_M.pth'))
    


    #Define fixed model
    I_dim = 12 #12 #8
    I_L = 2
    encoder = Encoder(dim=I_dim, nc=1, stochastic=True, quantize_latents=True, L=I_L) #Generator Side
    decoder = Decoder_Iframe(dim=I_dim) #Generator Side
    decoder_hat = Decoder_Iframe(dim=I_dim)

    encoder.cuda()
    decoder.cuda()
    decoder_hat.cuda()

    encoder.eval()
    decoder.eval()
    decoder_hat.eval()
    encoder.load_state_dict(torch.load('./I3/I_frame_encoder_zdim_12_L_2.pth'))
    decoder.load_state_dict(torch.load('./I3/I_frame_decoderMMSE_zdim_12_L_2.pth'))
    decoder_hat.load_state_dict(torch.load('./I3/I_frame_decoder_zdim_12_L_2.pth'))

    #Define Data Loader
    train_loader, test_loader = get_dataloader(data_root=path, seq_len=8, batch_size=bs, num_digits=1)
    loader_l = len(train_loader)
    mse = torch.nn.MSELoss()

    #discriminator.train()
    opt_ssf= torch.optim.RMSprop(ssf.parameters(), lr=1e-5)
    opt_d = torch.optim.RMSprop(discriminator.parameters(), lr=5e-5)
    opt_dm = torch.optim.RMSprop(discriminator_M.parameters(), lr=5e-5)

    list_opt = [opt_ssf, opt_d, opt_dm]

    
    for epoch in range(total_epochs):
        set_models_state(list_models, 'train')
        a1 = time.time()
        for i,x in enumerate(train_loader):
            print(f'{i}/{loader_l}')
            #Set 0 gradient
            set_opt_zero(list_opt)
            
            #Get the data
            x = x.permute(0, 4, 1, 2, 3)
            x = x.cuda().float()
            x_cur = x[:,:,1,...]
            with torch.no_grad():
                hx = encoder(x[:,:,0,...])[0]
                x_ref = decoder(hx).detach()
                x_1_hat = decoder_hat(hx).detach()
            #x_ref[x_ref < 0.1] = 0.0
            x_hat = ssf(x_cur, x_ref, x_1_hat)



            #Optimize discriminator
            fake_vid = torch.cat((x_1_hat, x_hat), dim = 1)
            real_vid = x[:,0,:2,...].detach() #this looks good!
            fake_validity = discriminator(fake_vid.detach())
            real_validity = discriminator(real_vid)
            gradient_penalty = compute_gradient_penalty(discriminator, real_vid.data, fake_vid.data)
            errVD =  -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
            errVD.backward()
            opt_d.step()
            
            #Optimize discriminator M
            fake_img = x_hat.detach()
            real_img = x[:,0,:1,...].detach()
            fake_valid_m = discriminator_M(fake_img)
            real_valid_m = discriminator_M(real_img)
            gradient_penalty_m = compute_gradient_penalty(discriminator_M, fake_img.data, real_img.data)
            errID =  -torch.mean(real_valid_m) + torch.mean(fake_valid_m) + lambda_gp * gradient_penalty_m
            errID.backward()
            opt_dm.step()
            
            
            if i%skip_fq == 0:
                x_cur = x_cur.detach()
                x_ref = x_ref.detach()
                x_1_hat = x_1_hat.detach()
                x_hat = ssf(x_cur, x_ref, x_1_hat)

                fake_vid = torch.cat((x_1_hat, x_hat), dim = 1)
                fake_validity = discriminator(fake_vid)
                errVG = -torch.mean(fake_validity)

                fake_img = x_hat
                fake_validity_im = discriminator_M(fake_img)
                errIG = -torch.mean(fake_validity_im)

                loss = lambda_MSE*mse(x_hat, x_cur) + lambda_P*errVG + + lambda_PM*errIG
                loss.backward()

                opt_ssf.step()
        a = time.time()
        if epoch %10 == 0:
            show_str= "Epoch: "+ str(epoch) + "l_PM, l_P, l_MSE, d_penalty " + str(lambda_PM) + str(lambda_P)+ " " \
            +str(lambda_MSE) + " " + str(d_penalty) + " P loss: " + str(cal_W1(ssf, encoder, decoder, decoder_hat, discriminator, discriminator_M, test_loader, list_models))
            print (show_str)
            f.write(show_str+"\n")
        b = time.time()
        print('test time:', b-a)
            
        b1 = time.time()
        print(f'epoch: {b1-a1}')
    #show_str= "Epoch: "+ str(epoch) + "l_PM, l_P, l_MSE, d_penalty " + str(lambda_PM) + str(lambda_P)+ " " \
    #        +str(lambda_MSE) + " " + str(d_penalty) + " P loss: " + str(cal_W1(ssf, encoder, decoder, decoder_hat, discriminator, discriminator_M, test_loader, list_models))
    #print (show_str)
    #f.write(show_str+"\n")


    set_models_state(list_models, 'eval')
    a = time.time()
    torch.save(ssf.motion_encoder.state_dict(), os.path.join("./saved_models/" + folder_name, 'm_enc.pth'))
    torch.save(ssf.motion_decoder.state_dict(), os.path.join("./saved_models/" + folder_name, 'm_dec.pth'))
    torch.save(ssf.P_encoder.state_dict(), os.path.join("./saved_models/" + folder_name, 'p_enc.pth'))
    torch.save(ssf.res_encoder.state_dict(), os.path.join("./saved_models/" + folder_name, 'r_enc.pth'))
    torch.save(ssf.res_decoder.state_dict(), os.path.join("./saved_models/" + folder_name, 'r_dec.pth' ))
    torch.save(discriminator.state_dict(), os.path.join("./saved_models/" + folder_name, 'discriminator.pth'))
    torch.save(discriminator_M.state_dict(), os.path.join("./saved_models/" + folder_name, 'discriminator_M.pth'))
    b = time.time()
    print(f'save models: {b-a}')
    f.close()

    #save some figures
    for i,x in enumerate(iter(train_loader)):
        x = x.permute(0, 4, 1, 2, 3)
        x = x.cuda().float()
        break
    np.savez_compressed("./saved_models/" + folder_name+"/x", a=x.detach().cpu().numpy())

    for i in range(5): #generate same figure 5 times
        x_cur = x[:,:,1,...]
        x_ref = x[:,:,0,...]
        x_hat = ssf(x_cur, x_ref)
        np.savez_compressed("./saved_models/" + folder_name+"/x_hat"+str(i), a=x_hat.detach().cpu().numpy())


In [9]:
main(total_epochs = 100, lambda_MSE=1)

Stochastic:  True
Quantize:  True
Settings:  New_R1eps_dim_128|z_dim_1|L_2|lambda_gp_10|bs_64|dpenalty_0|lambdaP_0|lambdaPM_0|lambdaMSE_1
[-1.0, 1.0]
[-1.0, 1.0]
Finished Loading MNIST!
Epoch: 0l_PM, l_P, l_MSE, d_penalty 00 1 0 P loss: (tensor(21.6516), tensor(22.7973), tensor(0.0546))
Epoch: 10l_PM, l_P, l_MSE, d_penalty 00 1 0 P loss: (tensor(9.1763), tensor(9.2369), tensor(0.0165))
Epoch: 20l_PM, l_P, l_MSE, d_penalty 00 1 0 P loss: (tensor(8.7785), tensor(8.9075), tensor(0.0160))
Epoch: 30l_PM, l_P, l_MSE, d_penalty 00 1 0 P loss: (tensor(8.8860), tensor(8.7995), tensor(0.0157))
Epoch: 40l_PM, l_P, l_MSE, d_penalty 00 1 0 P loss: (tensor(8.6874), tensor(9.0857), tensor(0.0154))
Epoch: 50l_PM, l_P, l_MSE, d_penalty 00 1 0 P loss: (tensor(8.5272), tensor(8.7196), tensor(0.0153))
Epoch: 60l_PM, l_P, l_MSE, d_penalty 00 1 0 P loss: (tensor(8.7462), tensor(8.4982), tensor(0.0152))
Epoch: 70l_PM, l_P, l_MSE, d_penalty 00 1 0 P loss: (tensor(8.4312), tensor(8.3757), tensor(0.0151))
Epoch

In [14]:
pre_path = './saved_models/New_R1eps_dim_128|z_dim_1|L_2|lambda_gp_10|bs_64|dpenalty_0|lambdaP_0|lambdaPM_0|lambdaMSE_1'
main(total_epochs = 1, lambda_MSE=1, lambda_P=10e-3, pre_path = pre_path)

Stochastic:  True
Quantize:  True
Settings:  New_R1eps_dim_128|z_dim_1|L_2|lambda_gp_10|bs_64|dpenalty_0|lambdaP_0.01|lambdaPM_0|lambdaMSE_1
[-1.0, 1.0]
[-1.0, 1.0]
Finished Loading MNIST!
0/937
1/937
2/937
3/937
4/937
5/937
6/937
7/937
8/937
9/937
10/937
11/937
12/937
13/937
14/937
15/937
16/937
17/937
18/937
19/937
20/937
21/937
22/937
23/937
24/937
25/937
26/937
27/937
28/937
29/937
30/937
31/937
32/937
33/937
34/937
35/937
36/937
37/937
38/937
39/937
40/937
41/937
42/937
43/937
44/937
45/937
46/937
47/937
48/937
49/937
50/937
51/937
52/937
53/937
54/937
55/937
56/937
57/937
58/937
59/937
60/937
61/937
62/937
63/937
64/937
65/937
66/937
67/937
68/937
69/937
70/937
71/937
72/937
73/937
74/937
75/937
76/937
77/937
78/937
79/937
80/937
81/937
82/937
83/937
84/937
85/937
86/937
87/937
88/937
89/937
90/937
91/937
92/937
93/937
94/937
95/937
96/937
97/937
98/937
99/937
100/937
101/937
102/937
103/937
104/937
105/937
106/937
107/937
108/937
109/937
110/937
111/937
112/937
113/937
114/937
1