In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Check device availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device: %s" % device)

device: cuda


# RESNet-101 Test

In [3]:
from resnet_wrapper import ResNetWrapper

res = ResNetWrapper()
res


  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
    (layer2): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=Fal

In [4]:
in_t = torch.randn(1, 3, 960, 720)

out, low_level_features = res(in_t)
print(out.shape, low_level_features.shape)

torch.Size([1, 1280, 30, 22]) torch.Size([1, 256, 240, 180])


# WASP Module Test



In [5]:
from wasp import WASP
in_t = torch.randn(16, 1280, 120, 90)

w = WASP(device)
w

WASP(
  (convolutions): ModuleList(
    (0): ModuleDict(
      (3x3): Conv2d(1280, 256, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
      (1x1): Sequential(
        (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (1): ModuleDict(
      (3x3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12))
      (1x1): Sequential(
        (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (2): ModuleDict(
      (3x3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(18, 18), dilation=(18, 18))
      (1x1): Sequential(
        (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (3): ModuleDict(
      (3x3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=

In [6]:
out = w(in_t)
print(out.shape)

torch.Size([16, 256, 120, 90])


# Decoder Test

In [7]:
from decoder import Decoder

d = Decoder()

In [8]:
in_wasp = torch.randn(1, 256, 120, 90)
in_low_level = torch.randn(1, 256, 240, 180)

score_maps = d(in_wasp, in_low_level)
print(score_maps.shape)

torch.Size([1, 16, 1280, 720])


# UniPose Test

In [9]:
from unipose import UniPose
u = UniPose().to(device)
u


ernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stat

In [10]:
in_t = torch.randn(1, 3, 960, 720).to(device)

out = u(in_t)

print(out.shape)

torch.Size([1, 16, 1280, 720])
