In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import joblib
from data_prep.renderMPpose import *
import cv2

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
output = joblib.load("../data/islrtc_hand_train/lhpts.pkl") 
handpts = output["joints"]
dims = handpts[0].flatten().shape[0]

In [4]:
class Generator(nn.Module):
    
    def __init__(self, noise_dim, pose_dim):
        
        super(Generator, self).__init__()
        
        self.gen = torch.nn.Sequential(
            # Fully Connected Layer 1
            torch.nn.Linear(
                in_features=noise_dim,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 2
            torch.nn.Linear(
                in_features=240,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 3
            torch.nn.Linear(
                in_features=240,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            # Fully Connected Layer 4
            torch.nn.Linear(
                in_features=240,
                out_features=pose_dim,
                bias=True
            ),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.gen(x)


In [5]:
class Maxout(nn.Module):

    def __init__(self, num_pieces):

        super(Maxout, self).__init__()

        self.num_pieces = num_pieces

    def forward(self, x):

        # x.shape = (batch_size? x 625)

        assert x.shape[1] % self.num_pieces == 0  # 625 % 5 = 0

        ret = x.view(
            *x.shape[:1],  # batch_size
            x.shape[1] // self.num_pieces,  # piece-wise linear
            self.num_pieces,  # num_pieces
            *x.shape[2:]  # remaining dimensions if any
        )
        
        # ret.shape = (batch_size? x 125 x 5)

        # https://pytorch.org/docs/stable/torch.html#torch.max        
        ret, _ = ret.max(dim=2)

        # ret.shape = (batch_size? x 125)

        return ret

In [6]:
class Discriminator(nn.Module):
    
    def __init__(self, pose_dim):
        
        super(Discriminator, self).__init__()
        
        self.disc = torch.nn.Sequential(
            # Fully Connected Layer 1
            torch.nn.Linear(
                in_features=pose_dim,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            # Fully Connected Layer 2
            torch.nn.Linear(
                in_features=240,
                out_features=240,
                bias=True
            ),
            torch.nn.ReLU(),
            # Fully Connected Layer 3
            torch.nn.Linear(
                in_features=240,
                out_features=1,
                bias=True
            ),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)


In [7]:
num_epochs = 512
num_steps = len(handpts)

generator = Generator(dims, dims).to(device)
discriminator = Discriminator(dims).to(device)

discriminator_optimizer = optim.SGD(discriminator.parameters(), lr=0.0002, momentum=0.5)
generator_optimizer = torch.optim.SGD(generator.parameters(), lr=0.0002, momentum=0.5)

criterion = torch.nn.BCELoss()

In [None]:
for epoch in range(num_epochs):

    for i, real_pose in enumerate(handpts):
        if i == num_steps:
            break
        real_pose_tensor = torch.tensor(real_pose/128, dtype=torch.float)
        real_pose_tensor = real_pose_tensor.flatten().to(device)
        
        batch_size = real_pose_tensor.shape[0]
        # Train Discriminator
        for _ in range(8):
        
            fake_pose_tensor = generator(torch.randn(1, dims).to(device))
            
            real_outputs = discriminator(real_pose_tensor).view(-1)
            lossD_real = criterion(real_outputs, torch.ones_like(real_outputs))
            fake_outputs = discriminator(fake_pose_tensor).view(-1)
            lossD_fake = criterion(fake_outputs, torch.zeros_like(fake_outputs))
            
            lossD = (lossD_real + lossD_fake) / 2
            discriminator.zero_grad()
            lossD.backward(retain_graph=True)

            discriminator_optimizer.step()

        # Train Generator
        outputs = discriminator(fake_pose_tensor)
        lossG = criterion(outputs, torch.ones_like(outputs))
        generator.zero_grad()
        lossG.backward()

        generator_optimizer.step()
        #print(epoch, i)
    # Visualize Results
    if epoch % 1 == 0:
        print(f"Epoch [{epoch}/{num_epochs}] , Loss D: {lossD:.4f}, loss G: {lossG:.4f}")
        with torch.no_grad():
            fake = generator(torch.randn(1, dims).to(device)).reshape(-1, 1, 21, 2)
            #data = real_pose_tensor.reshape(-1, 1, 21, 2)
            fake_pts = fake[0][0].numpy()*128
            #real_pts = data[0][0].numpy()*128
            #real_img = np.zeros((128, 128, 3), dtype=np.uint8)
            fake_img = np.zeros((128, 128, 3), dtype=np.uint8)
            #display_single_hand_skleton(real_img, real_pts.astype(int))
            display_single_hand_skleton(fake_img, fake_pts.astype(int))
            #vis = np.concatenate((real_img, fake_img), axis=1)
            cv2.imwrite(f'tmp/out_e{epoch}.png', fake_img)
        

In [None]:
import torch

In [None]:
torch.randn(42, 42)