![](./figs/External_Attention.png)

In [15]:
import time

import torch
import torch.nn as nn
import torch.nn.init as init

class ExternalAttention(nn.Module):
    def __init__(self, input_dim=64,output_dim=32):
        super().__init__()
        self.mk = nn.Linear(input_dim,output_dim)
        self.mv = nn.Linear(output_dim,input_dim)
        self.softmax = nn.Softmax(dim=1)
        self.init_weights()
        
    def init_weights(self):
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias,0)
            elif isinstance(m,nn.BatchNorm2d):
                init.constant_(m.weight,1)
                init.constant_(m.weight,0)
            elif isinstance(m,nn.Linear):
                init.normal_(m.weight,std=1e-3)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    
    def forward(self,x):
        attn = self.mk(x)
        attn = self.softmax(attn)
        attn = attn / torch.sum(attn,dim=2,keepdim=True)
        out = self.mv(attn)
        return out
    

0 Linear(in_features=64, out_features=32, bias=True)
0 Linear(in_features=32, out_features=64, bias=True)
0 Softmax(dim=1)
0 ExternalAttention(
  (mk): Linear(in_features=64, out_features=32, bias=True)
  (mv): Linear(in_features=32, out_features=64, bias=True)
  (softmax): Softmax(dim=1)
)


ExternalAttention(
  (mk): Linear(in_features=64, out_features=32, bias=True)
  (mv): Linear(in_features=32, out_features=64, bias=True)
  (softmax): Softmax(dim=1)
)

In [8]:
from tqdm import tqdm
import time
pbar = tqdm(range(10),ncols=0)
for i in pbar:
    time.sleep(1)
    print(i)
    pbar.set_postfix(index=i)
pbar.close()

 10% 1/10 [00:01<00:09,  1.00s/it, index=0]

0


 20% 2/10 [00:02<00:08,  1.00s/it, index=1]

1


 30% 3/10 [00:03<00:07,  1.00s/it, index=2]

2


 40% 4/10 [00:04<00:06,  1.00s/it, index=3]

3


 50% 5/10 [00:05<00:05,  1.00s/it, index=4]

4


 60% 6/10 [00:06<00:04,  1.00s/it, index=5]

5


 70% 7/10 [00:07<00:03,  1.00s/it, index=6]

6


 80% 8/10 [00:08<00:02,  1.00s/it, index=7]

7


 90% 9/10 [00:09<00:01,  1.00s/it, index=8]

8


100% 10/10 [00:10<00:00,  1.00s/it, index=9]

9





In [23]:
import torch.nn as nn
import torch
from torch.optim import lr_scheduler
model = nn.Conv2d(10,20,3)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=1e-4)
print(optimizer.state_dict())

lr_scheduler = lr_scheduler.CosineAnnealingLR(optimizer,step_size=3,gamma=0.5)
print(lr_scheduler.state_dict())
for epoch in range(10):
    optimizer.zero_grad()
    print(optimizer.param_groups[0]["lr"])
    optimizer.step()
    lr_scheduler.step(9+epoch)
print(lr_scheduler.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.0001, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}
{'step_size': 3, 'gamma': 0.5, 'base_lrs': [0.001], 'last_epoch': 0, 'verbose': False, '_step_count': 1, '_get_lr_called_within_step': False, '_last_lr': [0.001]}
0.001
0.000125
0.000125
0.000125
6.25e-05
6.25e-05
6.25e-05
3.125e-05
3.125e-05
3.125e-05
{'step_size': 3, 'gamma': 0.5, 'base_lrs': [0.001], 'last_epoch': 18, 'verbose': False, '_step_count': 11, '_get_lr_called_within_step': False, '_last_lr': [1.5625e-05]}
