Attention provides a means for a model to understand how parts of a model that are not necessarily closely connected in a model, can be considered as jointly linked (as per words in different parts of a sentence or one part of an image to another part in a different region)

The model below is based upon attention as developed for NLP.  It is similar to the diffusers model in that at the end the inp is added to the attention, but in pytorch attention this is not the case



In [1]:
import math,torch
from torch import nn
from miniai.activations import *

In [2]:
import matplotlib.pyplot as plt

In [3]:
from diffusers.models.attention import AttentionBlock

In [4]:
set_seed(42)
x = torch.randn(64,32,16,16)

Batch size 63, 32 channels, hxw = 16, 16.

In the next cell the channels are moved to the end and the hw (the equivalent of sequence in NLP) are concatenated into a single vector and are usually before the channels

In [5]:
t = x.view(*x.shape[:2], -1).transpose(1, 2)
t.shape

torch.Size([64, 256, 32])

In [6]:
ni = 32

Create linear layers to generate the k, q, v vectors from the input to the attention model (ni values)

In [7]:
sk = nn.Linear(ni, ni)
sq = nn.Linear(ni, ni)
sv = nn.Linear(ni, ni)

In [8]:
k = sk(t)
q = sq(t)
v = sv(t)

In [16]:
v.shape

torch.Size([64, 256, 32])

In [11]:
q.shape, k.transpose(1,2).shape

(torch.Size([64, 256, 32]), torch.Size([64, 32, 256]))

For transformers it is common to multiply the vectors together to generate an outer product (showing effectively linkage between every pixel of either vector.

Because all of the k, q, and v vectors are fed from the same x this is called self attention

In [9]:
(q@k.transpose(1,2)).shape

torch.Size([64, 256, 256])

In [15]:
(q@k.transpose(1,2)).softmax(dim=-1).shape

torch.Size([64, 256, 256])

In [51]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        # create GroupNorm (batchnorm split into a sets of channels)
        self.norm = nn.GroupNorm(1, ni)
        self.q = nn.Linear(ni, ni)
        self.k = nn.Linear(ni, ni)
        self.v = nn.Linear(ni, ni)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, x):
        inp = x
        n,c,h,w = x.shape
        # Apply normalisation
        x = self.norm(x)
        # Reshape to concatenate rows and to move the channels to the end
        x = x.view(n, c, -1).transpose(1, 2)
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        # Obtain the outer product of q and k.  Divide by scale to normalise
        s = (q@k.transpose(1,2))/self.scale
        # Take a softmax of the above outer product across the last dimension.
        # Then take the product of the above with the v vector. This provides a matrix
        # of pixels x channels with a prioritization for high priority links
        x = s.softmax(dim=-1)@v
        # Pass this though a final linear layer
        x = self.proj(x)
        # Swop the channel and hw channels, then reshape back to the original shape
        x = x.transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [25]:
sa = SelfAttention(32)

In [26]:
ra = sa(x)
ra.shape

torch.Size([64, 32, 16, 16])

In [27]:
ra[0,0,0]

tensor([ 1.9104,  1.4186,  0.8385, -2.1584,  0.6318, -1.2443, -0.0789, -1.6844,
        -0.7939,  1.6117, -0.3852, -1.4307, -0.7494, -0.6010, -0.8335,  0.7477],
       grad_fn=<SelectBackward0>)

Work out whether the attention class above provides the same output as the diffusion attention block

In [29]:
def cp_parms(a,b):
    b.weight = a.weight
    b.bias = a.bias

Copy the weights and biases from the above class into the diffusion attention instance

In [30]:
at = AttentionBlock(32, norm_num_groups=1)
src = sa.q,sa.k,sa.v,sa.proj,sa.norm
dst = at.query,at.key,at.value,at.proj_attn,at.group_norm
for s,d in zip(src,dst): cp_parms(s,d)

In [31]:
rb = at(x)
rb[0,0,0]

tensor([ 1.9104,  1.4186,  0.8385, -2.1584,  0.6318, -1.2443, -0.0789, -1.6844,
        -0.7939,  1.6117, -0.3852, -1.4307, -0.7494, -0.6010, -0.8335,  0.7477],
       grad_fn=<SelectBackward0>)

It can be seen that the output of the two networks is identical given the same weights

To reduce the computation overhead the k, q, v matricies can be combine and then separated using chunks.  This makes moving them around more efficient

In [32]:
sqkv = nn.Linear(ni, ni*3)
st = sqkv(t)
st.shape

torch.Size([64, 256, 96])

In [33]:
q,k,v = torch.chunk(st, 3, dim=-1)
q.shape

torch.Size([64, 256, 32])

In [34]:
(k@q.transpose(1,2)).shape

torch.Size([64, 256, 256])

In [38]:
class SelfAttention(nn.Module):
    def __init__(self, ni):
        super().__init__()
        self.scale = math.sqrt(ni)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, x):
        inp = x
        n,c,h,w = x.shape
        # Apply normalisation
        x = self.norm(x)
        # Reshape to concatenate rows and to move the channels to the end
        x = x.view(n, c, -1).transpose(1, 2)
        q, k, v = torch.chunk(self.qkv(x), 3, dim=-1)
        # Obtain the outer product of q and k.  Divide by scale to normalise
        s = (q@k.transpose(1,2))/self.scale
        # Take a softmax of the above outer product across the last dimension.
        # Then take the product of the above with the v vector. This provides a matrix
        # of pixels x channels with a prioritization for high priority links
        x = s.softmax(dim=-1)@v
        # Pass this though a final linear layer
        x = self.proj(x)
        # Swop the channel and hw channels, then reshape back to the original shape
        x = x.transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [39]:
