In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import utils
from torchvision import transforms

from PIL import Image 

import os
import numpy as np

In [2]:
def decode(x):
    num_filters = 256
    for i in range(3):
        num_filters = num_filters // pow(2, i)
        x = nn.ConvTranspose2d(in_channels=64,out_channels=num_filters,kernel_size=(4,4),bias=False)(x)
        x = nn.BatchNorm2d(256)(x)
        x = nn.ReLU()(x)
        return x

In [4]:
sample_img = Image.open("./images/street_small.jpg")
sample_tensor = transforms.ToTensor()(sample_img)
sample_tensor = sample_tensor.view(-1, sample_tensor.shape[0], sample_tensor.shape[1], sample_tensor.shape[2])
num_classes = 10
print(sample_tensor.shape)
x = sample_tensor
x = nn.Conv2d(3, 32, 5)(x)
x = nn.Conv2d(32, 64, 5)(x)
x = decode(x)
x = nn.Conv2d(in_channels=256 ,out_channels=64, kernel_size=3, bias=False)(x)
x = nn.BatchNorm2d(64)(x)
x = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)(x)

torch.Size([1, 3, 480, 480])


TypeError: gather() received an invalid combination of arguments - got (int, int), but expected one of:
 * (name dim, Tensor index, bool sparse_grad)
 * (int dim, Tensor index, bool sparse_grad)


In [None]:
class ConvCenterNet(nn.Module):
    def __init__(self, num_classes=10, input_size=480, max_objects=100):
        super(ConvCenterNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.decoder = decode(x)
        
        # heatmap header
        self.hm_conv2d_1 = nn.Conv2d(in_channels=256 ,out_channels=64, kernel_size=3, bias=False)
        self.hm_bn = nn.BatchNorm2d(64)
        self.hm_relu = nn.ReLU()
        self.hm_conv2d_2 = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)
        
        # wh header
        self.wh_conv2d_1 = nn.Conv2d(in_channels=256 ,out_channels=64, kernel_size=3, bias=False)
        self.wh_bn = nn.BatchNorm2d(64)
        self.wh_relu = nn.ReLU()
        self.wh_conv2d_2 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
        
        # reg header is the same as wh header
        self.reg_conv2d_1 = nn.Conv2d(in_channels=256 ,out_channels=64, kernel_size=3, bias=False)
        self.reg_bn = nn.BatchNorm2d(64)
        self.reg_relu = nn.ReLU()
        self.reg_conv2d_2 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
    
    def forward(x):
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.decoder(x)
        
        # hm header
        y1 = self.hm_conv2d_1(x)
        y1 = self.hm_bn(y1)
        y1 = self.hm_relu(y1)
        y1 = self.hm_conv2d_2(y1)
        
        # wh header
        y2 = self.wh_conv2d_1(y2)
        y2 = self.wh_bn(y2)
        y2 = self.wh_relu(y2)
        y2 = self.wh_conv2d_2(y2)
        
        # reg header
        y3 = self.reg_conv2d_1(x)
        y3 = self.reg_bn(y3)
        y3 = self.reg_relu(y3)
        y3 = self.reg_conv2d_2(y3)
        
        return [y1, y2, y3]