-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
178 lines (121 loc) · 5.92 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import numpy as np
from tqdm import tqdm, trange
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST
np.random.seed(0)
torch.manual_seed(0)
def patchify(images, n_patches: int):
n, c, h, w = images.shape
assert h == w, "Patches can be done for square images only"
patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
patch_size = h // n_patches # Patch size = width / no of patches
for idx, image in enumerate(images):
# For each i and j, iterating through the image
for i in range(n_patches):
for j in range(n_patches):
patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
patches[idx, i * n_patches + j] = patch.flatten()
return patches
def get_positional_embeddings(sequence_length, d):
result = torch.ones(sequence_length, d)
# Calculating the positioning embedding for the image
for i in range(sequence_length):
for j in range(d):
result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
return result
class MultiHeadSelfAttention(nn.Module):
def __init__(self, d_model, n_heads: int = 2) -> None:
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
assert d_model % n_heads == 0, f"Can't divide dimension {d_model} into {n_heads} heads"
d_head = int(d_model / n_heads)
self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
self.d_head = d_head
self.softmax = nn.Softmax(dim=-1)
def forward(self, sequences):
# (N, seq_length, token_dim) -> (N, seq_length, n_heads, token_dim / n_heads) -> (N, seq_length, item_dim) (through concatenation)
result = []
for sequence in sequences:
seq_result = []
for head in range(self.n_heads):
q_mapping = self.q_mappings[head]
k_mapping = self.k_mappings[head]
v_mapping = self.v_mappings[head]
seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
seq_result.append(attention @ v)
result.append(torch.hstack(seq_result))
return torch.cat([torch.unsqueeze(r, dim = 0) for r in result])
class VisualTransformerBlock(nn.Module):
def __init__(self, hidden_d, n_heads, mlp_ratio=4):
super().__init__()
self.hidden_d = hidden_d
self.n_heads = n_heads
self.norm1 = nn.LayerNorm(hidden_d)
self.mhsa = MultiHeadSelfAttention(hidden_d, n_heads)
self.norm2 = nn.LayerNorm(hidden_d)
self.mlp = nn.Sequential(
nn.Linear(hidden_d, mlp_ratio * hidden_d),
nn.GELU(),
nn.Linear(mlp_ratio * hidden_d, hidden_d)
)
def forward(self, x):
out = x + self.mhsa(self.norm1(x))
out = out + self.mlp(self.norm2(out))
return out
class VisualTransformer(nn.Module):
def __init__(self, chw, n_patches: int = 7, n_blocks: int = 2, hidden_d: int = 8, n_heads: int = 2, out_d: int = 10) -> None:
# Super Constructor
super().__init__()
# Self Attributes:
self.chw = chw # (Channel, Height, Width)
self.n_patches = n_patches # Total number of patches
self.n_patches = n_patches
self.n_blocks = n_blocks
self.n_heads = n_heads
self.hidden_d = hidden_d
# Check if the number of patches are perfectly divisible with the Width and Height
assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)
# 1) Linear mapper
self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
# 2) Learnable classifiation token
self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
# 3) Positional embedding
self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)
# 4) Transformer encoder blocks
self.blocks = nn.ModuleList([VisualTransformerBlock(hidden_d, n_heads) for _ in range(n_blocks)])
# 5) Classification MLPk
self.mlp = nn.Sequential(
nn.Linear(self.hidden_d, out_d),
nn.Softmax(dim=-1)
)
def forward(self, images):
# Dividing into patches
n, c, h, w = images.shape
patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)
# Running linear layer tokenization
# Map the vector corresponding to each patch to the hidden size dimension
tokens = self.linear_mapper(patches)
# Adding classification token to the tokens
tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)
# Adding positional embedding
pos_embed = self.positional_embeddings.repeat(n, 1, 1)
out = tokens + pos_embed
# Transformer Blocks
for block in self.blocks:
out = block(out)
# Getting the classification token only
out = out[:, 0]
return self.mlp(out) # Map to output dimension, output category distribution