In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# ***IMPORTS BLOCK***

In [None]:
import torch
import torch.optim as optim
import time
import os
import argparse
import torch.nn as nn
import torchvision.models as models
import math
import numpy as np

from utils.visualization import visualize_hand
from utils.general import set_seed, get_device, print_data_info, print_model_info, save_training_arguments, logger, save_checkpoint, load_checkpoint
from utils.data import generate_data

# ***MODEL UTILS BLOCK:***

In [None]:
def load_encoder(latent_space_dim):
    encoder = rgb_encoder(latent_space_dim) 
    return(encoder)

class rgb_encoder(nn.Module):
    def __init__(self, latent_space_dim):
        super(rgb_encoder, self).__init__()
        self.latent_space_dim = latent_space_dim
        #(1) Load ResNet-18, pretrained in Imagenet from torchvision.models and remove the last fully connected layer
        #(2) Build a new last fully connected layer, that takes 512 features + 2 (handedness indicator)and returns 2xlatent space dimensionality (mean and log-variance of the predicted distribution) 
        model = models.resnet18(pretrained=True)     
        self.encoder =nn.Sequential(*(list(model.children())[:-1]))
        self.fc = nn.Linear(in_features=514,out_features=2*self.latent_space_dim) 
    def forward(self, image, hand_side):
        x = self.encoder(image).squeeze()
        z = torch.cat((x.float(),hand_side.float()),dim=1)
        x = self.fc(z)
        return x[:, self.latent_space_dim:], x[:, :self.latent_space_dim]

def load_decoder(latent_space_dim, num_of_joints):
    decoder = pose_decoder(latent_space_dim, num_of_joints)
    return decoder

