Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DependencyViT #2062

Draft
wants to merge 78 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
014d125
Create dependencyvit.py
fffffgggg54 Dec 18, 2023
e8cd67a
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
7858568
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
753d25a
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
7fff1fb
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
6466517
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
74a55aa
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
bff9111
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
a56729d
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
376ec76
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
f7645a2
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
304976c
Update __init__.py
fffffgggg54 Dec 18, 2023
de87c61
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
7f83ebf
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
f10d8db
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
f79f6b8
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
26955cf
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
1719974
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
152a6ed
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
c47ab88
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
8040b62
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
99fca58
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
0b1024b
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
d6cc2b3
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
6a2ffbc
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
d15f55d
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
ed47beb
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
c72b8f4
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
4aacf65
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
3f98fb7
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
88fa74f
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
15f912f
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
b6576ba
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
80d7eb6
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
df6a699
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
143a977
Update dependencyvit.py
fffffgggg54 Dec 18, 2023
ba8dd06
Update dependencyvit.py
fffffgggg54 Dec 19, 2023
3bc97f0
Update dependencyvit.py
fffffgggg54 Dec 19, 2023
4009465
Update dependencyvit.py
fffffgggg54 Dec 20, 2023
a8e4d0c
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
23a6eb2
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
c75522a
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
60dd4ad
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
2a6030d
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
8ce8310
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
6b7b1bc
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
b134f6b
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
4cfbb33
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
c151ead
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
6ca0fbc
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
5b305a8
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
73b7ef6
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
df457f2
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
7c076a7
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
5e7ecff
Update dependencyvit.py
fffffgggg54 Dec 23, 2023
8ec5ed7
Update dependencyvit.py
fffffgggg54 Dec 24, 2023
7d1c049
Update dependencyvit.py
fffffgggg54 Dec 24, 2023
2f4927a
Update dependencyvit.py
fffffgggg54 Dec 24, 2023
f628864
Update dependencyvit.py
fffffgggg54 Dec 24, 2023
1b0fb07
Update dependencyvit.py
fffffgggg54 Dec 24, 2023
ee3fa6e
Update dependencyvit.py
fffffgggg54 Dec 24, 2023
506f859
Update dependencyvit.py
fffffgggg54 Dec 24, 2023
68857cf
Update dependencyvit.py
fffffgggg54 Dec 24, 2023
1b71be9
Update dependencyvit.py
fffffgggg54 Dec 25, 2023
33c6b37
Update dependencyvit.py
fffffgggg54 Dec 25, 2023
9e2189a
Update dependencyvit.py
fffffgggg54 Dec 25, 2023
d353c23
Update dependencyvit.py
fffffgggg54 Dec 25, 2023
352eb29
Update dependencyvit.py
fffffgggg54 Dec 25, 2023
2172867
Update test_models.py
fffffgggg54 Dec 25, 2023
94e5558
Merge branch 'huggingface:main' into dependencyvit
fffffgggg54 Dec 25, 2023
e6f8765
Update dependencyvit.py
fffffgggg54 Dec 25, 2023
63f853d
Merge branch 'dependencyvit' of https://github.com/fffffgggg54/pytorc…
fffffgggg54 Dec 25, 2023
abe3abf
fix syntax, type and shape annotations
fffffgggg54 Dec 26, 2023
9bfa1ee
Update test_models.py
fffffgggg54 Dec 26, 2023
30c82f2
Update dependencyvit.py
fffffgggg54 Dec 26, 2023
5edd3a3
Update dependencyvit.py
fffffgggg54 Dec 26, 2023
aac764c
Update dependencyvit.py
fffffgggg54 Dec 27, 2023
11e8739
Update dependencyvit.py
fffffgggg54 Dec 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*'
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'dependencyvit_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

Expand All @@ -69,7 +69,7 @@
EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*']

EXCLUDE_JIT_FILTERS = []
EXCLUDE_JIT_FILTERS = ['dependencyvit_*']

TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
Expand Down
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .davit import *
from .deit import *
from .densenet import *
from .dependencyvit import *
from .dla import *
from .dpn import *
from .edgenext import *
Expand Down
321 changes: 321 additions & 0 deletions timm/models/dependencyvit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
""" DependencyViT

From-scratch implementation of DependencyViT in PyTorch

'Visual Dependency Transformers: Dependency Tree Emerges from Reversed Attention'
- https://arxiv.org/abs/2304.03282

ReversedAttention implementation derived from timm's Vision Transformer implementation

Implementation for timm by / Copyright 2023, Fredo Guan
"""

import math
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import Final

from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import DropPath, Mlp
from timm.models.vision_transformer import VisionTransformer
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model

__all__ = ['DependencyViT']

class TokenPruner(nn.Module):
def __init__(
self,
prune_ratio: float,
prune_index: int,
) -> None:
super().__init__()
self.pct_kept_tokens = (1 - prune_index * prune_ratio) / (1 - (prune_index - 1) * prune_ratio)

