Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extracting attention maps #5

Open
vibrant-galaxy opened this issue Jan 28, 2021 · 3 comments
Open

Extracting attention maps #5

vibrant-galaxy opened this issue Jan 28, 2021 · 3 comments

Comments

@vibrant-galaxy
Copy link

Hi there,

Excellent project!

I'm using axial-attention with video (1, 5, 128, 256, 256) and sum_axial_out=True, and I wish to visualise the attention maps.

Essentially, given my video, and two frame indices frame_a_idx and frame_b_idx, I need to extract the attention map over frame_b to a chosen pixel (x, y) in frame_a (after the axial sum).

My understanding is that I should be able to reshape the dots (after softmax) according to the permutations in calculate_permutations, then sum these permuted dots together to form a final attention score tensor of an accessible shape, thus ready for visualisation.

I am slightly stuck due to the numerous axial permutations and shape mismatches. What I am doing is as follows:

In SelfAttention.forward():

dots_reshaped = dots.reshape(b, h, t, t)
return out, dots_reshaped

In PermuteToFrom.forward():

 # attention
axial, dots = self.fn(axial, **kwargs)

# restore to original shape and permutation
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
dots = dots.reshape(*shape[:3], *dots.shape[1:])

However, I am unsure of how to un-permute the dots appropriately such that all resulting “axes” (of different sizes) can be summed. If you have suggestions or code for doing so, it would be very much appreciated, thanks!

@lucidrains
Copy link
Owner

@vibrant-galaxy i'm not actually sure if it will be too interpretable as it is, since attention is done along each axis separately, and information can take up to two steps to be routed.

however, i think what may be worth trying (and I haven't built it into this repo yet) is to do axial attention and then expand the attention map of each axis along the other axis and then sum, softmax, aggregate values. perhaps it could lead to something more interpretable, as you would have the full attention map. would you be interested in trying this if i were to build it?

@vibrant-galaxy
Copy link
Author

That sounds like a good approach to get the full map. Yes, I am very much interested in trying that!

@lucidrains
Copy link
Owner

I tried to do something like the below, but it actually goes out of memory when you try to expand and sum the pre-attention maps

So basically I don't think it's possible lol, unless if you see a way to make it work

import torch
from torch import einsum, nn
from einops import rearrange

class AxialAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.to_q = nn.Linear(dim, inner_dim, bias = False)

        self.to_height_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_width_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_frame_k = nn.Linear(dim, inner_dim, bias = False)

        self.to_v = nn.Linear(dim, inner_dim, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        heads, b, f, c, h, w = self.heads, *x.shape

        x = rearrange(x, 'b f c h w -> b f h w c')

        q = self.to_q(x)

        k_height = self.to_height_k(x)
        k_width = self.to_width_k(x)
        k_frame = self.to_frame_k(x)

        v = self.to_v(x)

        q, k_height, k_width, k_frame, v = map(lambda t: rearrange(t, 'b f x y (h d) -> (b h) f x y d', h = heads), (q, k_height, k_width, k_frame, v))

        q *= q.shape[-1] ** -0.5

        sim_frame = einsum('b f h w d, b j h w d -> b f h w j', q, k_frame)
        sim_frame = sim_frame[..., :, None, None].expand(-1, -1, -1, -1, -1, h, w)

        sim_height = einsum('b f h w d, b f k w d -> b f h w k', q, k_height)
        sim_height = sim_height[..., None, :, None].expand(-1, -1, -1, -1, f, -1, w)

        sim_width = einsum('b f h w d, b f h l d -> b f h w l', q, k_width)
        sim_width = sim_width[..., None, None, :].expand(-1, -1, -1, -1, f, h, -1)

        sim = rearrange(sim_frame + sim_height + sim_width, 'b f h w j k l -> b f h w (j k l)')
        attn = sim.softmax(dim = -1)

        attn = rearrange(attn, 'b f h w (j k l) -> b f h w j k l', j = f, k = h, l = w)
        out = einsum('b f h w j k l, b j k l d -> b f h w d', attn, v)

        out = rearrange(out, '(b h) f x y d -> b f x y (h d)', h = heads)
        out = self.to_out(out)
        out = rearrange(out, 'b f x y d -> b f d x y')

        return out, attn

layer = AxialAttention(dim = 16)
video = torch.randn(1, 5, 16, 32, 32)
out, attn = layer(video)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants