In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from htools import *

In [5]:
# From my img_wang project.
class SmoothSoftmaxBase(nn.Module):
    """Parent class of SmoothSoftmax and SmoothLogSoftmax (softmax or log
    softmax with temperature baked in). There shouldn't be a need to
    instantiate this class directly.
    """

    def __init__(self, log=False, temperature='auto', dim=-1):
        """
        Parameters
        ----------
        log: bool
            If True, use log softmax (if this is the last activation in a
            network, it can be followed by nn.NLLLoss). If False, use softmax
            (this is more useful if you're doing something attention-related:
            no standard torch loss functions expect softmax outputs). This
            argument is usually passed implicitly by the higher level interface
            provided by the child classes.
        temperature: float or str
            If a float, this is the temperature to divide activations by before
            applying the softmax. Values larger than 1 soften the distribution
            while values between 0 and 1 sharpen it. If str ('auto'), this will
            compute the square root of the last dimension of x's shape the
            first time the forward method is called and use that for subsequent
            calls.
        dim: int
            The dimension to compute the softmax over.
        """
        super().__init__()
        self.temperature = None if temperature == 'auto' else temperature
        self.act = nn.LogSoftmax(dim=dim) if log else nn.Softmax(dim=dim)

    def forward(self, x):
        """
        Parameters
        ----------
        x: torch.float

        Returns
        -------
        torch.float: Same shape as x.
        """
        # Kind of silly but this is called every mini batch so removing an
        # extra dot attribute access saves a little time.
        while True:
            try:
                return self.act(x.div(self.temperature))
            except TypeError:
                self.temperature = np.sqrt(x.shape[-1])
            except Exception as e:
                raise e


class SmoothSoftmax(SmoothSoftmaxBase):

    def __init__(self, temperature='auto', dim=-1):
        super().__init__(log=False, temperature=temperature, dim=dim)


class SmoothLogSoftmax(SmoothSoftmaxBase):

    def __init__(self, temperature='auto', dim=-1):
        super().__init__(log=True, temperature=temperature, dim=dim)

In [387]:
class SpatialSoftmax(nn.Module):
    
    def __init__(self, temperature='auto', log=False):
        super().__init__()
        cls = SmoothLogSoftmax if log else SmoothSoftmax
        self.act = cls(temperature)
        
    def forward(self, x):
        # Should work on any tensor with shape (bs, ..., h, w).
        flattened = self.act(x.view(*x.shape[:-2], -1))
        return flattened.view(*x.shape)

In [389]:
bs = 2
c = 3
h = 4
w = 4

x = torch.randn(bs, c, h, w)
x.shape

torch.Size([2, 3, 4, 4])

In [390]:
x

