In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LocalAttention(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(LocalAttention, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.attn_conv = nn.Conv2d(out_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.softmax = nn.Softmax(dim=2)
    
    def forward(self, x):
        x = self.conv(x) # Apply a convolutional layer
        N, C, H, W = x.shape # Get the shape of the tensor
        
        # Compute the local attention
        attn = self.attn_conv(x)
        attn = attn.reshape(N, C, -1)
        attn = self.softmax(attn)
        attn = attn.reshape(N, C, H, W)
        
        # Apply the attention
        x = x * attn
        return x