In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
# Check device availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device: %s" % device)

device: cpu


# RESNet-101 Test

In [19]:
from resnet_wrapper import ResNetWrapper

res = ResNetWrapper()
res


Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /home/jupyter/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth


ResNetWrapper(
  (net): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
        

In [20]:
in_t = torch.randn(1, 3, 960, 720)

out, low_level_features = res(in_t)
print(out.shape, low_level_features.shape)

torch.Size([1, 1280, 30, 22]) torch.Size([1, 256, 240, 180])


# WASP Module Test



In [21]:
from wasp import WASP
in_t = torch.randn(16, 1280, 120, 90)

w = WASP(device)
w

WASP(
  (convolutions): ModuleList(
    (0): ModuleDict(
      (3x3): Conv2d(1280, 256, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
      (1x1): Sequential(
        (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (1): ModuleDict(
      (3x3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12))
      (1x1): Sequential(
        (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (2): ModuleDict(
      (3x3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(18, 18), dilation=(18, 18))
      (1x1): Sequential(
        (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (3): ModuleDict(
      (3x3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=

In [22]:
out = w(in_t)
print(out.shape)

torch.Size([16, 256, 120, 90])


# Decoder Test

In [23]:
from decoder import Decoder

d = Decoder()

In [24]:
in_wasp = torch.randn(1, 256, 120, 90)
in_low_level = torch.randn(1, 256, 240, 180)

score_maps = d(in_wasp, in_low_level)
print(score_maps.shape)

torch.Size([1, 16, 960, 720])


# UniPose Test

In [25]:
from unipose import UniPose
u = UniPose().to(device)
u


UniPose(
  (resnet): ResNetWrapper(
    (net): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace

In [10]:
# in_t = torch.randn(1, 3, 960, 720).to(device)

# out = u(in_t)

# print(out.shape)

torch.Size([1, 16, 960, 720])


In [41]:
import json
import numpy as np
import cv2
from google.cloud import storage

imagelist = []
kptlist = []

with open('../annotations/valid.json') as f:
    data = json.load(f)
    
storage_client = storage.Client("pose_estimation")
bucket = storage_client.get_bucket('pose_estimation_datasets')

for i in range(len(data)):
    img_name = data[i]['image']
    
    blob = bucket.blob('MPII/images/' +  img_name)
    blob.content_type = 'image/jpeg'
    image = np.asarray(bytearray(blob.download_as_string()))
    img = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
    
    kpt = np.asarray(data[i]['joints'], dtype=np.int32)

    if img.shape[0] != 960 or img.shape[1] != 720:
        kpt[:,0] = kpt[:,0] * (960/img.shape[1])
        kpt[:,1] = kpt[:,1] * (720/img.shape[0])
        img = cv2.resize(img,(960,720))
        img = np.array(img)
#     height, width, _ = img.shape
    
    imagelist.append(img)
    kptlist.append(kpt)

In [None]:
import torch
torch_image = torch.Tensor(imagelist[0:5])
torch_image = torch_image.permute(0, 3, 2, 1).to(device)
print(torch_image.shape)
out = u(torch_image).to(device)

torch.Size([5, 3, 960, 720])


In [None]:
kpt_tensor = torch.FloatTensor(kptlist[0]).to(device)
print(kpt_tensor.shape)

# Naive MSE Loss Test

In [30]:
from criterion.joint_max_mse_loss import JointMaxMSELoss

in_t = torch.randn(1, 16, 960, 720).to(device)

correct_coords = torch.Tensor([[
    (40, 40),
    (40, 40),
    (40, 40),
    (40, 40),

    (40, 40),
    (40, 40),
    (40, 40),
    (40, 40),

    (40, 40),
    (40, 40),
    (40, 40),
    (40, 40),

    (40, 40),
    (40, 40),
    (40, 40),
    (40, 40),
]]).to(device)

loss = JointMaxMSELoss()

output = loss(out, kpt_tensor)
# output.backward()


tensor(150491.8750)

In [32]:
print(out.shape)

torch.Size([1, 16, 960, 720])
