Skip to content

Commit

Permalink
add ViT search space
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed Dec 12, 2022
1 parent 3b4d6a4 commit b2f033a
Show file tree
Hide file tree
Showing 3 changed files with 357 additions and 2 deletions.
5 changes: 3 additions & 2 deletions hyperbox/networks/ofa/ofa_mbv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
kernel_size_list: List[int] = [3, 5, 7],
expand_ratio_list: List[float] = [3, 4, 6],
depth_list: List[int] = [2, 3, 4],
base_stage_width: List[int] = [16, 16, 24, 40, 80, 112, 160, 960, 1280],
base_stage_width: List[int] = [16, 16, 24, 40, 80, 112, 160, 960, 1280], # indices in [1,6] are searchable
stride_stages: List[int] = [1, 2, 2, 2, 1, 2],
act_stages: List[str] = ['relu', 'relu', 'relu', 'h_swish', 'h_swish', 'h_swish'],
se_stages: List[bool] = [False, False, True, False, True, True],
Expand All @@ -39,7 +39,8 @@ def __init__(
final_expand_width = make_divisible(base_stage_width[-2] * self.width_mult, self.CHANNEL_DIVISIBLE)
last_channel = make_divisible(base_stage_width[-1] * self.width_mult, self.CHANNEL_DIVISIBLE)

n_block_list = [1] + [max(self.depth_list)] * 5
num_searchable_stages = len(stride_stages) - 1
n_block_list = [1] + [max(self.depth_list)] * num_searchable_stages
width_list = []
for base_width in base_stage_width[:-2]:
width = make_divisible(base_width * self.width_mult, self.CHANNEL_DIVISIBLE)
Expand Down
Empty file.
354 changes: 354 additions & 0 deletions hyperbox/networks/vit/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,354 @@
from typing import Dict, List, Tuple, Union, Optional, Callable
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from hyperbox.networks.base_nas_network import BaseNASNetwork
from hyperbox.mutables import spaces, ops


# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)

def keepPositiveList(l):
return [x for x in l if x > 0]


# model classes
class PreNorm(nn.Module):
def __init__(self, dim: int, fn: Callable):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
def __init__(
self,
dim: int, # input dimension
hidden_dim: int, # hidden dimension
search_ratio: List[float] = [0.5, 1.], # search ratio for hidden dimension (hidden_dim)
dropout: float = 0., # dropout rate
suffix: str = '', # suffix for the name of the module
mask= None, # mask for the search space (mutables)
):
super().__init__()
self.mask = mask
hidden_dim_list = keepPositiveList([int(r*hidden_dim) for r in search_ratio])
hidden_dim_list = spaces.ValueSpace(hidden_dim_list, key=f"{suffix}_hidden_dim", mask=self.mask)
self.net = nn.Sequential(
ops.Linear(dim, hidden_dim_list),
nn.GELU(),
nn.Dropout(dropout),
ops.Linear(hidden_dim_list, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)


class Attention(nn.Module):
def __init__(
self,
dim: int, # input dimension
search_ratio: List[float], # search ratio for hidden dimension
heads: int = 8, # number of attention heads
dim_head: int = 64, # dimension of each attention head
dropout: float = 0., # dropout
suffix: str = None, # suffix for naming
mask: dict = None, # mask for the search space (mutables)
):
super().__init__()
self.mask = mask
self.heads_list = keepPositiveList([int(r*heads) for r in search_ratio])
self.dim_head_list = keepPositiveList([int(r*dim_head) for r in search_ratio])
self.scale_list = [dh**-0.5 for dh in self.dim_head_list]

self.inner_dim_list = []
self.heads_idx_map = {}
self.dim_head_idx_map = {}
count = 0
for h_idx, h in enumerate(self.heads_list):
for d_idx, d in enumerate(self.dim_head_list):
inner_dim = h * d
self.inner_dim_list.append(inner_dim)
self.heads_idx_map[count] = h_idx
self.dim_head_idx_map[count] = d_idx
count += 1
self.inner_dim_list = spaces.ValueSpace(self.inner_dim_list, key=f"{suffix}_inner_dim", mask=self.mask)

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

qkv_dim_list = [x*3 for x in self.inner_dim_list]
qkv_dim_list = spaces.ValueSpace(qkv_dim_list, key=f"{suffix}_inner_dim", mask=self.mask) # coupled with self.inner_dim_list
self.to_qkv = ops.Linear(dim, qkv_dim_list, bias = False)

self.to_out = nn.Sequential(
ops.Linear(self.inner_dim_list, dim),
nn.Dropout(dropout)
)

def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

@property
def heads(self):
idx = self.inner_dim_list.index
if idx is None:
idx = self.inner_dim_list.mask.float().argmax().item()
idx = self.heads_idx_map[idx]
return self.heads_list[idx]

@property
def scale(self):
idx = self.inner_dim_list.index
if idx is None:
idx = self.inner_dim_list.mask.float().argmax().item()
idx = self.dim_head_idx_map[idx]
return self.scale_list[idx]

@property
def dim_head(self):
idx = self.inner_dim_list.index
if idx is None:
idx = self.inner_dim_list.mask.float().argmax().item()
idx = self.dim_head_idx_map[idx]
return self.dim_head_list[idx]


class Transformer(nn.Module):
def __init__(
self,
dim: int, # dimension of the model
depth: int, # depth of the model
heads: int, # number of heads
dim_head: int, # dimension of each head
mlp_dim: int, # dimension of the feedforward layer
search_ratio: list, # search ratio for hidden_dim and inner_dim
dropout: float= 0., # dropout rate
mask: dict = None, # mask for the search space (mutables)
):
super().__init__()
self.search_ratio = search_ratio
self.mask = mask
self.layers = nn.ModuleList([])
for idx, d in enumerate(range(depth)):
attKey = f"att_{idx}"
ffKey = f"ff_{idx}"
self.layers.append(nn.ModuleList([
PreNorm(
dim, Attention(
dim, heads = heads, dim_head = dim_head, search_ratio=self.search_ratio,
dropout = dropout, suffix = attKey, mask = self.mask)
),
PreNorm(
dim, FeedForward(
dim, mlp_dim, search_ratio=self.search_ratio, dropout = dropout,
suffix = ffKey, mask = self.mask))
]))

runtime_depth = [v for v in range(1, depth + 1)]
self.run_depth = spaces.ValueSpace(runtime_depth, key='run_depth', mask=self.mask)

def forward(self, x):
for attn, ff in self.layers[:self.run_depth.value]:
x = attn(x) + x
x = ff(x) + x
return x


class VisionTransformer(BaseNASNetwork):
def __init__(
self, *,
image_size: Union[int, Tuple[int, int]], # image size
patch_size: Union[int, Tuple[int, int]], # patch size
num_classes: int, # number of classes
dim: int, # dim of patch embedding
depth: int, # depth of transformer
heads: int, # number of heads
mlp_dim: int, # hidden dim of mlp
search_ratio: Optional[List[float]] = None,
pool: str = 'cls', # 'cls' or 'mean'
channels: int = 3, # number of input channels
dim_head: int = 64, # dimension of each attention head
dropout: float = 0., # dropout rate
emb_dropout: float = 0., # embedding dropout rate
mask: dict = None, # mask for the search space (mutables)
):
super().__init__(mask = mask)
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
if search_ratio is None:
search_ratio = [0.5, 0.75, 1]
self.search_ratio = search_ratio

assert image_height % patch_height == 0 and image_width % patch_width == 0, \
'Image dimensions must be divisible by the patch size.'

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(
dim=dim, depth=depth, heads=heads, dim_head=dim_head, mlp_dim=mlp_dim,
dropout=dropout, search_ratio=search_ratio, mask=self.mask)

self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)

def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)

