In [84]:
import torch
import torch.nn as nn
from models.common import ReOrg, Conv, Concat, SPPCSPC
from models.yolo import IKeypoint

anchors = [
    [19, 27, 44, 40, 38, 94],  # P3/8
    [96, 68, 86, 152, 180, 137],  # P4/16
    [140, 301, 303, 264, 238, 542],  # P5/32
    [436, 615, 739, 380, 925, 792]  # P6/64
]

class YOLOv7(nn.Module):
    def __init__(self, nc=1, nkpt=17, anchors=anchors, depth_multiple=1.0, width_multiple=1.0):
        super(YOLOv7, self).__init__()
        self.nc = nc
        self.nkpt = nkpt
        self.anchors = anchors if anchors is not None else [[], [], [], []]
        self.depth_multiple = depth_multiple
        self.width_multiple = width_multiple

        self.backbone = nn.ModuleList([
            ReOrg(),
            Conv(3, int(64 * width_multiple), 3, 1),  # 1-P1/2
            Conv(int(64 * width_multiple), int(128 * width_multiple), 3, 2),  # 2-P2/4
            Conv(int(128 * width_multiple), int(64 * width_multiple), 1, 1),
            Conv(int(64 * width_multiple), int(64 * width_multiple), 1, 1),
            Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
            Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
            Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
            Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
            Concat([-1, -3, -5, -6]),
            Conv(int(256 * width_multiple), int(128 * width_multiple), 1, 1),  # 10
            Conv(int(128 * width_multiple), int(256 * width_multiple), 3, 2),  # 11-P3/8
            Conv(int(256 * width_multiple), int(128 * width_multiple), 1, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 1, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Concat([-1, -3, -5, -6]),
            Conv(int(512 * width_multiple), int(256 * width_multiple), 1, 1),  # 19
            Conv(int(256 * width_multiple), int(512 * width_multiple), 3, 2),  # 20-P4/16
            Conv(int(512 * width_multiple), int(256 * width_multiple), 1, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
            Concat([-1, -3, -5, -6]),
            Conv(int(1024 * width_multiple), int(512 * width_multiple), 1, 1),  # 28
            Conv(int(512 * width_multiple), int(768 * width_multiple), 3, 2),  # 29-P5/32
            Conv(int(768 * width_multiple), int(384 * width_multiple), 1, 1),
            Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
            Conv(int(384 * width_multiple), int(384 * width_multiple), 3, 1),
            Conv(int(384 * width_multiple), int(384 * width_multiple), 3, 1),
            Conv(int(384 * width_multiple), int(384 * width_multiple), 3, 1),
            Conv(int(384 * width_multiple), int(384 * width_multiple), 3, 1),
            Concat([-1, -3, -5, -6]),
            Conv(int(1536 * width_multiple), int(768 * width_multiple), 1, 1),  # 37
            Conv(int(768 * width_multiple), int(1024 * width_multiple), 3, 2),  # 38-P6/64
            Conv(int(1024 * width_multiple), int(512 * width_multiple), 1, 1),
            Conv(int(512 * width_multiple), int(512 * width_multiple), 1, 1),
            Conv(int(512 * width_multiple), int(512 * width_multiple), 3, 1),
            Conv(int(512 * width_multiple), int(512 * width_multiple), 3, 1),
            Conv(int(512 * width_multiple), int(512 * width_multiple), 3, 1),
            Conv(int(512 * width_multiple), int(512 * width_multiple), 3, 1),
            Concat([-1, -3, -5, -6]),
            Conv(int(2048 * width_multiple), int(1024 * width_multiple), 1, 1)  # 46
        ])

        self.head = nn.ModuleList([
            SPPCSPC(int(1024 * width_multiple), int(512 * width_multiple), k=(5, 9, 13)),  # 47
            Conv(int(512 * width_multiple), int(384 * width_multiple), 1, 1),
            nn.Upsample(scale_factor=2, mode='nearest'),  # 49
            Conv(int(768 * width_multiple), int(384 * width_multiple), 1, 1),  # Connect from 37
            Concat([-1, -2]),
            Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
            Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
            Conv(int(384 * width_multiple), int(192 * width_multiple), 3, 1),
            Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
            Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
            Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
            Concat([-1, -2, -3, -4, -5, -6]),
            Conv(int(1152 * width_multiple), int(384 * width_multiple), 1, 1),  # 59
            Conv(int(384 * width_multiple), int(256 * width_multiple), 1, 1),
            nn.Upsample(scale_factor=2, mode='nearest'),  # 61
            Conv(int(512 * width_multiple), int(256 * width_multiple), 1, 1),  # Connect from 28
            Concat([-1, -2]),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
            Conv(int(256 * width_multiple), int(128 * width_multiple), 3, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Concat([-1, -2, -3, -4, -5, -6]),
            Conv(int(768 * width_multiple), int(256 * width_multiple), 1, 1),  # 71
            Conv(int(256 * width_multiple), int(128 * width_multiple), 1, 1),
            nn.Upsample(scale_factor=2, mode='nearest'),  # 73
            Conv(int(256 * width_multiple), int(128 * width_multiple), 1, 1),  # Connect from 19
            Concat([-1, -2]),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 1, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 1, 1),
            Conv(int(128 * width_multiple), int(64 * width_multiple), 3, 1),
            Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
            Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
            Conv(int(64 * width_multiple), int(64 * width_multiple), 3, 1),
            Concat([-1, -2, -3, -4, -5, -6]),
            Conv(int(384 * width_multiple), int(128 * width_multiple), 1, 1),  # 83
            Conv(int(128 * width_multiple), int(256 * width_multiple), 3, 2),  # Connect from 83
            Concat([-1, 71]),  # Connect with 71
            Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 1, 1),
            Conv(int(256 * width_multiple), int(128 * width_multiple), 3, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Conv(int(128 * width_multiple), int(128 * width_multiple), 3, 1),
            Concat([-1, -2, -3, -4, -5, -6]),
            Conv(int(768 * width_multiple), int(256 * width_multiple), 1, 1),  # 93
            Conv(int(256 * width_multiple), int(384 * width_multiple), 3, 2),  # Connect from 93
            Concat([-1, 59]),  # Connect with 59
            Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
            Conv(int(384 * width_multiple), int(384 * width_multiple), 1, 1),
            Conv(int(384 * width_multiple), int(192 * width_multiple), 3, 1),
            Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
            Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
            Conv(int(192 * width_multiple), int(192 * width_multiple), 3, 1),
            Concat([-1, -2, -3, -4, -5, -6]),
            Conv(int(1152 * width_multiple), int(384 * width_multiple), 1, 1),  # 103
            Conv(int(384 * width_multiple), int(512 * width_multiple), 3, 2),  # Connect from 103
            Concat([-1, 47]),  # Connect with 47
            Conv(int(512 * width_multiple), int(512 * width_multiple), 1, 1),
            Conv(int(512 * width_multiple), int(512 * width_multiple), 1, 1),
            Conv(int(512 * width_multiple), int(256 * width_multiple), 3, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
            Conv(int(256 * width_multiple), int(256 * width_multiple), 3, 1),
            Concat([-1, -2, -3, -4, -5, -6]),
            Conv(512 * int(width_multiple), 512, 1, 1),  # 113
            Conv(256 * int(width_multiple), 256, 3, 1),  # Layer connected from 83
            Conv(512 * int(width_multiple), 512, 3, 1),  # Layer connected from 93
            Conv(768 * int(width_multiple), 768, 3, 1),  # Layer connected from 103
            Conv(1024 * int(width_multiple), 1024, 3, 1),  # Layer connected from 113
        ])

        self.detect = IKeypoint(nc=nc, anchors=anchors, nkpt=nkpt) 

    def forward(self, x):
        print(f"Input shape: {x.shape}")

        # Backbone
        for idx, layer in enumerate(self.backbone):
            x = layer(x)
            print(f"After backbone layer {idx} ({type(layer).__name__}): {x.shape}")

        # Head
        for idx, layer in enumerate(self.head):
            x = layer(x)
            print(f"After head layer {idx} ({type(layer).__name__}): {x.shape}")

        # Detection
        output = self.detect(x)
        print(f"After detection ({type(self.detect).__name__}): {output.shape}")

        return output

In [None]:
import torch
import torch.nn as nn
from collections import OrderedDict
import torch.onnx

def load_model_weights(model_path, model):
    checkpoint = torch.load(model_path, map_location='cpu')

    state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint['model'].state_dict() if 'model' in checkpoint else checkpoint

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v


    model.load_state_dict(new_state_dict, strict=False)

model = YOLOv7()

load_model_weights('yolov7-w6-pose.pt', model)

model.eval()

dummy_input = torch.randn(1, 3, 640, 640) 

torch.onnx.export(model,               
                  dummy_input,         
                  "yolov7-w6-pose.onnx",  
                  export_params=True,        
                  opset_version=12,          
                  do_constant_folding=True,  
                  input_names=['input'],   
                  output_names=['detections', 'keypoints'], 
                  dynamic_axes={'input': {0: 'batch_size'},    
                                'detections': {0: 'batch_size'},  
                                'keypoints': {0: 'batch_size'}})