In [53]:
import torch
import torch.nn as nn
import timm
from PIL import Image
from torchvision import transforms
from torch.cuda.amp import autocast
from torchvision.ops import RoIAlign

def get_roi_align_layer(output_size=(4, 4), spatial_scale=0.0625, sampling_ratio=0):
    # 创建 ROI Align 层
    roi_align = RoIAlign(output_size, spatial_scale, sampling_ratio)
    return roi_align

def preprocess_image(image_path):
    img = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = transform(img).unsqueeze(0)
    return img
    
class JigsawViT(nn.Module):

    def __init__(self, pretrained_cfg_file, num_labels=2):
        super(JigsawViT, self).__init__()
        # pretrained_cfg = timm.models.create_model(
        #     'crossvit_base_240.in1k').default_cfg
        # pretrained_cfg['file'] = pretrained_cfg_file
        # self.model = timm.models.create_model(
        #     'crossvit_base_240.in1k', pretrained=True, num_classes=num_labels, pretrained_cfg=pretrained_cfg)
        pretrained_cfg = timm.models.create_model(
            'pit_s_distilled_224.in1k').default_cfg
        print(pretrained_cfg)
        pretrained_cfg['file'] = pretrained_cfg_file
        self.backbone = timm.models.create_model(
            'pit_s_distilled_224.in1k', pretrained=True, features_only=True, pretrained_cfg=pretrained_cfg)
        self.roi_align = get_roi_align_layer()
        self.fc1 = nn.Linear(9216, 2048)
        self.fc2 = nn.Linear(2048, 512)
        self.fc3 = nn.Linear(512, num_labels)
        
        
    # @autocast()
    def forward(self, x, roi):
        features = self.backbone(x)[-1]
        aligned_features = self.roi_align(features, rois)
        x = aligned_features.view(aligned_features.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output

In [54]:
image = preprocess_image("dataset/MIT_ex/fragment_0001.png")
pretrained_cfg_file='/data/csl/code/piece/models/pit_s-distilled_224/model.safetensors'
jit = JigsawViT(pretrained_cfg_file, 2)
rois = torch.tensor([[0, 60, 30, 180, 150]], dtype=torch.float32)
output = jit(image, rois)
print(output.shape)

{'url': '', 'hf_hub_id': 'timm/pit_s_distilled_224.in1k', 'architecture': 'pit_s_distilled_224', 'tag': 'in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'fixed_input_size': True, 'interpolation': 'bicubic', 'crop_pct': 0.9, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': None, 'first_conv': 'patch_embed.conv', 'classifier': ('head', 'head_dist')}
torch.Size([1, 2])


In [45]:
import torch
import timm
from torchvision.ops import RoIAlign



def load_pit_backbone():
    # 加载预训练的 PiT 模型
    pretrained_cfg = timm.models.create_model(
            'pit_s_distilled_224.in1k').default_cfg
    pretrained_cfg['file'] = pretrained_cfg_file
    model = timm.models.create_model(
        'pit_s_distilled_224.in1k', pretrained=True, features_only=True, pretrained_cfg=pretrained_cfg)
    
    return model

def get_roi_align_layer(output_size=(4, 4), spatial_scale=0.0625, sampling_ratio=0):
    # 创建 ROI Align 层
    roi_align = RoIAlign(output_size, spatial_scale, sampling_ratio)
    return roi_align

def preprocess_image(image_path):
    img = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = transform(img).unsqueeze(0)
    return img
def extract_features(backbone, image):
    # 通过 Backbone 提取特征
    features = backbone(image)[-1]
    return features
def extract_roi_features(roi_align, features, rois):
    # 使用 ROI Align 提取特征
    aligned_features = roi_align(features, rois)
    return aligned_features

backbone = load_pit_backbone()
roi_align = get_roi_align_layer()

image = preprocess_image("dataset/MIT_ex/fragment_0001.png")
features = extract_features(backbone, image)
print(features.shape)

# 假设的ROI
rois = torch.tensor([[0, 60, 30, 180, 150]], dtype=torch.float32)

aligned_features = extract_roi_features(roi_align, features, rois)
aligned_features = aligned_features.view(aligned_features.size(0), -1)
print(aligned_features.shape)
fc = nn.Linear(aligned_features.shape[1], 2)
f = fc(aligned_features)
print(f.shape)




torch.Size([1, 576, 7, 7])
torch.Size([1, 9216])
torch.Size([1, 2])


In [58]:
import torch
import torch.nn as nn
import timm
from torch.cuda.amp import autocast
from torchvision.ops import RoIAlign

def get_roi_align_layer(output_size=(4, 4), spatial_scale=0.0625, sampling_ratio=0):
    # 创建 ROI Align 层
    roi_align = RoIAlign(output_size, spatial_scale, sampling_ratio)
    return roi_align

def preprocess_image(image_path):
    img = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = transform(img).unsqueeze(0)
    return img


import torch
from einops import rearrange
from torch import nn
import math

from functools import partial
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block as transformer_block
from timm.models.registry import register_model

class Transformer(nn.Module):
    def __init__(self, base_dim, depth, heads, mlp_ratio,
                 drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
        super(Transformer, self).__init__()
        self.layers = nn.ModuleList([])
        embed_dim = base_dim * heads

        if drop_path_prob is None:
            drop_path_prob = [0.0 for _ in range(depth)]

        self.blocks = nn.ModuleList([
            transformer_block(
                dim=embed_dim,
                num_heads=heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=True,
                proj_drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=drop_path_prob[i],
                norm_layer=partial(nn.LayerNorm, eps=1e-6)
            )
            for i in range(depth)])

    def forward(self, x, cls_tokens):
        h, w = x.shape[2:4]
        x = rearrange(x, 'b c h w -> b (h w) c')

        token_length = cls_tokens.shape[1]
        x = torch.cat((cls_tokens, x), dim=1)
        for blk in self.blocks:
            x = blk(x)

        cls_tokens = x[:, :token_length]
        x = x[:, token_length:]
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

        return x, cls_tokens


class conv_head_pooling(nn.Module):
    def __init__(self, in_feature, out_feature, stride,
                 padding_mode='zeros'):
        super(conv_head_pooling, self).__init__()

        self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=stride + 1,
                              padding=stride // 2, stride=stride,
                              padding_mode=padding_mode, groups=in_feature)
        self.fc = nn.Linear(in_feature, out_feature)

    def forward(self, x, cls_token):

        x = self.conv(x)
        cls_token = self.fc(cls_token)

        return x, cls_token


class conv_embedding(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size,
                 stride, padding):
        super(conv_embedding, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size,
                              stride=stride, padding=padding, bias=True)

    def forward(self, x):
        x = self.conv(x)
        return x


class PoolingTransformer(nn.Module):
    def __init__(self, image_size, patch_size, stride, base_dims, depth, heads,
                 mlp_ratio, num_classes=1000, in_chans=3,
                 attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
        super(PoolingTransformer, self).__init__()

        total_block = sum(depth)
        padding = 0
        block_idx = 0

        width = math.floor(
            (image_size + 2 * padding - patch_size) / stride + 1)

        self.base_dims = base_dims
        self.heads = heads
        self.num_classes = num_classes

        self.patch_size = patch_size
        self.pos_embed = nn.Parameter(
            torch.randn(1, base_dims[0] * heads[0], width, width),
            requires_grad=True
        )
        self.patch_embed = conv_embedding(in_chans, base_dims[0] * heads[0],
                                          patch_size, stride, padding)

        self.cls_token = nn.Parameter(
            torch.randn(1, 1, base_dims[0] * heads[0]),
            requires_grad=True
        )
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.transformers = nn.ModuleList([])
        self.pools = nn.ModuleList([])

        for stage in range(len(depth)):
            drop_path_prob = [drop_path_rate * i / total_block
                              for i in range(block_idx, block_idx + depth[stage])]
            block_idx += depth[stage]

            self.transformers.append(
                Transformer(base_dims[stage], depth[stage], heads[stage],
                            mlp_ratio,
                            drop_rate, attn_drop_rate, drop_path_prob)
            )
            if stage < len(heads) - 1:
                self.pools.append(
                    conv_head_pooling(base_dims[stage] * heads[stage],
                                      base_dims[stage + 1] * heads[stage + 1],
                                      stride=2
                                      )
                )

        self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
        self.embed_dim = base_dims[-1] * heads[-1]

        # Classifier head
        if num_classes > 0:
            self.head = nn.Linear(base_dims[-1] * heads[-1], num_classes)
        else:
            self.head = nn.Identity()

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        if num_classes > 0:
            self.head = nn.Linear(self.embed_dim, num_classes)
        else:
            self.head = nn.Identity()

    def forward_features(self, x):
        x = self.patch_embed(x)

        pos_embed = self.pos_embed
        x = self.pos_drop(x + pos_embed)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)

        for stage in range(len(self.pools)):
            x, cls_tokens = self.transformers[stage](x, cls_tokens)
            x, cls_tokens = self.pools[stage](x, cls_tokens)
        x, cls_tokens = self.transformers[-1](x, cls_tokens)

        cls_tokens = self.norm(cls_tokens)

        return cls_tokens

    def forward(self, x):
        cls_token = self.forward_features(x)
        cls_token = self.head(cls_token[:, 0])
        return cls_token


class DistilledPoolingTransformer(PoolingTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cls_token = nn.Parameter(
            torch.randn(1, 2, self.base_dims[0] * self.heads[0]),
            requires_grad=True)
        if self.num_classes > 0:
            self.head_dist = nn.Linear(self.base_dims[-1] * self.heads[-1],
                                       self.num_classes)
        else:
            self.head_dist = nn.Identity()

        trunc_normal_(self.cls_token, std=.02)
        self.head_dist.apply(self._init_weights)

    def forward(self, x):
        cls_token = self.forward_features(x)
        x_cls = self.head(cls_token[:, 0])
        x_dist = self.head_dist(cls_token[:, 1])
        if self.training:
            return x_cls, x_dist
        else:
            return (x_cls + x_dist) / 2

@register_model
def pit_b(pretrained, **kwargs):
    model = PoolingTransformer(
        image_size=224,
        patch_size=14,
        stride=7,
        base_dims=[64, 64, 64],
        depth=[3, 6, 4],
        heads=[4, 8, 16],
        mlp_ratio=4,
        **kwargs
    )
    if pretrained:
        state_dict = \
        torch.load('weights/pit_b_820.pth', map_location='cpu')
        model.load_state_dict(state_dict)
    return model

@register_model
def pit_s(pretrained, **kwargs):
    model = PoolingTransformer(
        image_size=224,
        patch_size=16,
        stride=8,
        base_dims=[48, 48, 48],
        depth=[2, 6, 4],
        heads=[3, 6, 12],
        mlp_ratio=4,
        **kwargs
    )
    if pretrained:
        state_dict = \
        torch.load('weights/pit_s_809.pth', map_location='cpu')
        model.load_state_dict(state_dict)
    return model


@register_model
def pit_xs(pretrained, **kwargs):
    model = PoolingTransformer(
        image_size=224,
        patch_size=16,
        stride=8,
        base_dims=[48, 48, 48],
        depth=[2, 6, 4],
        heads=[2, 4, 8],
        mlp_ratio=4,
        **kwargs
    )
    if pretrained:
        state_dict = \
        torch.load('weights/pit_xs_781.pth', map_location='cpu')
        model.load_state_dict(state_dict)
    return model

@register_model
def pit_ti(pretrained, **kwargs):
    model = PoolingTransformer(
        image_size=224,
        patch_size=16,
        stride=8,
        base_dims=[32, 32, 32],
        depth=[2, 6, 4],
        heads=[2, 4, 8],
        mlp_ratio=4,
        **kwargs
    )
    if pretrained:
        state_dict = \
        torch.load('weights/pit_ti_730.pth', map_location='cpu')
        model.load_state_dict(state_dict)
    return model


@register_model
def pit_b_distilled(pretrained, **kwargs):
    model = DistilledPoolingTransformer(
        image_size=224,
        patch_size=14,
        stride=7,
        base_dims=[64, 64, 64],
        depth=[3, 6, 4],
        heads=[4, 8, 16],
        mlp_ratio=4,
        **kwargs
    )
    if pretrained:
        state_dict = \
        torch.load('weights/pit_b_distill_840.pth', map_location='cpu')
        model.load_state_dict(state_dict)
    return model


@register_model
def pit_s_distilled(pretrained, **kwargs):
    model = DistilledPoolingTransformer(
        image_size=224,
        patch_size=16,
        stride=8,
        base_dims=[48, 48, 48],
        depth=[2, 6, 4],
        heads=[3, 6, 12],
        mlp_ratio=4,
        **kwargs
    )
    if pretrained:
        state_dict = \
        torch.load('weights/pit_s_distill_819.pth', map_location='cpu')
        model.load_state_dict(state_dict)
    return model


@register_model
def pit_xs_distilled(pretrained, **kwargs):
    model = DistilledPoolingTransformer(
        image_size=224,
        patch_size=16,
        stride=8,
        base_dims=[48, 48, 48],
        depth=[2, 6, 4],
        heads=[2, 4, 8],
        mlp_ratio=4,
        **kwargs
    )
    if pretrained:
        state_dict = \
        torch.load('weights/pit_xs_distill_791.pth', map_location='cpu')
        model.load_state_dict(state_dict)
    return model


@register_model
def pit_ti_distilled(pretrained, **kwargs):
    model = DistilledPoolingTransformer(
        image_size=224,
        patch_size=16,
        stride=8,
        base_dims=[32, 32, 32],
        depth=[2, 6, 4],
        heads=[2, 4, 8],
        mlp_ratio=4,
        **kwargs
    )
    if pretrained:
        state_dict = \
        torch.load('weights/pit_ti_distill_746.pth', map_location='cpu')
        model.load_state_dict(state_dict)
    return model
    
class JigsawViT(nn.Module):

    def __init__(self, pretrained_cfg_file, num_labels=2):
        super(JigsawViT, self).__init__()
        # pretrained_cfg = timm.models.create_model(
        #     'crossvit_base_240.in1k').default_cfg
        # pretrained_cfg['file'] = pretrained_cfg_file
        # self.model = timm.models.create_model(
        #     'crossvit_base_240.in1k', pretrained=True, num_classes=num_labels, pretrained_cfg=pretrained_cfg)
        pretrained_cfg = timm.models.create_model(
            'pit_s_distilled_224.in1k').default_cfg
        print(pretrained_cfg)
        pretrained_cfg['file'] = pretrained_cfg_file
        self.backbone = timm.models.create_model(
            'pit_s_distilled_224.in1k', pretrained=True, features_only=True, pretrained_cfg=pretrained_cfg)
        self.roi_align = get_roi_align_layer()
        self.fc1 = nn.Linear(9216, 2048)
        self.fc2 = nn.Linear(2048, 512)
        self.fc3 = nn.Linear(512, num_labels)
        
        
    # @autocast()
    def forward(self, x, roi):
        features = self.backbone(x)[-1]
        aligned_features = self.roi_align(features, rois)
        x = aligned_features.view(aligned_features.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output

import torch


model = pit_s(pretrained=False)
# model.load_state_dict(torch.load('./weights/pit_s_809.pth'))
print(model(torch.randn(1, 3, 224, 224)).shape)

  def pit_b(pretrained, **kwargs):
  def pit_s(pretrained, **kwargs):
  def pit_xs(pretrained, **kwargs):
  def pit_ti(pretrained, **kwargs):
  def pit_b_distilled(pretrained, **kwargs):
  def pit_s_distilled(pretrained, **kwargs):
  def pit_xs_distilled(pretrained, **kwargs):
  def pit_ti_distilled(pretrained, **kwargs):


torch.Size([1, 1000])