x = self.transformer(x)

x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x)


_vit_s = dict(
image_size=224,
patch_size=16,
num_classes=1000,
dim=768,
depth=4,
heads=6,
dim_head=64, # intermediate_size=12*64=768
mlp_dim=1024,
search_ratio=[0.5, 0.75, 1],
pool='cls',
channels=3,
dropout=0.1,
emb_dropout=0.1,
)

_vit_b = dict(
image_size=224,
patch_size=16,
num_classes=1000,
dim=768,
depth=12,
heads=12,
dim_head=64, # intermediate_size=12*64=768
mlp_dim=3072,
search_ratio=[0.5, 0.75, 1],
pool='cls',
channels=3,
dropout=0.1,
emb_dropout=0.1,
)

_vit_h = dict(
image_size=224,
patch_size=16,
num_classes=1000,
dim=1280,
depth=32,
heads=16,
dim_head=80,
mlp_dim=5120,
search_ratio=[0.5, 0.75, 1],
pool='cls',
channels=3,
dropout=0.1,
emb_dropout=0.1,
)

_vit_g = dict(
image_size=224,
patch_size=16,
num_classes=1000,
dim=1664,
depth=48,
heads=16,
dim_head=104,
mlp_dim=8192,
search_ratio=[0.5, 0.75, 1],
pool='cls',
channels=3,
dropout=0.1,
emb_dropout=0.1,
)

_vit_10b = dict(
image_size=224,
patch_size=16,
num_classes=1000,
dim=4096,
depth=50,
heads=16,
dim_head=256,
mlp_dim=16384,
search_ratio=[0.5, 0.75, 1],
pool='cls',
channels=3,
dropout=0.1,
emb_dropout=0.1,
)

ViT = partial(VisionTransformer, **_vit_b)
ViT_S = partial(VisionTransformer, **_vit_s)
ViT_B = partial(VisionTransformer, **_vit_b)
ViT_H = partial(VisionTransformer, **_vit_h)
ViT_G = partial(VisionTransformer, **_vit_g)
ViT_10B = partial(VisionTransformer, **_vit_10b)


if __name__ == '__main__':
from hyperbox.mutator import RandomMutator
# device = 'cpu'
# device = 'mps'
device = 'cuda'
net = ViT(image_size = 224, patch_size = 16, num_classes = 1000, dim = 1024,
depth = 6, heads = 16, dim_head=1024, mlp_dim = 2048)
x = torch.rand(2,3,224,224).to(device)
net = net.to(device)
rm = RandomMutator(net)
for i in range(10):
rm.reset()
# print(rm._cache)
print(net.arch_size((2,3,224,224), convert=False, verbose=True))
y = net(x)
print(y.shape, y.device)

0 comments on commit b2f033a

Please sign in to comment.