## Libraries

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.core.lightning import LightningModule
from PIL import Image
from skimage import io, transform
import seaborn as sns
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F

## Data Loader class

In [None]:
class F_DB(Dataset):
    def __init__(self, root_dir: str, labels_path: str, gray: bool, transform):
        self.root_dir = root_dir
        self.labels = self.get_labels(labels_path)
        self.img_names = self.get_image_names()
        self.transform = transform
        self.gray = gray

    def get_image_names(self):
        img_names = next(os.walk(self.root_dir))[2]
        img_names.sort()
        return img_names

    def get_labels(self, lables_path):
        with open(lables_path, "r") as f:
            return json.load(f)
        return None

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.img_names[idx])
        img = io.imread(img_name, as_gray=self.gray)
        joints = torch.from_numpy(np.array(self.labels[idx])).float()

        sample = {"image": img, "joints": joints}
        if self.transform:
            sample['image'] = self.transform(sample['image'])

        return sample

In [None]:
f_db = F_DB(root_dir='../data/raw/FreiHAND_pub_v2/training/rgb',
           labels_path='../data/raw/FreiHAND_pub_v2/training_xyz.json', transform =transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),              
]), gray=False)

In [None]:
class BasicHandPose(LightningModule):
    def __init__(self, freeze_resnet=True):
        super().__init__()

        self.resnet18 = torchvision.models.resnet18(pretrained=True)
        if freeze_resnet:
            for param in model.parameters():
                param.requires_grad = False
        self.resnet18.fc = torch.nn.Linear(self.resnet18.fc.in_features, 128)

        self.layer_1 = torch.nn.Linear(128, 128)
        self.output_layer = torch.nn.Linear(128, 21 * 3)

    def forward(self, x):
        batch_size, channel, width, height = x.size()

        x = self.resnet18(x)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.output_layer(x)
        x = x.view(batch_size, 21, 3)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch["image"], batch["joints"]
        #         print(x)
        prediction = self(x)
        loss = F.mse_loss(prediction, y)
        result = pl.TrainResult(loss)
        result.log("train_loss", loss)
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [None]:
# model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)

In [None]:
# model.fc = torch.nn.Linear(model.fc.in_features,128) 

In [None]:
model = BasicHandPose()
out  = model(f_db[0]['image'].view(1,3, 224, 224).float())

In [None]:
dta = DataLoader(f_db , batch_size = 16)

In [None]:
from pytorch_lightning import Trainer
trainer = Trainer(max_epochs=3)
trainer.fit(model, dta)

#### Unsupervised landmark detection

In [None]:
WEIGHTS = torch.tensor([i for  i in range(x_distribution.shape[-1])]).double()
SIGMA = 5
class Net(Module):   
    def __init__(self):
        super(Net, self).__init__()
        
        self.block1 = Sequential(
            Conv2d(1, 32, kernel_size=7, stride=2, padding=1),
            ReLU(inplace=True),
            Conv2d(32, 32, kernel_size=7, stride=1, padding=1),
            ReLU(inplace=True)
        )
        self.block2 = Sequential(
            Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            ReLU(inplace=True),
            Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            ReLU(inplace=True)
        )
        
        self.block3 = Sequential(
            Conv2d(128, 21, kernel_size=1, stride=1, padding=1)
        )
        
    
    def spatial_expectations(self, x):
        sum_of_channels = torch.sum(x, dim=(2,3))
        normalized_channels = x/sum_of_channels 
        x_distribution = torch.sum(x, dim=(2))
        y_distribution = torch.sum(x, dim=(3))
        mean_x = torch.matmul(x_distribution, WEIGHTS)
        mean_y = torch.matmul(y_distribution, WEIGHTS)
        return mean_x, mean_y
    
    def gaussian_heatmaps(self, mean_x, mean_y):
        mean_x = np.array([112])
        mean_y = np.array([112])
        data   = np.array([1.0])
        temp   = coo_matrix((data, (mean_x, mean_y)), shape=(244, 244)).toarray()
        
        
        
    # Defining the forward pass    
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        mean_x, mean_y = self.spatial_expectations(x)
        
        
        return x

In [None]:
x, y = np.random.multivariate_normal([1,1], [[1,0],[0,1]], 5000).T

In [None]:
x.shape

In [None]:
model = Net()
print(model)

In [None]:
model(torch.from_numpy(f_db[2]['image'].reshape(1,1,224,224)).float()).shape