In [1]:
import matplotlib.pyplot as plt
from matplotlib import patches
import torch
import torchvision
from pol.datasets.objdetect import COCODataset, collate_padding_fn
from torch.utils.data import DataLoader, Dataset

In [2]:
class MiniDETR(torch.nn.Module):
    def __init__(self, hidden_dim=256, 
                 use_transformer=True, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()
        
        self.use_transformer = use_transformer
        self.backbone = torchvision.models.resnet50(pretrained=True)
        del self.backbone.fc
        
        self.conv = torch.nn.Conv2d(2048, hidden_dim, 1)
        
        if self.use_transformer:
            self.transformer = torch.nn.Transformer(
                hidden_dim, nheads, num_encoder_layers, num_decoder_layers
            )
            self.query_pos = torch.nn.Parameter(torch.rand(100, hidden_dim))

            self.row_embed = torch.nn.Parameter(torch.rand(50, hidden_dim // 2))
            self.col_embed = torch.nn.Parameter(torch.rand(50, hidden_dim // 2))
            self.final_fc = torch.nn.Linear(100, 1)
        else:
            self.pooling = torch.nn.AdaptiveAvgPool2d((1, 1))
        
    def forward(self, inputs):
        '''
        Args:
            inputs: BxHxW, batched images
        
        Returns:
            BxD, embedded latent codes
        '''
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)
        
        # x is Bx2048xH'xW'
        h = self.conv(x) # BxDxH'xW'
        
        if self.use_transformer:
            B, _, H, W = h.shape
            pos = torch.cat([
                self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
                self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
            ], dim=-1).flatten(0, 1).unsqueeze(1)
            # pos: H'W'x1xD
            h = 0.1 * h.flatten(2).permute(2, 0, 1) # H'W'xBxD
            h = self.transformer(pos + h,
                                 self.query_pos.unsqueeze(1).repeat(
                                     1, B,
                                     1)) # 100xBxD
            h = self.final_fc(h.permute(1, 2, 0)) # BxDx1
            return h.squeeze(-1)
        else:
            h = self.pooling(h) # BxDx1x1
            return h.squeeze(-1).squeeze(-1)

In [3]:
img_encoder = MiniDETR(use_transformer=False)
dataset = COCODataset(split='validation', max_num_detection=10)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, 
                        collate_fn=collate_padding_fn, drop_last=True)

for batch_idx, data in enumerate(dataloader):
    batch_img = data['img']
    output = img_encoder(batch_img)
    print(output.shape)
    break

Downloading split 'validation' to '/home/lingxiao/fiftyone/coco-2017/validation' if necessary
Found annotations at '/home/lingxiao/fiftyone/coco-2017/raw/instances_val2017.json'
Sufficient images already downloaded
Existing download of split 'validation' is sufficient
Loading 'coco-2017' split 'validation'
 100% |█████████████████| 349/349 [873.6ms elapsed, 0s remaining, 399.5 samples/s]      
Dataset 'coco-2017-validation' created
torch.Size([4, 256])
