Skip to content

Commit

Permalink
add a 3d version of cct, addressing #238
Browse files Browse the repository at this point in the history
0.38.1
  • Loading branch information
lucidrains committed Oct 29, 2022
1 parent 6ec8fda commit ad1e6df
Show file tree
Hide file tree
Showing 4 changed files with 506 additions and 95 deletions.
30 changes: 30 additions & 0 deletions README.md
Expand Up @@ -1023,6 +1023,36 @@ video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, widt
preds = v(video) # (4, 1000)
```

3D version of <a href="https://github.com/lucidrains/vit-pytorch#cct">CCT</a>

```python
import torch
from vit_pytorch.cct_3d import CCT

cct = CCT(
img_size = 224,
num_frames = 8,
embedding_dim = 384,
n_conv_layers = 2,
frame_kernel_size = 3,
kernel_size = 7,
stride = 2,
padding = 3,
pooling_kernel_size = 3,
pooling_stride = 2,
pooling_padding = 1,
num_layers = 14,
num_heads = 6,
mlp_radio = 3.,
num_classes = 1000,
positional_embedding = 'learnable'
)

video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width)
pred = cct(video)
print(pred.shape)
```

## ViViT

<img src="./images/vivit.png" width="350px"></img>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.37.1',
version = '0.38.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
193 changes: 99 additions & 94 deletions vit_pytorch/cct.py
@@ -1,9 +1,17 @@
import torch
import torch.nn as nn
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def pair(t):
return t if isinstance(t, tuple) else (t, t)

Expand Down Expand Up @@ -50,8 +58,9 @@ def cct_16(*args, **kwargs):
def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
kernel_size=3, stride=None, padding=None,
*args, **kwargs):
stride = stride if stride is not None else max(1, (kernel_size // 2) - 1)
padding = padding if padding is not None else max(1, (kernel_size // 2))
stride = default(stride, max(1, (kernel_size // 2) - 1))
padding = default(padding, max(1, (kernel_size // 2)))

return CCT(num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
Expand All @@ -61,13 +70,22 @@ def _cct(num_layers, num_heads, mlp_ratio, embedding_dim,
padding=padding,
*args, **kwargs)

# positional

def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return rearrange(pe, '... -> 1 ...')

# modules

class Attention(nn.Module):
def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1):
super().__init__()
self.num_heads = num_heads
head_dim = dim // self.num_heads
self.heads = num_heads
head_dim = dim // self.heads
self.scale = head_dim ** -0.5

self.qkv = nn.Linear(dim, dim * 3, bias=False)
Expand All @@ -77,17 +95,20 @@ def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]

attn = (q @ k.transpose(-2, -1)) * self.scale
qkv = self.qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

q = q * self.scale

attn = einsum('b h i d, b h j d -> b h i j', q, k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
x = einsum('b h i j, b h j d -> b h i d', attn, v)
x = rearrange(x, 'b h n d -> b n (h d)')

return self.proj_drop(self.proj(x))


class TransformerEncoderLayer(nn.Module):
Expand All @@ -97,7 +118,8 @@ class TransformerEncoderLayer(nn.Module):
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
attention_dropout=0.1, drop_path_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
super().__init__()

self.pre_norm = nn.LayerNorm(d_model)
self.self_attn = Attention(dim=d_model, num_heads=nhead,
attention_dropout=attention_dropout, projection_dropout=dropout)
Expand All @@ -108,50 +130,34 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout2 = nn.Dropout(dropout)

self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.drop_path = DropPath(drop_path_rate)

self.activation = F.gelu

def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
def forward(self, src, *args, **kwargs):
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))
src = src + self.drop_path(self.dropout2(src2))
return src


def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
"""
Obtained from: github.com:rwightman/pytorch-image-models
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
super().__init__()
self.drop_prob = float(drop_prob)

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
batch, drop_prob, device, dtype = x.shape[0], self.drop_prob, x.device, x.dtype

if drop_prob <= 0. or not self.training:
return x

keep_prob = 1 - self.drop_prob
shape = (batch, *((1,) * (x.ndim - 1)))

keep_mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < keep_prob
output = x.div(keep_prob) * keep_mask.float()
return output

class Tokenizer(nn.Module):
def __init__(self,
Expand All @@ -164,34 +170,35 @@ def __init__(self,
activation=None,
max_pool=True,
conv_bias=False):
super(Tokenizer, self).__init__()
super().__init__()

n_filter_list = [n_input_channels] + \
[in_planes for _ in range(n_conv_layers - 1)] + \
[n_output_channels]

n_filter_list_pairs = zip(n_filter_list[:-1], n_filter_list[1:])

self.conv_layers = nn.Sequential(
*[nn.Sequential(
nn.Conv2d(n_filter_list[i], n_filter_list[i + 1],
nn.Conv2d(chan_in, chan_out,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=(padding, padding), bias=conv_bias),
nn.Identity() if activation is None else activation(),
nn.Identity() if not exists(activation) else activation(),
nn.MaxPool2d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding) if max_pool else nn.Identity()
)
for i in range(n_conv_layers)
for chan_in, chan_out in n_filter_list_pairs
])

self.flattener = nn.Flatten(2, 3)
self.apply(self.init_weight)

def sequence_length(self, n_channels=3, height=224, width=224):
return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]

def forward(self, x):
return self.flattener(self.conv_layers(x)).transpose(-2, -1)
return rearrange(self.conv_layers(x), 'b c h w -> b (h w) c')

@staticmethod
def init_weight(m):
Expand All @@ -214,106 +221,104 @@ def __init__(self,
sequence_length=None,
*args, **kwargs):
super().__init__()
positional_embedding = positional_embedding if \
positional_embedding in ['sine', 'learnable', 'none'] else 'sine'
assert positional_embedding in {'sine', 'learnable', 'none'}

dim_feedforward = int(embedding_dim * mlp_ratio)
self.embedding_dim = embedding_dim
self.sequence_length = sequence_length
self.seq_pool = seq_pool

assert sequence_length is not None or positional_embedding == 'none', \
assert exists(sequence_length) or positional_embedding == 'none', \
f"Positional embedding is set to {positional_embedding} and" \
f" the sequence length was not specified."

if not seq_pool:
sequence_length += 1
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim),
requires_grad=True)
self.class_emb = nn.Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad=True)
else:
self.attention_pool = nn.Linear(self.embedding_dim, 1)

if positional_embedding != 'none':
if positional_embedding == 'learnable':
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
requires_grad=True)
nn.init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = nn.Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim),
requires_grad=False)
else:
if positional_embedding == 'none':
self.positional_emb = None
elif positional_embedding == 'learnable':
self.positional_emb = nn.Parameter(torch.zeros(1, sequence_length, embedding_dim),
requires_grad=True)
nn.init.trunc_normal_(self.positional_emb, std=0.2)
else:
self.positional_emb = nn.Parameter(sinusoidal_embedding(sequence_length, embedding_dim),
requires_grad=False)

self.dropout = nn.Dropout(p=dropout_rate)

dpr = [x.item() for x in torch.linspace(0, stochastic_depth_rate, num_layers)]

self.blocks = nn.ModuleList([
TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads,
dim_feedforward=dim_feedforward, dropout=dropout_rate,
attention_dropout=attention_dropout, drop_path_rate=dpr[i])
for i in range(num_layers)])
attention_dropout=attention_dropout, drop_path_rate=layer_dpr)
for layer_dpr in dpr])

self.norm = nn.LayerNorm(embedding_dim)

self.fc = nn.Linear(embedding_dim, num_classes)
self.apply(self.init_weight)

def forward(self, x):
if self.positional_emb is None and x.size(1) < self.sequence_length:
b = x.shape[0]

if not exists(self.positional_emb) and x.size(1) < self.sequence_length:
x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode='constant', value=0)

if not self.seq_pool:
cls_token = self.class_emb.expand(x.shape[0], -1, -1)
cls_token = repeat(self.class_emb, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_token, x), dim=1)

if self.positional_emb is not None:
if exists(self.positional_emb):
x += self.positional_emb

x = self.dropout(x)

for blk in self.blocks:
x = blk(x)

x = self.norm(x)

if self.seq_pool:
x = torch.matmul(F.softmax(self.attention_pool(x), dim=1).transpose(-1, -2), x).squeeze(-2)
attn_weights = rearrange(self.attention_pool(x), 'b n 1 -> b n')
x = einsum('b n, b n d -> b d', attn_weights.softmax(dim = 1), x)
else:
x = x[:, 0]

x = self.fc(x)
return x
return self.fc(x)

@staticmethod
def init_weight(m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
if isinstance(m, nn.Linear) and exists(m.bias):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

@staticmethod
def sinusoidal_embedding(n_channels, dim):
pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
for p in range(n_channels)])
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(0)


# CCT Main model

class CCT(nn.Module):
def __init__(self,
img_size=224,
embedding_dim=768,
n_input_channels=3,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
*args, **kwargs):
super(CCT, self).__init__()
def __init__(
self,
img_size=224,
embedding_dim=768,
n_input_channels=3,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
*args, **kwargs
):
super().__init__()
img_height, img_width = pair(img_size)

self.tokenizer = Tokenizer(n_input_channels=n_input_channels,
Expand Down

0 comments on commit ad1e6df

Please sign in to comment.