class pose_decoder(nn.Module):
    def __init__(self, latent_space_dim, joint_num):
        super(pose_decoder, self).__init__()
        self.joint_num = torch.IntTensor(joint_num) # [21,3]
        self.in_size = self.joint_num.prod() # 63
        #(1): Build the pose decoder
        self.lin_lays = nn.Sequential(
            #first layer
            nn.Linear(in_features=64,out_features=512),
            nn.BatchNorm1d(num_features=512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            #second layer
            nn.Linear(in_features=512,out_features=512),
            nn.BatchNorm1d(num_features=512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            #third layer

            nn.Linear(in_features=512,out_features=512),
            nn.BatchNorm1d(num_features=512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            #fourth layer           
            nn.Linear(in_features=512,out_features=512),
            nn.BatchNorm1d(num_features=512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            #fifth layer
            nn.Linear(in_features=512,out_features=512),
            nn.BatchNorm1d(num_features=512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            #last layer
            nn.Linear(in_features=512,out_features=int(self.in_size))

        )

    
    def forward(self, sample):
        out_lays = self.lin_lays(sample)
        return out_lays.view(-1, self.joint_num[0], self.joint_num[1])



# ***CROSS MODEL BLOCK***
This block creates the VAE model using the above encoder/decoder


In [None]:
class VAE_model(nn.Module):
    def __init__(self, rgb_encoder, pose_decoder, num_of_joints, latent_space_dim):
        super(VAE_model, self).__init__()
        self.encoder = rgb_encoder
        self.decoder = pose_decoder
        self.latent_space_dim =latent_space_dim
    def forward(self, x, hand_side):
        #(1) Obtain mean and logvar by encoding the input image and concatenating the handedness into encoded features
        self.mean,self.var = self.encoder.forward(x, hand_side)
        #(2) Reparameterize to sample from the latent space-distribution
        self.repair = self.reparameterize(self.mean,self.var)
        #(3) Decode the sample 
        self.decoding = self.decoder.forward(self.repair)
        return self.decoding
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)


# ***LOSSES BLOCK***
This block calculates the MSE and KL losses


In [None]:
def get_mse_loss(preds, ground_truth):

    criterion = nn.MSELoss(reduction='sum')
    loss = criterion(preds,ground_truth)
    return loss 

def get_kl_loss(mean, logvar):

    kl_loss = - 0.5 * torch.sum(1 + logvar - mean**2 -torch.exp(logvar))
    return kl_loss



# ***METRIC BLOCK***
This block implements mean End-Point-Error


In [None]:
def calc_mean_epe(pred_joints, gt_joints, visible_joints):
lculate mean EPE on all joints 

    We calculate mean EPE on visible joints 

    It return mean EPE on visible keypoints
    

# ***TRAINING BLOCK:***
This block contains the training function that you will use for the training process of the VAE model. 



In [None]:
def train(model, train_loader, device, optimizer, epoch, beta_kl, total_epochs, log_interval):    

    model.train()
    train_mse_loss = 0
    train_kl_loss = 0
    mean_epe = 0
    for batch_idx, (img, keypoint_xyz21, keypoint_vis21, keypoint_scale, hand_side) in enumerate(train_loader):
      # send the img data set to gpu device
      img=img.float()
      img = img.to(device)
      keypoint_xyz21 = keypoint_xyz21.float()
      keypoint_xyz21 = keypoint_xyz21.to(device)
      hand_side=hand_side.float()
      hand_side = hand_side.to(device)
      #(1) Forward propagation of input to the model
      logits = model.forward(img,hand_side)
      keypoint_scale = torch.unsqueeze(keypoint_scale,1)
      keypoint_scale = torch.unsqueeze(keypoint_scale,2)
      keypoint_scale=keypoint_scale.to(device)
      #(2) Make predictions scale-invariant (multiply the keypoint predictions with the calculated bone scale (“keypoint_scale”), in order to make them scale invariant)
      #  x1 = torch.tensor(np.ones(shape=(logits.shape[0],logits.shape[1],logits.shape[2])))
      #  for i in range(logits.shape[0]):
      #    x1[i] = x1[i] * keypoint_scale[i]
      logits *=  keypoint_scale
       #(3) Loss Calculation
      mse_loss = get_mse_loss(logits,keypoint_xyz21)
      train_mse_loss += mse_loss.item()
    
      mean   = model.mean
      logvar = model.var
      kl_loss  = get_kl_loss(mean,logvar)
      train_kl_loss += kl_loss.item()
      #mse_loss  = mse_loss.to(torch.float32)
     # kl_loss  =kl_loss.to(torch.float32)
      loss = mse_loss + beta_kl * kl_loss
       #(4) Backpropagation and optimization
       # Back-propagation of loss and gradient calculation
      loss.backward()
        # optimization
      optimizer.step()
        # zero the gradients buffer
      optimizer.zero_grad()
       #(5) Mean EPE calculation of visible keypoints
      mean_epe = calc_mean_epe(logits, keypoint_xyz21, keypoint_vis21)


        #print stats per log_interval*batch_size samples: mean epe*1000 to express in mm, losses divided by number of samples assuming you calculate them with reduction "sum" (not mean per batch)
      if (batch_idx % log_interval == 1):
         print('Train Epoch: [{}/{}] \t Image [{}/{}] \t Mean_EPE: {:.6f} \t Loss_MSE: {:.6f} \tLoss_KL: {:.6f}'.format(
                epoch,total_epochs, batch_idx * len(img), len(train_loader.dataset), \
                1000*mean_epe/(batch_idx+1),\
                train_mse_loss / ((batch_idx + 1)*len(img)),\
                train_kl_loss / ((batch_idx + 1)*len(img))))

    return [str(epoch), str(1000*mean_epe/(batch_idx+1)),
              str(float(train_mse_loss / len(train_loader.dataset))), str(float(train_kl_loss / len(train_loader.dataset)))]


# ***EVALUATION BLOCK:***
This block contains the validation function that you will use for the training process of the VAE model. 





In [None]:
def validate(model, test_loader, device, optimizer, epoch, ckpt_dir):

    model.eval()
    test_mse_loss = 0
    test_kl_loss = 0
    mean_epe = 0

    with torch.no_grad():

        for batch_idx, (img, keypoint_xyz21, keypoint_vis21, keypoint_scale, hand_side) in enumerate(test_loader):
          
          img=img.float()
          img = img.to(device)
          keypoint_xyz21 = keypoint_xyz21.float()
          keypoint_xyz21 = keypoint_xyz21.to(device)
          hand_side=hand_side.float()
          hand_side = hand_side.to(device)
          #(1) Forward propagation of input to the model
          logits = model.forward(img,hand_side)
          keypoint_scale = torch.unsqueeze(keypoint_scale,1)
          keypoint_scale = torch.unsqueeze(keypoint_scale,2)
          keypoint_scale=keypoint_scale.to(device)
          #(2) Make predictions scale-invariant (multiply the keypoint predictions with the calculated bone scale (“keypoint_scale”), in order to make them scale invariant)
          logits *=  keypoint_scale
          xyz_prediction = logits
          #    (3) Loss Calculation
          mse_loss = get_mse_loss(logits,keypoint_xyz21)
          test_mse_loss += mse_loss.item()
          mean   = model.mean
          logvar = model.var
          kl_loss  = get_kl_loss(mean,logvar)
          test_kl_loss += kl_loss.item()
          #(4) Mean EPE calculation of visible keypoints
          mean_epe = calc_mean_epe(logits, keypoint_xyz21, keypoint_vis21)

          #(5) Visualize the predicted and ground-truth keypoints alongside each image 
        
          visualize_hand(img.clone().detach(), xyz_prediction.clone().detach().cpu().numpy(), keypoint_xyz21.clone().detach().cpu().numpy(),ckpt_dir,batch_idx)

    print('Evaluation : Mean_EPE {:.5f}\t Loss_MSE: {:.6f} \tLoss_KL: {:.6f}'.format(
         1000*mean_epe/(batch_idx+1),\
         test_mse_loss / len(test_loader.dataset),\
         test_kl_loss / len(test_loader.dataset)))

    return 1000*mean_epe/(batch_idx+1), [str(epoch),
                   str(1000*mean_epe/(batch_idx+1)), str(float(test_mse_loss /  len(test_loader.dataset))), str(float(test_kl_loss /  len(test_loader.dataset)))]


# ***MAIN BLOCK:***

In [None]:
def main():
    timestamp = str(time.asctime(time.localtime(time.time())))
    parser = argparse.ArgumentParser(description='3D Hand Pose Estimation')
    
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')

    parser.add_argument('--beta-kl', type=float, default=0.00001, metavar='b',
                        help='weight of KL loss(default: 0.00001)')
    
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 1000)')
    
    parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
                        help='learning rate (default: 0.0001)')
    
    parser.add_argument('--z-dim', type=int, default=64, metavar='N',
                        help='latent space dimensionality (default: 64)')

    parser.add_argument('--pretrained', type=bool, default=True, metavar='p',
                        help='load pretrained model')
    
    parser.add_argument('--cuda', action='store_true', default=True,
                        help='enables CUDA training')

    parser.add_argument('--mode', type=str, default='rgb23d', metavar='m',
                        help='VAE mode')
    
    parser.add_argument('--seed', type=int, default=1234, metavar='S',
                        help='random seed (default: 1234)')
    
    parser.add_argument('--log-interval', type=int, default=100, metavar='l',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--gpu', type=str, default='0')

    args  = parser.parse_args(args=[])

    timestamp = str(time.asctime(time.localtime(time.time())))
    set_seed(args.seed)
    device = get_device(args.cuda, args.gpu)
    ###Print data settings###
    print_data_info()
    ###Print model settings###
  #  print_model_info(args.beta_kl, args.z_dim, args.lr, args.batch_size, args.log_interval)
    ###Generate training and test data###
    training_generator, test_generator = generate_data(args.batch_size)

    num_of_joints=[21,3]
    ###Create VAE model###/
    rgb_encoder = load_encoder(args.z_dim)
    pose_decoder = load_decoder(args.z_dim, num_of_joints)
    model = VAE_model(rgb_encoder, pose_decoder, num_of_joints, args.z_dim).to(device)
    ###Create optimizer###
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    ###Create checkpoint folder###
    ckpt_fol_name = '/content/drive/My Drive/hand_checkpoints/test_day' + timestamp

    """
    PRETRAINED SECTION
    """
    if args.pretrained:
      ###Choose pretrained path to load###
      path = "/content/drive/MyDrive/hand_checkpoints/test_daySun Jan  3 23_22_41 2021/best_mean_epe.pth"
      print("Resuming pretrained model")
      start_epoch = load_checkpoint(path, model)
      start_epoch = start_epoch + 1
    else:
      start_epoch = 1
    print(timestamp)
    print("Checkpoint Directory= ", ckpt_fol_name)

    ##Training-test loop###

    best_mean_epe = 1000000
    for epoch in range(start_epoch, args.epochs):

        # "TRAINING MODE"
        # print("!!!!!!!!   TRAINING   !!!!!!!!")
        # train_stats = train(model, training_generator, device, optimizer, epoch, args.beta_kl, args.epochs, args.log_interval)
        
        "EVALUATION MODE"
        print("!!!!!!!!   VALIDATION   !!!!!!!!")
        mean_epe, test_stats = validate(model, test_generator, device, optimizer, epoch, ckpt_fol_name)
      
        "CHECKPOINT"
        is_best = mean_epe < best_mean_epe
        if (is_best):
            print("Best mean_EPE {}".format(mean_epe))
            best_mean_epe = mean_epe
            save_checkpoint(model, epoch, mean_epe, ckpt_fol_name, 'best_mean_epe')
        else:
            save_checkpoint(model, epoch, mean_epe, ckpt_fol_name, 'last')

        ###save training measurements###
        logger(ckpt_fol_name + '/training.txt', train_stats)
        ###save test measurements###
        logger(ckpt_fol_name + '/validation.txt', test_stats)
        ###save arguments###
        save_training_arguments(args,ckpt_fol_name)
  
if __name__ == '__main__':
    main()