# [B, N, C], [B, 1, 1, N], [B, N] -> [B, N', C], [B, 1, 1, N']
def forward(self, x: torch.Tensor, m: torch.Tensor, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
B, N, C = x.shape
topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False)[1] # [B, N']
x = x.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, C)) # [B, N', C]
m = m.gather(3, topk_indices.unsqueeze(1).unsqueeze(1)) # [B, 1, 1, N']
return (x, m)


class ReversedAttention(nn.Module):
dependency_mask: Optional[torch.Tensor]

def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.track_dependency_mask = False
self.dependency_mask = None
self.head_selector_temperature = 0.1 # appendix D.1

self.head_selector = nn.Linear(dim, num_heads, bias=False) # FIXME is there a bias term?

self.message_controller = Mlp(
in_features = dim,
hidden_features = int(dim/2),
out_features = 1,
act_layer = nn.GELU,
bias = False, # FIXME is there a bias term?
)

#self.token_pruner = None

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

# m is cumulative over all layers (m = m_i * m_i-1 * ... * m_1)
# [B, N, C], [B, 1, 1, N] -> [B, N, C], [B, 1, 1, N], [B, N]
def forward(self, x: torch.Tensor, m: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)

p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1)
p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N)

m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N)

q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn).transpose(-2, -1) # this transpose prevents use of sdpa
attn = attn * p * m # [B, n_h, N, N]
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)


#FIXME absolute value?
self.dependency_mask = attn.detach().sum(1) if self.track_dependency_mask else None # [B, N, N]

#FIXME which pruning mask?

#prune_mask = attn.detach().sum(1).sum(-1)
#prune_mask = attn.detach().sum(1).abs().sum(-1)
#prune_mask = attn.detach().abs().sum((1, -1))
#prune_mask = attn.sum(1).sum(-1)
#prune_mask = attn.sum(1).abs().sum(-1)
#prune_mask = attn.abs().sum((1, -1))
#prune_mask = m.reshape(B, N)
prune_mask = m.detach().reshape(B, N)

x = self.proj(x)
x = self.proj_drop(x)
return (x, m, prune_mask)

class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma


class DependencyViTBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
init_values: Optional[float] = None,
drop_path: float = 0.,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
mlp_layer: nn.Module = Mlp,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = ReversedAttention(
dim=dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

self.token_pruner = None

self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
x, m = in_tuple
x_new, m, prune_mask = self.attn(self.norm1(x), m)
x = x + self.drop_path1(self.ls1(x_new))
x, m = self.token_pruner(x, m, prune_mask) if self.token_pruner else (x, m)
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return (x, m)


# FIXME verify against reference impl
# FIXME train weights that meet or exceed results from paper

class DependencyViT(VisionTransformer):
def __init__(
self,
prune_layers: Optional[Union[List[int], Tuple[int]]] = None,
prune_ratio: Optional[float] = None,
*args,
**kwargs
) -> None:
super().__init__(
*args,
**kwargs,
block_fn = DependencyViTBlock,
class_token=False,
global_pool='avg',
qkv_bias=False,
init_values=1e-6,
fc_norm=False,
)

if prune_layers is not None:
self.prune_layers = sorted(list(dict.fromkeys(prune_layers)))
self.prune_ratio = prune_ratio

assert max(self.prune_layers) <= len(self.blocks), "1 or more pruned layer indices exceed model depth"
assert self.prune_ratio * len(self.prune_layers) < 1, "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1"

self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indicess
for prune_index, layer in enumerate(self.prune_layers, 1):
self.blocks[layer].token_pruner = TokenPruner(self.prune_ratio, prune_index)


def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
B, N, _ = x.shape
m = torch.Tensor([1]).to(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x, m = checkpoint_seq(self.blocks, (x, m))
else:
x, m = self.blocks((x, m))

#x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm

x = self.norm(x)
x = x * m.transpose(1, 3).squeeze(-1)
return x

def track_dependency_mask(self, track: bool = True) -> None:
for block in self.blocks:
if block.attn.track_dependency_mask is not track:
block.attn.dependency_mask = None
block.attn.track_dependency_mask = track

def get_dependency_mask(self, layers: Optional[Union[List[int], Tuple[int]]] = None) -> List[torch.Tensor]:
# L' * [B, N, N]
# L' * [B, N', N']
result = []
layers = layers if layers else range(len(self.blocks))
for layer in layers:
result.append(self.blocks[layer].attn.dependency_mask)
return result




def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': None,
'crop_pct': 0.9,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN,
'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj',
'classifier': 'head',
**kwargs,
}


default_cfgs = {
'dependencyvit_tiny_patch16_224.untrained': _cfg(url=''),
'dependencyvit_small_patch16_224.untrained': _cfg(url=''),

'dependencyvit_lite_tiny_patch16_224.untrained': _cfg(url=''),
}


default_cfgs = generate_default_cfgs(default_cfgs)



def _create_dependencyvit(variant: str, pretrained: bool = False, **kwargs) -> DependencyViT:
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')

return build_model_with_cfg(
DependencyViT,
variant,
pretrained,
**kwargs,
)

@register_model
def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT:
model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12)
model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model

@register_model
def dependencyvit_small_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT:
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12)
model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model

@register_model
def dependencyvit_lite_tiny_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT:
model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12, prune_layers=[2, 5, 8, 11], prune_ratio=0.16)
model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model