tensor([[[[-2.3441, -0.7909,  1.0298,  0.4039],
          [-1.6188,  0.2625,  0.3535, -0.9783],
          [-0.7640, -1.2406, -1.1361, -2.3369],
          [-2.7585,  0.0651,  1.5677,  0.1620]],

         [[-0.5695, -1.0143,  0.2400, -1.4993],
          [ 0.8825, -0.3705, -0.1127, -1.2692],
          [ 0.2704, -1.9788, -0.6400,  0.8381],
          [-1.4564, -0.3531, -1.5069, -2.1211]],

         [[ 0.0988, -0.4415,  0.7668,  1.9533],
          [ 0.4844, -0.5616, -0.7259,  1.7374],
          [-0.2316, -0.3654, -1.1872,  0.2421],
          [ 0.3454,  0.5115,  0.1213, -0.5517]]],


        [[[-0.1965,  1.2759,  1.5981,  0.7768],
          [ 1.5828,  1.6654,  1.1962,  2.7620],
          [-0.9039, -0.5277,  0.6310, -0.4405],
          [-0.9645, -0.3029,  0.5913,  0.2449]],

         [[ 1.0945,  0.4701, -0.0326, -1.1616],
          [ 3.1092, -1.2894, -0.6499,  2.6702],
          [ 1.2273,  0.5223, -0.3593,  1.1957],
          [ 0.5903, -0.3326,  1.0712,  0.8777]],

         [[-0.8878, -0.0735,

In [391]:
x.ndim

4

In [392]:
x.view(*x.shape[:2], -1)

tensor([[[-2.3441, -0.7909,  1.0298,  0.4039, -1.6188,  0.2625,  0.3535,
          -0.9783, -0.7640, -1.2406, -1.1361, -2.3369, -2.7585,  0.0651,
           1.5677,  0.1620],
         [-0.5695, -1.0143,  0.2400, -1.4993,  0.8825, -0.3705, -0.1127,
          -1.2692,  0.2704, -1.9788, -0.6400,  0.8381, -1.4564, -0.3531,
          -1.5069, -2.1211],
         [ 0.0988, -0.4415,  0.7668,  1.9533,  0.4844, -0.5616, -0.7259,
           1.7374, -0.2316, -0.3654, -1.1872,  0.2421,  0.3454,  0.5115,
           0.1213, -0.5517]],

        [[-0.1965,  1.2759,  1.5981,  0.7768,  1.5828,  1.6654,  1.1962,
           2.7620, -0.9039, -0.5277,  0.6310, -0.4405, -0.9645, -0.3029,
           0.5913,  0.2449],
         [ 1.0945,  0.4701, -0.0326, -1.1616,  3.1092, -1.2894, -0.6499,
           2.6702,  1.2273,  0.5223, -0.3593,  1.1957,  0.5903, -0.3326,
           1.0712,  0.8777],
         [-0.8878, -0.0735, -0.4978, -0.3571,  0.8847, -0.4557,  0.5479,
          -0.8211, -1.5795,  1.4599, -1.3009, -0.2

In [393]:
sm = SpatialSoftmax(1.0)
sm

SpatialSoftmax(
  (act): SmoothSoftmax(
    (act): Softmax(dim=-1)
  )
)

In [394]:
sm(x)

tensor([[[[0.0058, 0.0276, 0.1705, 0.0912],
          [0.0121, 0.0792, 0.0867, 0.0229],
          [0.0284, 0.0176, 0.0196, 0.0059],
          [0.0039, 0.0650, 0.2921, 0.0716]],

         [[0.0461, 0.0296, 0.1036, 0.0182],
          [0.1970, 0.0563, 0.0728, 0.0229],
          [0.1068, 0.0113, 0.0430, 0.1884],
          [0.0190, 0.0573, 0.0181, 0.0098]],

         [[0.0406, 0.0237, 0.0793, 0.2596],
          [0.0598, 0.0210, 0.0178, 0.2092],
          [0.0292, 0.0255, 0.0112, 0.0469],
          [0.0520, 0.0614, 0.0416, 0.0212]]],


        [[[0.0169, 0.0738, 0.1018, 0.0448],
          [0.1003, 0.1089, 0.0681, 0.3262],
          [0.0083, 0.0122, 0.0387, 0.0133],
          [0.0079, 0.0152, 0.0372, 0.0263]],

         [[0.0494, 0.0265, 0.0160, 0.0052],
          [0.3706, 0.0046, 0.0086, 0.2389],
          [0.0564, 0.0279, 0.0115, 0.0547],
          [0.0298, 0.0119, 0.0483, 0.0398]],

         [[0.0145, 0.0327, 0.0214, 0.0246],
          [0.0852, 0.0223, 0.0608, 0.0155],
          [0.0072, 0

In [395]:
sm = SpatialSoftmax('auto')
sm(x)

tensor([[[[0.0389, 0.0574, 0.0905, 0.0774],
          [0.0467, 0.0747, 0.0764, 0.0548],
          [0.0578, 0.0513, 0.0527, 0.0390],
          [0.0351, 0.0711, 0.1035, 0.0728]],

         [[0.0624, 0.0558, 0.0764, 0.0495],
          [0.0897, 0.0656, 0.0699, 0.0524],
          [0.0770, 0.0439, 0.0613, 0.0887],
          [0.0500, 0.0659, 0.0494, 0.0423]],

         [[0.0606, 0.0529, 0.0716, 0.0963],
          [0.0667, 0.0513, 0.0493, 0.0912],
          [0.0558, 0.0539, 0.0439, 0.0628],
          [0.0644, 0.0671, 0.0609, 0.0515]]],


        [[[0.0500, 0.0722, 0.0783, 0.0637],
          [0.0780, 0.0796, 0.0708, 0.1047],
          [0.0419, 0.0460, 0.0614, 0.0470],
          [0.0412, 0.0487, 0.0608, 0.0558]],

         [[0.0682, 0.0584, 0.0515, 0.0388],
          [0.1129, 0.0376, 0.0441, 0.1012],
          [0.0705, 0.0591, 0.0474, 0.0700],
          [0.0601, 0.0478, 0.0678, 0.0646]],

         [[0.0494, 0.0606, 0.0545, 0.0565],
          [0.0770, 0.0551, 0.0708, 0.0503],
          [0.0416, 0

In [396]:
# Adapted from my annotated GPT notebook.
class ConvolutionalProjector(nn.Module):
    
    def __init__(self, c_in, spaces=3):
        super().__init__()
        self.c_in = c_in
        self.spaces = spaces
        self.conv = nn.Conv2d(c_in, c_in * spaces, 1, groups=1, bias=False)
        
    def forward(self, x):
        """Project input tensor into n subspaces. We use 3 by default to
        create query, key, and value vectors.
        """
        bs, c, h, w = x.shape
        z = self.conv(x)
        return z.view(bs, c, self.spaces, h, w).transpose(0, 1)
    
    # Old way: a similar approach might be more useful if we try to generalize
    # nD inputs, but for now we're just using conv2d so probably better to be
    # explicit and name the dimensions as we do above. Above also avoids 
    # calling shape multiple times.
#         z = self.conv(x)
#         return z.view(z.shape[0], self.c_in, self.spaces, *z.shape[-2:])\
#                 .transpose(0, 1)

In [397]:
proj = ConvolutionalProjector(c)
proj

ConvolutionalProjector(
  (conv): Conv2d(3, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
)

In [398]:
proj.conv.weight.shape

torch.Size([9, 3, 1, 1])

In [399]:
# Returns a single tensor but we can easily assign it to 3 different vars.
res = proj(x)
q, k, v = proj(x)

print(res.shape)
smap(q, k, v)

torch.Size([3, 2, 3, 4, 4])


[torch.Size([2, 3, 4, 4]), torch.Size([2, 3, 4, 4]), torch.Size([2, 3, 4, 4])]

In [400]:
(q[0, 0, 0, :] * k[0, 0, :, 0]).sum()

tensor(-1.6037, grad_fn=<SumBackward0>)

In [401]:
(q[-1, -1, -1, :] * k[-1, -1, :, -1]).sum()

tensor(0.6749, grad_fn=<SumBackward0>)

In [402]:
a1 = q @ k
assert torch.allclose(a1, torch.matmul(q, k))
print(a1.shape)
a1

torch.Size([2, 3, 4, 4])


tensor([[[[-1.6037, -0.0939, -0.2546, -2.0012],
          [-1.3959, -0.5026, -0.5371, -0.1470],
          [-1.2164, -0.0037,  0.5243, -1.4530],
          [-1.1913,  0.1414, -0.5223, -2.5689]],

         [[ 0.4544, -0.3331,  0.0215,  0.0728],
          [ 0.0412,  0.0757, -0.4914, -0.6715],
          [ 0.8307, -0.9926, -1.0531,  0.9694],
          [ 0.1037, -0.7941, -0.3201, -0.0301]],

         [[-0.4950, -0.0913, -0.9914, -0.9311],
          [ 0.0516,  0.3831, -0.6121, -0.3686],
          [ 1.0549,  0.4504,  0.5480,  1.0619],
          [-0.3888, -1.1192, -1.3545, -0.5644]]],


        [[[-1.4427,  0.6771, -1.0180, -1.0829],
          [-3.3068,  3.1099, -1.5567,  0.6423],
          [-1.3335,  1.1716, -1.2045,  0.7099],
          [-0.7767,  0.6920, -1.0570,  1.0162]],

         [[-0.7798,  0.5073,  0.5074,  0.4309],
          [ 1.7568, -2.3141,  0.5051, -0.5665],
          [ 0.0459, -0.2779, -0.0075, -0.2615],
          [ 0.7350, -1.3004,  0.4445, -0.6756]],

         [[ 0.0742,  0.2712,

In [460]:
class SpatialAttention2d(nn.Module):
    
    def __init__(self, c_in, temperature='auto', output_attentions=False):
        super().__init__()
        # TODO: not sure if projector kwargs work w/ other values.
        self.projector = ConvolutionalProjector(c_in=c_in, spaces=3)
        self.softmax = SpatialSoftmax(temperature)
        self.output_attentions = output_attentions
        
    def forward(self, x):
        # TODO: Noticed my annotated gpt2 notebook uses matrix multiply for
        # v and attn weights, while here I'm doing element-wise mult. 
        # Make sure this is okay. I think it's because with text, we project
        # into a larger dimension (spaces*hidden_dim) whereas here we don't
        # (but maybe we should? Trying to think what that would look like).
        
#         q, k, v = self.projector(x)
#         attn = self.softmax(q @ k)
#         res = (attn * v, )
#         if self.output_attentions: res = (*res, attn)
#         return res
    
        # TODO: above implementation only does self attention. 
        # I want to use between-channel info.
        q, k, v = self.projector(x)
        attn = q.unsqueeze(1) @ k.unsqueeze(2)
        res = attn * v.unsqueeze(1)
        
        # TODO: test start. Considering what we want to output. This should
        # stack along the channel dimension, so we get (bs, c_in*3, h, w).
        # Seems like a nice feature to keep h and w the same (potential for
        # skip connections) which is why I chose the channel dimension.
        # Still need to check this a little more carefully to ensure it's 
        # doing what I think it is.
        
        shape = torch.tensor(res.shape)
        res = res.view(shape[0], torch.prod(shape[1:3]), *shape[-2:])
        # TODO: test end
        
        return (res, attn) if self.output_attentions else (res,)

In [458]:
attn = SpatialAttention2d(3)
attn

SpatialAttention2d(
  (projector): ConvolutionalProjector(
    (conv): Conv2d(3, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (softmax): SpatialSoftmax(
    (act): SmoothSoftmax(
      (act): Softmax(dim=-1)
    )
  )
)

In [459]:
z, = attn(x)
z.shape

torch.Size([2, 9, 4, 4])

In [451]:
z[0, ..., :2, :2]

tensor([[[[ 0.2445,  0.0338],
          [ 0.0727,  0.0566]],

         [[-0.1141, -0.0486],
          [ 0.1168,  0.0052]],

         [[-0.2818, -0.0043],
          [ 0.0176,  0.0643]]],


        [[[-0.2578, -0.0271],
          [-0.1115, -0.0273]],

         [[ 0.1289,  0.0392],
          [-0.1250, -0.0031]],

         [[ 0.3163,  0.0034],
          [-0.0190, -0.0385]]],


        [[[-0.1025,  0.0117],
          [-0.0574, -0.0657]],

         [[-0.0786, -0.0354],
          [ 0.0895, -0.0013]],

         [[-0.1683, -0.0053],
          [ 0.0125, -0.0072]]]], grad_fn=<SliceBackward>)

In [453]:
shape = torch.tensor(z.shape)
z.view(shape[0], torch.prod(shape[1:3]), *shape[-2:]).shape#[0, ..., :2, :2]

torch.Size([2, 9, 4, 4])

In [406]:
attn = SpatialAttention2d(3, output_attentions=True)
attn

SpatialAttention2d(
  (projector): ConvolutionalProjector(
    (conv): Conv2d(3, 9, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (softmax): SpatialSoftmax(
    (act): SmoothSoftmax(
      (act): Softmax(dim=-1)
    )
  )
)

In [407]:
z, a = attn(x)
smap(z, a)

[torch.Size([2, 3, 4, 4]), torch.Size([2, 3, 4, 4])]

In [408]:
a.sum((-1, -2))

tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000]], grad_fn=<SumBackward1>)

In [409]:
z

tensor([[[[-0.0453, -0.0281,  0.0070, -0.0649],
          [-0.0225,  0.0104,  0.0217, -0.0999],
          [-0.0037, -0.0718, -0.0089, -0.0288],
          [-0.0573, -0.0138, -0.0091, -0.0232]],

         [[ 0.0548,  0.0309, -0.0473, -0.0782],
          [ 0.0188,  0.0110,  0.0153, -0.0352],
          [ 0.0191,  0.0373,  0.0621,  0.0375],
          [ 0.0569, -0.0190, -0.0397,  0.0178]],

         [[-0.0624, -0.0512,  0.0618,  0.0449],
          [-0.0097, -0.0162, -0.0149,  0.0023],
          [-0.0170, -0.0686, -0.0712, -0.0326],
          [-0.0687,  0.0095,  0.0184, -0.0554]]],


        [[[ 0.0473,  0.0253,  0.0443, -0.0045],
          [ 0.0948,  0.0115, -0.0045,  0.2593],
          [ 0.0362, -0.0316,  0.0271,  0.0221],
          [ 0.0320, -0.0754,  0.0348, -0.0065]],

         [[ 0.0342, -0.0257, -0.0199, -0.0040],
          [-0.0549, -0.0222, -0.0459, -0.0360],
          [ 0.0607, -0.0432,  0.0308,  0.0184],
          [ 0.0645, -0.0754, -0.0037, -0.0435]],

         [[-0.0100,  0.0388,

In [410]:
a

tensor([[[[0.0481, 0.0631, 0.0666, 0.0554],
          [0.0677, 0.0689, 0.0647, 0.0707],
          [0.0981, 0.0744, 0.0699, 0.0614],
          [0.0386, 0.0454, 0.0635, 0.0435]],

         [[0.0699, 0.0590, 0.0622, 0.0700],
          [0.0616, 0.0545, 0.0611, 0.0671],
          [0.0499, 0.0565, 0.0609, 0.0550],
          [0.0695, 0.0680, 0.0677, 0.0672]],

         [[0.0592, 0.0630, 0.0744, 0.0728],
          [0.0588, 0.0620, 0.0662, 0.0639],
          [0.0513, 0.0553, 0.0602, 0.0569],
          [0.0503, 0.0604, 0.0768, 0.0686]]],


        [[[0.0693, 0.0438, 0.0685, 0.0660],
          [0.0735, 0.0626, 0.0477, 0.1225],
          [0.0462, 0.0620, 0.0501, 0.0527],
          [0.0521, 0.0691, 0.0501, 0.0638]],

         [[0.0678, 0.0612, 0.0650, 0.0578],
          [0.0515, 0.0676, 0.0665, 0.0602],
          [0.0549, 0.0753, 0.0680, 0.0667],
          [0.0513, 0.0694, 0.0615, 0.0554]],

         [[0.0605, 0.0635, 0.0645, 0.0594],
          [0.0770, 0.0516, 0.0558, 0.0556],
          [0.0614, 0

### TODO

- [ ] Test what kinds of values auto temperature gives for different size images. Not sure what the expected input size is: depends how early in the network we apply this. Could plausibly have multiple layers of both regular conv and spatial attention, maybe with different temperatures as we get deeper into the net and height and width decrease. Or if we only use spatial attention, maybe height and width don't need to change and we could therefore keep the same temperature.
- [ ] Think about whether this is doing something useful. Found something else called spatial attention which does something different and probably more useful: https://paperswithcode.com/method/spatial-attention-module . That's not to say that I should match what they did necessarily, but it does highlight a weakness in my version. If we're deep in a network, a single channel probably contains a couple "activated" regions, representing features like "has shape of dog ear" or "is furry". Right now, my attention module is learning about how one "is furry" area compares to another "is furry area". But ideally, we probably want to learn where one "is furry" area compares to one "has shape of dog ear" area.

Fuzzy idea: maybe we can multiply each channel's K by every OTHER channel's Q to get attention weights rather than by its own Q. Options:
- Maybe I can just flatten each channel and use matmul.
- torch.einsum
See spiral notebook: right now attn is (bs, c, h, w) but I want (bs, c, c, h, w) or (bs, c^2, h, w).

UPDATES:
I updated class to compute attention weights over all combinations of channels, but I'm still not sure if this is useful. In text, we know it can be useful to represent words as a combination of other words near it (see: training word embeddings). In images, I'm not sure if representing a feature as a combination of other features is as useful. On the other hand, images can also be viewed as graphs and graph neural networks definitely do this, so maybe it's still reasonable. 

More thoughts:
- Could we concat the input X along the channel dimension so we have both the feature maps and the "weighted combo of features" maps?
- On 2nd thought, convolutions already share information between channels. Maybe all this is completely unnecessary. BUT I don't think they account for relationships between $P_{b0, c0, h0, w0}$ and $P_{b0, c64, h128, w128}$ (i.e. different channels AND different spatial locations). I could imagine this would be useful: e.g. "this image has a shiny thing (eye) in the middle and a hairy thing (hair) at the top)".

In [411]:
q, k, v = attn.projector(x)
(q @ k).shape

torch.Size([2, 3, 4, 4])

In [412]:
(q @ k)[0]

tensor([[[-1.2648, -0.1779,  0.0391, -0.7008],
         [ 0.1021,  0.1710, -0.0809,  0.2738],
         [ 1.5865,  0.4821,  0.2279, -0.2885],
         [-2.1455, -1.5005, -0.1557, -1.6630]],

        [[ 0.6571, -0.0211,  0.1897,  0.6626],
         [ 0.1537, -0.3381,  0.1181,  0.4916],
         [-0.6913, -0.1907,  0.1066, -0.3036],
         [ 0.6338,  0.5477,  0.5314,  0.4989]],

        [[-0.3232, -0.0751,  0.5865,  0.5024],
         [-0.3542, -0.1384,  0.1200, -0.0216],
         [-0.9016, -0.5958, -0.2570, -0.4870],
         [-0.9752, -0.2457,  0.7185,  0.2630]]], grad_fn=<SelectBackward>)

In [413]:
q.unsqueeze(1).shape

torch.Size([2, 1, 3, 4, 4])

In [414]:
q.unsqueeze(2).shape

torch.Size([2, 3, 1, 4, 4])

In [415]:
# Right shape but need to check if this is actually doing that I want.
res = torch.matmul(q.unsqueeze(1), k.unsqueeze(2))
res.shape

torch.Size([2, 3, 3, 4, 4])

In [416]:
q[0][0] @ k[0][1]

tensor([[ 0.3940, -0.0655,  0.0942,  0.3738],
        [-0.1032, -0.2228,  0.3847,  0.4522],
        [-0.6221,  0.1176,  0.3759, -0.2803],
        [ 0.8638,  0.5567,  0.4673,  0.6127]], grad_fn=<MmBackward>)

In [417]:
# For b, c, h, w: 

# When h_i=w_j, this is the same as q@k (i.e. self attention).
# res[0][0][0] = qk[0][0]
# res[0][1][1] = qk[0][1]
# res[0][2][2] = qk[0][2]

# res[1][0][0] = qk[1][0]
# res[1][1][1] = qk[1][1]
# res[1][2][2] = qk[1][2]

##################################
# res[0][0][1] = q[0][1] @ k[0][0]
# res[0][0][2] = q[0][2] @ k[0][0]
# res[0][1][0] = q[0][0] @ k[0][1]

# res[b, c1, c2, ...] = q[b, c2, ...] @ k[b, c1, ...]
res[0][1][0]

tensor([[ 0.3940, -0.0655,  0.0942,  0.3738],
        [-0.1032, -0.2228,  0.3847,  0.4522],
        [-0.6221,  0.1176,  0.3759, -0.2803],
        [ 0.8638,  0.5567,  0.4673,  0.6127]], grad_fn=<SelectBackward>)

In [419]:
a = attn.softmax(res)
a.sum((-1, -2))

tensor([[[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]]], grad_fn=<SumBackward1>)

In [423]:
(a * v.unsqueeze(1)).shape

torch.Size([2, 3, 3, 4, 4])

## Different kind of Spatial Attention from paperswithcode

Still trying to figure out why this is useful. Seems like the maxpool
should be high in many places because most areas of the image 
surely activate SOME filter (or maybe not - maybe if we choose a 
reasonable number of some areas of the image are largely ignored).

Maybe it's using these as a sort of hack to learn attention weights? I.E.
a high value doesn't tell us what's in that area but it does say "there's
something interesting here".

UPDATES: Noticed they apply this directly to the inputs (no high level features yet). They use this in a "SpatialGate" where they pool first, then pass that through a conv-bn-relu layer, then pass that output through a sigmoid. The sigmoid outputs are then used to scale the original inputs (hence the Gate name. Seems reminiscent of lstm/gru). There's actually no softmax at all here.

Note: paper says their CBAM module is a dropin for a conv block, not a whole model, and is meant to take in intermediate feature maps. So the fact that the channelpool operates directly on inputs isn't as extreme as it sounded initially. 

The core of the CBAM block is actually something called ChannelGate which DOESN'T use channel pooling. Need to spend more time understanding what that does. It uses the SpatialGate as an optional second step on the ChannelGate's outputs.

In [261]:
class ChannelPool(nn.Module):
    """From 
    https://github.com/Jongchan/attention-module/blob/5d3a54af0f6688bedca3f179593dff8da63e8274/MODELS/cbam.py#L72
    """
    
    def forward(self, x):
        """
        Returns
        -------
        torch.tensor: Shape (bs, 2, h, w) where all dimensions except the 
        channel dimension are unchanged from the input dimensions.
        """
        return torch.cat(
            [x.max(1)[0].unsqueeze(1), x.mean(1).unsqueeze(1)], dim=1
        )

In [249]:
x2 = torch.randn(bs, c, 4, 4)
x2.shape

torch.Size([2, 3, 4, 4])

In [287]:
pool = ChannelPool()
pool(x2).shape

torch.Size([2, 2, 4, 4])

In [255]:
x2.max(1)[0].unsqueeze(1)

tensor([[[[ 0.9192,  0.9164,  0.9607,  2.4313],
          [ 1.7686,  0.9094,  1.0064,  1.5090],
          [ 0.9665,  0.8555,  1.1865,  1.9622],
          [ 1.9539,  0.0046,  1.6981,  1.0057]]],


        [[[ 0.8377,  0.3892,  0.9254,  1.0246],
          [-0.9388,  1.1602,  1.6493,  0.6775],
          [ 0.3136,  1.5351,  1.2207,  1.2812],
          [ 0.4229,  1.1303,  2.7650,  0.6992]]]])

In [253]:
x2.mean(1).unsqueeze(1)

tensor([[[[ 0.2255,  0.2981,  0.3671,  0.6192],
          [ 0.9503,  0.2977, -0.4940,  1.0133],
          [ 0.2477, -0.4888,  0.3105,  0.9771],
          [ 0.0368, -0.1245,  0.2561, -0.4880]]],


        [[[-0.0326, -0.2258,  0.0792, -0.1351],
          [-1.0789,  0.9497,  0.6635,  0.2583],
          [-0.6534,  0.5303,  0.0463,  0.3939],
          [-0.2778,  0.6453,  1.0496, -0.1302]]]])

In [239]:
nn.MaxPool2d(kernel_size=x2.shape[-1], stride=1)(x2).squeeze()

tensor([[1.5338, 2.8634, 2.2469],
        [2.4798, 2.9073, 2.0179]])

## Scratch

In [63]:
conv1 = nn.Conv2d(c, c*3, 1, groups=1, bias=False)
conv2 = nn.Conv2d(c, c*3, 1, groups=3, bias=False)

In [65]:
conv1(x)

tensor([[[[-1.8149e-01, -9.8219e-01,  7.0230e-02, -8.1756e-01],
          [-5.5817e-01, -2.1600e+00,  7.2470e-01,  6.8560e-01],
          [ 5.4090e-01, -6.9522e-01, -3.8275e-01,  6.7900e-01],
          [ 2.3532e-01, -8.1782e-01,  2.3031e+00,  7.5006e-02]],

         [[-1.2837e-01, -4.3930e-01,  1.5289e-01, -5.9028e-01],
          [-4.9126e-01, -1.5850e+00,  3.7685e-01,  2.3783e-01],
          [ 3.8888e-01, -2.8695e-01,  4.1176e-01,  2.7987e-01],
          [-1.5640e-01, -4.6738e-01,  2.1840e+00,  5.4696e-01]],

         [[-1.2380e-01, -1.7929e-01, -4.8263e-02, -2.5773e-01],
          [-4.8813e-02, -5.5613e-01,  2.5324e-01,  1.8768e-01],
          [ 1.4011e-01, -2.2262e-01, -3.7227e-01,  2.0429e-01],
          [ 1.2380e-01, -1.6675e-01,  4.0662e-01, -1.0374e-01]],

         [[-2.0736e-01,  7.7723e-01, -2.2901e-01,  1.5783e-01],
          [ 5.3646e-01,  8.6285e-01, -1.0946e-01, -3.4214e-01],
          [-2.1630e-01,  2.0996e-01, -5.6337e-01, -2.5327e-01],
          [ 1.8572e-03,  5.4152e-0

In [121]:
shape = conv2.weight.shape
with torch.no_grad():
    conv2.weight.data = torch.arange(shape[0])\
                             .view(shape[0], *[1]*(len(shape)-1)).float()

In [122]:
x[0]

tensor([[[ 1.1026e+00, -1.3052e+00,  3.2674e-01,  1.2588e+00],
         [-3.4807e-01,  1.8726e+00, -7.8483e-01,  4.2860e-01],
         [-4.4428e-01,  1.7124e-01,  7.2784e-01, -7.9687e-04],
         [ 4.4602e-01, -4.2235e-01, -1.8255e+00, -6.7862e-01]],

        [[-6.5279e-01,  1.4068e+00, -5.6674e-01, -8.4416e-02],
         [ 1.0725e+00,  9.4541e-01,  1.6971e-01, -4.5089e-01],
         [-2.3649e-01,  1.0384e-01, -1.7775e+00, -2.2743e-01],
         [ 1.8197e-01,  9.0580e-01, -2.3105e+00, -7.3189e-01]],

        [[-1.8626e-01,  1.9660e+00,  4.4586e-02,  3.6956e-01],
         [ 4.6333e-01,  1.4049e+00, -7.5455e-01, -1.3606e+00],
         [-3.7741e-01,  1.0789e+00,  1.6159e+00, -1.1100e+00],
         [-1.0816e+00,  1.1959e+00, -4.8594e-01,  1.2200e+00]]])

In [128]:
z = conv2(x)
z.shape

torch.Size([2, 9, 4, 4])

In [132]:
q, k, v = z.view(z.shape[0], c, c, *z.shape[-2:]).transpose(0, 1)

In [133]:
smap(q, k, v)

[torch.Size([2, 3, 4, 4]), torch.Size([2, 3, 4, 4]), torch.Size([2, 3, 4, 4])]

Was experimenting with einsum as a way to do my q k broadcasted matrix multiplication, but fortunately I found a different (simpler) way.

In [316]:
x3 = torch.randn(2, 3, 4, 4)
x4 = torch.randn(2, 3, 4, 4)

In [323]:
# Right shape but no clue if this is actually doing that I want.
res = torch.einsum('abij,acjk->abcik', x3, x4)
res

tensor([[[[[ 1.2403e+00, -1.0589e+00, -2.6218e+00, -2.1962e-01],
           [ 1.5456e+00, -1.3448e+00, -1.5296e+00, -1.4196e+00],
           [-7.1985e-01,  1.1622e+00, -8.5005e-01,  1.0664e+00],
           [ 3.0225e+00, -2.6262e+00, -5.9555e+00, -8.4185e-01]],

          [[ 1.2202e+00,  1.9351e-01, -1.0199e+00, -4.8780e-01],
           [ 2.0950e-01, -8.8629e-01, -7.4623e-03,  2.1819e+00],
           [-4.0034e-01,  1.5097e-01, -7.1156e-02, -4.0540e-01],
           [ 2.8856e+00, -8.0512e-01, -1.6272e+00,  1.5047e+00]],

          [[ 6.4828e-01, -1.3674e-01,  5.0591e-01, -1.0983e+00],
           [-1.0043e+00, -4.7152e-03, -4.1933e+00,  1.7472e+00],
           [ 1.4408e+00, -1.5071e-01,  4.6706e+00, -2.3493e+00],
           [ 1.6439e+00, -3.7970e-01,  4.8559e-01, -1.8974e+00]]],


         [[[-1.0202e+00,  5.2712e-02,  1.8747e+00,  1.2337e+00],
           [ 2.9561e+00, -1.4705e+00, -6.5982e+00, -1.5241e+00],
           [-1.2448e+00,  1.0389e+00,  3.0559e+00, -4.5393e-02],
           [-1.62

In [324]:
res.shape

torch.Size([2, 3, 3, 4, 4])