In [None]:
import torch
import torch.nn.functional as F
from torch import nn, Tensor
import torch.backends.cudnn as cudnn
from lib.models.S2 import SR, SP
import matplotlib.pyplot as plt
from lib.models.transformer import build_transformer
from lib.models.backbone import *
from lib.models.pose_transformer import MLP

# cudnn related setting
cudnn.benchmark = True
cudnn.determinstic = False
cudnn.enabled = True

transformer = build_transformer(hidden_dim=256, dropout=0.0, nheads=8, dim_feedforward=2048,
                                enc_layers=6, dec_layers=6, pre_norm=False,
                                num_clusters=9, use_sr=True)
backbone = ResNetBackbone('resnet50', train_backbone=True, return_interm_layers=True, pretrained=True, dilation=False)
position_embedding = build_position_encoding(256, 'sine')
backbone = Joiner(backbone, position_embedding)


In [None]:
src, pos = backbone(x)

In [None]:
class PoseTransformer(nn.Module):

    def __init__(self, backbone, transformer, **kwargs):
        super(PoseTransformer, self).__init__()
        self.num_queries = 68
        self.num_classes = 68
        self.transformer = transformer
        self.backbone = backbone
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, self.num_classes + 1)
        self.kpt_embed = MLP(hidden_dim, hidden_dim, 2, 3)
        self.query_embed = nn.Embedding(self.num_queries, hidden_dim)
        self.aux_loss = extra.AUX_LOSS

        self.num_feature_levels = extra.NUM_FEATURE_LEVELS
        if self.num_feature_levels > 1:
            num_backbone_outs = len(backbone.num_channels)
            input_proj_list = []
            for _ in range(num_backbone_outs):
                in_channels = backbone.num_channels[_]
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
            for _ in range(self.num_feature_levels - num_backbone_outs):
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
                in_channels = hidden_dim
            self.input_proj = nn.ModuleList(input_proj_list)
            # 初始化
            prior_prob = 0.01
            bias_value = -math.log((1 - prior_prob) / prior_prob)
            self.class_embed.bias.data = torch.ones(self.num_classes + 1) * bias_value
            nn.init.constant_(self.kpt_embed.layers[-1].weight.data, 0)
            nn.init.constant_(self.kpt_embed.layers[-1].bias.data, 0)
            for proj in self.input_proj:
                nn.init.xavier_uniform_(proj[0].weight, gain=1)
                nn.init.constant_(proj[0].bias, 0)
            num_pred = transformer.decoder.num_layers  ##解码器的层数
            self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])  ## 每一个解码层都添加全连接网络进行预测
            self.kpt_embed = nn.ModuleList([self.kpt_embed for _ in range(num_pred)])
        else:
            self.input_proj = nn.Conv2d(self.backbone.num_channels[0], hidden_dim, 1)

    def forward(self, x):
        src, pos = self.backbone(x)
        # features, pos = self.backbone(x)
        # srcs = []
        # for l, feat in enumerate(features):
        #     srcs.append(self.input_proj[l](feat))
        plt.imshow(np.transpose(x[0].cpu().numpy(), (1, 2, 0)))

        hs = self.transformer(self.input_proj(src[-1]), None,
                              self.query_embed.weight, pos[-1])

        # hs = self.transformer(srcs, None,
        #                       self.query_embed.weight, pos)

        outputs_class = self.class_embed(hs)
        outputs_coord = self.kpt_embed(hs).sigmoid()
        # outputs_classes = []
        # outputs_coords = []
        # for lvl in range(hs.shape[0]):
        #     outputs_class = self.class_embed[lvl](hs[lvl])
        #     tmp = self.kpt_embed[lvl](hs[lvl])
        #     outputs_coord = tmp.sigmoid()
        #     outputs_classes.append(outputs_class)
        #     outputs_coords.append(outputs_coord)
        #
        # outputs_class = torch.stack(outputs_classes)
        # outputs_coord = torch.stack(outputs_coords)

        out = {
            'pred_logits': outputs_class[-1],
            'pred_coords': outputs_coord[-1]
            }

        # outs = []
        # for i in range(len(outputs_class)):
        #     out = {
        #         'pred_logits': outputs_class[i],
        #         'pred_coords': outputs_coord[i]}
        #
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(
                outputs_class,
                outputs_coord)
        #     outs.append(out)
        return out

    @torch.jit.unused
    def _set_aux_loss(self,
                      outputs_class,
                      outputs_coord):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        return [{'pred_logits': a, 'pred_coords': b}
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

In [None]:
model = PoseTransformer()

model = nn.DataParallel(model, device_ids=[0]).cuda()

matcher = build_matcher(config.MODEL.NUM_JOINTS)
weight_dict = {'loss_ce': 1, 'loss_kpts': config.MODEL.EXTRA.KPT_LOSS_COEF}
if config.MODEL.EXTRA.AUX_LOSS:
    aux_weight_dict = {}
    for i in range(config.MODEL.EXTRA.DEC_LAYERS - 1):
        aux_weight_dict.update(
            {k + f'_{i}': v for k, v in weight_dict.items()})
    weight_dict.update(aux_weight_dict)
criterion = SetCriterion(model.num_classes, matcher, weight_dict, config.MODEL.EXTRA.EOS_COEF, [
    'labels',
    'kpts',
    'cardinality'
]).cuda()

gpus = list(config.GPUS)
model = nn.DataParallel(model, device_ids=gpus).cuda()

optimizer = utils.get_optimizer(config, model)
best_nme = 100
last_epoch = config.TRAIN.BEGIN_EPOCH
if config.TRAIN.RESUME:
    model_state_file = os.path.join(final_output_dir,
                                    'latest.pth')
    if os.path.isfile(model_state_file):
        checkpoint = torch.load(model_state_file)
        last_epoch = checkpoint['epoch']
        best_nme = checkpoint['best_nme']
        model.module.load_state_dict(checkpoint['state_dict'].module.state_dict())
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint (epoch {})"
              .format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found")

if isinstance(config.TRAIN.LR_STEP, list):
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, config.TRAIN.LR_STEP,
        config.TRAIN.LR_FACTOR, last_epoch-1
    )
else:
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, config.TRAIN.LR_STEP,
        config.TRAIN.LR_FACTOR, last_epoch-1
    )
dataset_type = get_dataset(config)

train_loader = DataLoader(
    dataset=dataset_type(config,
                         is_train=True),
    batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
    shuffle=config.TRAIN.SHUFFLE,
    num_workers=config.WORKERS,
    pin_memory=config.PIN_MEMORY)

val_loader = DataLoader(
    dataset=dataset_type(config,
                         is_train=False),
    batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus),
    shuffle=False,
    num_workers=config.WORKERS,
    pin_memory=config.PIN_MEMORY
)