sa = SelfAttention(32)
sa(x).shape

torch.Size([64, 32, 16, 16])

In [40]:
sa(x).std()

tensor(1.0085, grad_fn=<StdBackward0>)

### Multi-Head Self Attention

To allow the model to allow multiple things to be focussed upon we can provide multiple heads.  
This is done by splitting the channels by the number of heads, and then processing each head independently. To 
do so the channels are taken out and made into independent items with the batch, hence 
each set of channels and image appears as a separate sub image

In [None]:
def heads_to_batch(x, heads):
    """ Initially splits the channels into to dimensions (heads, channels/heads)
    Then transposes sl and heads. Then reshapes so that each combination of head and batch is separate at the first dim
    """
    n,sl,d = x.shape
    x = x.reshape(n, sl, heads, -1)
    return x.transpose(2, 1).reshape(n*heads,sl,-1)

def batch_to_heads(x, heads):
    n,sl,d = x.shape
    x = x.reshape(-1, heads, sl, d)
    return x.transpose(2, 1).reshape(-1,sl,d*heads)

In [41]:
from einops import rearrange

In [42]:
t2 = rearrange(t , 'n s (h d) -> (n h) s d', h=8)
t.shape, t2.shape

(torch.Size([64, 256, 32]), torch.Size([512, 256, 4]))

In [43]:
t3 = rearrange(t2, '(n h) s d -> n s (h d)', h=8)

In [44]:
t2.shape,t3.shape

(torch.Size([512, 256, 4]), torch.Size([64, 256, 32]))

In [45]:
(t==t3).all()

tensor(True)

In [46]:
class SelfAttentionMultiHead(nn.Module):
    def __init__(self, ni, nheads):
        super().__init__()
        self.nheads = nheads
        self.scale = math.sqrt(ni/nheads)
        self.norm = nn.BatchNorm2d(ni)
        self.qkv = nn.Linear(ni, ni*3)
        self.proj = nn.Linear(ni, ni)
    
    def forward(self, inp):
        n,c,h,w = inp.shape
        x = self.norm(inp).view(n, c, -1).transpose(1, 2)
        x = self.qkv(x)
        x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
        q,k,v = torch.chunk(x, 3, dim=-1)
        s = (q@k.transpose(1,2))/self.scale
        x = s.softmax(dim=-1)@v
        x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
        x = self.proj(x).transpose(1,2).reshape(n,c,h,w)
        return x+inp

In [47]:
sa = SelfAttentionMultiHead(32, 4)
sx = sa(x)
sx.shape

torch.Size([64, 32, 16, 16])

In [48]:
sx.mean(),sx.std()

(tensor(-0.0222, grad_fn=<MeanBackward0>),
 tensor(1.0069, grad_fn=<StdBackward0>))

### Note for Pytorch multihead attention.

By default pytorch MHA will be default not expect batch first unless the flag below is set.  The dafault is for the batch to be second.

Also, it is necessary to pass in the same to each input (kqv) in order to make it self attention.  If different things are passed in then it become cross attention

In [49]:
nm = nn.MultiheadAttention(32, num_heads=8, batch_first=True)
nmx,nmw = nm(t,t,t)
nmx = nmx+t

In [50]:
nmx.mean(),nmx.std()

(tensor(-0.0015, grad_fn=<MeanBackward0>),
 tensor(1.0034, grad_fn=<StdBackward0>))

In [54]:
test = torch.rand([6,6])
test

tensor([[0.4115, 0.3150, 0.3019, 0.0508, 0.6761, 0.8469],
        [0.7011, 0.2775, 0.5324, 0.3479, 0.7456, 0.9074],
        [0.4694, 0.9891, 0.9687, 0.6516, 0.6563, 0.5602],
        [0.0490, 0.9218, 0.8198, 0.7353, 0.5030, 0.9022],
        [0.1250, 0.4525, 0.6666, 0.2004, 0.3990, 0.2803],
        [0.3316, 0.7570, 0.0450, 0.0627, 0.5231, 0.9098]])

In [55]:
nn.functional.softmax(test/0.1, dim = -1)

tensor([[1.0684e-02, 4.0694e-03, 3.5706e-03, 2.8967e-04, 1.5057e-01, 8.3082e-01],
        [9.3837e-02, 1.3574e-03, 1.7362e-02, 2.7434e-03, 1.4642e-01, 7.3828e-01],
        [2.9040e-03, 5.2495e-01, 4.2817e-01, 1.7955e-02, 1.8823e-02, 7.2026e-03],
        [6.8814e-05, 4.2497e-01, 1.5320e-01, 6.5852e-02, 6.4489e-03, 3.4946e-01],
        [3.6371e-03, 9.6244e-02, 8.1883e-01, 7.7308e-03, 5.6367e-02, 1.7192e-02],
        [2.4837e-03, 1.7486e-01, 1.4142e-04, 1.6868e-04, 1.6858e-02, 8.0548e-01]])

In [56]:
nn.functional.softmax(test, dim = -1)

tensor([[0.1575, 0.1430, 0.1411, 0.1098, 0.2052, 0.2434],
        [0.1826, 0.1195, 0.1543, 0.1283, 0.1909, 0.2244],
        [0.1277, 0.2148, 0.2104, 0.1532, 0.1540, 0.1399],
        [0.0872, 0.2088, 0.1886, 0.1733, 0.1374, 0.2048],
        [0.1304, 0.1810, 0.2242, 0.1406, 0.1715, 0.1523],
        [0.1421, 0.2174, 0.1067, 0.1086, 0.1720, 0.2533]])