<a href="https://colab.research.google.com/github/karlmaji/pytorch_learning/blob/master/Swin_transformer%E8%AE%BA%E6%96%87%E5%A4%8D%E7%8E%B0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive

drive.mount('/content/drive/')

Mounted at /content/drive/


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Step1 img to patch
- 将img中相邻的patchsize*patchsize个像素点捆绑为一个patch，输出序列的shape为[bs,num_patch,patch_depth]  
其中patch_depth = patch_size * patch_size * input_channel



In [3]:
# patch_depth = patchsize *patchsize *input_channel
# patch_depth= 12
# model_dim = 8
# weight = torch.randn(12,8)


# 方法1：使用unfold 方法
def img2patch(img,patchsize,weight):
  patch = F.unfold(img,kernel_size=patchsize,
                   stride=patchsize).transpose(-1,-2) 
                   #[bs,patch_num,patch_depth]

  patch_embedding = patch @ weight #weight.shape:[patch_depth,model_dim]

  # print(patch_embedding.shape) #[bs,num_patch,model_dim]
  
  return patch_embedding

# 方法2:使用卷积方法
def img2patch_conv(img,patchsize,model_dim):
  bs,i_c,h,w =img.shape

  #实际建立model时 先定义卷积层在调用
  layer= nn.Conv2d(3,model_dim,kernel_size=patchsize,stride=patchsize)
  #output:[bs,model_dim,h,w]

  patch_embedding=layer(img).reshape(bs,model_dim,-1).transpose(-1,-2)
  # print(patch_embedding.shape)


# img=torch.randn(2,3,224,224)
# img2patch(img,2,weight)
# img2patch_conv(img,2,model_dim)


# Step2 构建MultiHeadSelfAttention

In [4]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self,model_dim,num_head):
    super(MultiHeadSelfAttention, self).__init__()
    self.headnum = num_head
    self.proj_Linear = nn.Linear(model_dim,3*model_dim)
    self.final_Linear = nn.Linear(model_dim,model_dim)

  def forward(self,input_,with_mask):
    bs,seq_len,model_dim = input_.shape

    num_head = self.headnum
    head_dim = model_dim // num_head

    proj_output = self.proj_Linear(input_) # [bs, seqlen, 3*model_dim] 

    proj_output=proj_output.reshape(bs,seq_len,3,\
                    num_head,head_dim).permute(2,0,3,1,4) 
                  #output:[3,bs,num_head,seq_len,head_dim]

    q, k, v=proj_output.reshape(3,bs*num_head,seq_len,head_dim)[:]
    #q,k,v .shape= [bs*num_head,seq_len,head_dim]
    if with_mask==None:
      atten_prob=F.softmax(torch.bmm(q,k.transpose(-1,-2)/torch.sqrt(torch.tensor(head_dim))),dim=-1)
    else:
      mask=torch.tile(with_mask,(num_head,1,1))
      atten_prob=F.softmax(torch.bmm(q,k.transpose(-1,-2)/torch.sqrt(torch.tensor(head_dim))) + mask,dim=-1)
      pass

    output = torch.bmm(atten_prob, v) # [bs*num_head, seqlen, head_dim]
    output = output.reshape(bs, num_head, seq_len, head_dim).transpose(1, 2) #[bs, seqlen, num_head, head_dim]
    output = output.reshape(bs, seq_len, model_dim)

    output = self.final_Linear(output)
    return output




    


# Step3 构建W-MHSA

In [63]:
def Window_MultHeadSelfAttention(patch_embedding,mhsa,window_size=4,num_head=2):

  num_patch_in_window = window_size * window_size

  bs, num_patch, patch_depth = patch_embedding.shape

  img_height = img_width = int(torch.sqrt(torch.tensor(num_patch)))

  patch_embedding = patch_embedding.transpose(-1,-2) #[bs,patch_depth,num_patch]

  patch_img = patch_embedding.reshape(bs,patch_depth,img_height,img_width)

  patch_windows = F.unfold(patch_img,kernel_size=window_size,stride=window_size).transpose(-1,-2) #[bs,num_windows,window_depth]
  # window_depth = window_size * window_size * patch_depth =num_patch_in_window*patch_depth

  bs,num_windows,window_depth = patch_windows.shape

  patch_window = patch_windows.reshape(bs*num_windows,patch_depth,num_patch_in_window).transpose(-1,-2)
  #[bs*num_windows,num_patch_in_window,patch_depth]
  output = mhsa(patch_window,with_mask=None)


  #这里输出为window的4维格式是为了方便后续的SW-MHSA
  output=output.reshape(bs,num_windows,num_patch_in_window,patch_depth)

  return output


# Step4 构建SW-MHSA


In [53]:
def window2img(patch_window):
  bs,num_windows,num_patch_in_window,patch_depth = patch_window.shape
  window_size = int(torch.sqrt(torch.tensor(num_patch_in_window)))
  img_height = img_width = int(torch.sqrt(torch.tensor(num_windows))) * window_size

  patch_img = patch_window.reshape(bs,img_height // window_size,img_width //window_size \
                  ,window_size,window_size,patch_depth).transpose(2,3)
  
  patch_img = patch_img.reshape(bs,img_height,img_width,patch_depth).permute(0,3,1,2)
  #output [bs,patch_depth,img_height,img_width]
  return patch_img

def get_shift_window_mask(bs,window_size,img_height,img_width):
  #we need out shape is [bs,num_windows,num_patch_in_window,num_patch_in_window]
  index_matrix = torch.zeros(img_height, img_width)

  for i in range(img_height):
      for j in range(img_width):
          row_times = (i+window_size//2) // window_size
          col_times = (j+window_size//2) // window_size
          index_matrix[i, j] = row_times*(img_width//window_size) + col_times + 1
  index_matrix = torch.roll(index_matrix,shifts=(-window_size//2,-window_size//2),dims=(0,1))
  index_matrix = index_matrix.unsqueeze(0).unsqueeze(0) #[bs,c,h,w]
  mask = F.unfold(index_matrix,kernel_size=(window_size,window_size),\
                  stride=(window_size,window_size)).transpose(-1,-2)
  #[bs,num_window,num_patch_in_window]
  mask=torch.tile(mask,dims=(bs,1,1))

  bs, num_window, num_patch_in_window = mask.shape

  mask=mask.unsqueeze(-1) #[bs,num_window,num_patch_in_window,1]
  mask = (mask - mask.transpose(-1,-2)) == 0 #[bs,num_window,num_patch_in_window,num_patch_in_window]

  mask = mask.to(torch.float32)
# MHSA中使用的是加法，因此True值应该为0  False 应该为负无穷
  mask = (1-mask)*-10000

  mask = mask.reshape(bs*num_window,num_patch_in_window,num_patch_in_window)

  return mask
  





def Shift_Window_func(patch_window,shift_size,generate_mask=False):

  bs,num_windows,num_patch_in_window,patch_depth = patch_window.shape

  window_size = int(torch.sqrt(torch.tensor(num_patch_in_window)))

  patch_img = window2img(patch_window)

  bs, patch_depth, img_height, img_width = patch_img.shape

  shift_img = torch.roll(patch_img,shifts=(-shift_size, -shift_size),dims=(-1,-2))

#-------------------------img2window------------------------------------------#

  shift_window = shift_img.permute(0,2,3,1) #[bs,img_height,img_width,patch_depth]
  shift_window = shift_window.reshape(bs,\
                    img_height //window_size, \
                    window_size, \
                    img_width //window_size, \
                    window_size, \
                    patch_depth).transpose(2,3)

  shift_window = shift_window.reshape(bs,num_windows,num_patch_in_window,patch_depth)

#-----------------------------------------------------------------------------#

  if generate_mask:
    mask = get_shift_window_mask(bs,window_size,img_height,img_width)
  else:
    mask=None

  return shift_window,mask
 

In [54]:
def Shift_Window_MultHeadSelfAttention(patch_window,mhsa,num_head=2,window_size=4):

  shifted_window , mask = Shift_Window_func(patch_window,shift_size = window_size//2,generate_mask=True)

  bs,num_windows,num_patch_in_window,patch_depth = shifted_window.shape

  shifted_window = shifted_window.reshape(bs*num_windows,num_patch_in_window,patch_depth)

  mhsa_out = mhsa(shifted_window,with_mask = mask)

  mhsa_out = mhsa_out.reshape(bs,num_windows,num_patch_in_window,patch_depth)

  mhsa_out,_ = Shift_Window_func(mhsa_out,shift_size= - window_size // 2,generate_mask=False)

  return mhsa_out


# Step5 Patch Merging

In [75]:
class PatchMerging(nn.Module):
    
    def __init__(self, model_dim, merge_size, output_depth_scale=0.5):
        super(PatchMerging, self).__init__()
        self.merge_size = merge_size
        self.proj_layer = nn.Linear(
            model_dim*merge_size*merge_size,
            int(model_dim*merge_size*merge_size*output_depth_scale)
        )
        
    def forward(self, input):
        bs, num_window, num_patch_in_window, patch_depth = input.shape
        window_size = int(torch.sqrt(torch.tensor(num_patch_in_window)))

        input = window2img(input) #[bs, patch_depth, image_h, image_w]

        merged_window = F.unfold(input, kernel_size=(self.merge_size, self.merge_size),
                                 stride=(self.merge_size, self.merge_size)
                                ).transpose(-1, -2)
        merged_window = self.proj_layer(merged_window) #[bs, num_patch, new_patch_depth]

        return merged_window
    

# Step6 构建swin transformer block

In [76]:
class SwinTransformerBlock(nn.Module):
  def __init__(self,model_dim,window_size,num_head):
    super(SwinTransformerBlock,self).__init__()
    self.LN1 = nn.LayerNorm(model_dim)
    self.LN2 = nn.LayerNorm(model_dim)
    self.LN3 = nn.LayerNorm(model_dim)
    self.LN4 = nn.LayerNorm(model_dim)

    self.wsma_mlp1 = nn.Linear(model_dim, 4*model_dim)
    self.wsma_mlp2 = nn.Linear(4*model_dim, model_dim)

    self.swsma_mlp1 = nn.Linear(model_dim, 4*model_dim)
    self.swsma_mlp2 = nn.Linear(4*model_dim, model_dim)

    self.wmhsa = MultiHeadSelfAttention(model_dim,num_head)

    self.swmhsa = MultiHeadSelfAttention(model_dim,num_head)

  def forward(self,patch_embedding):
    bs,num_patch,patch_depth = patch_embedding.shape

    input_ = patch_embedding

    patch_embedding = self.LN1(patch_embedding)

    patch_window = Window_MultHeadSelfAttention(patch_embedding,self.wmhsa,window_size=4,num_head=2)

    bs,num_window,num_patch_in_window,patch_depth = patch_window.shape

    patch_window_output = patch_window.reshape(bs,num_patch,patch_depth)

    out1 = input_ + patch_embedding

    out2 = self.LN2(out1)
    out2 = self.wsma_mlp1(out2)
    out2 = self.wsma_mlp2(out2)
    out2 += out1

    out3 = self.LN3(out2)
    out3_window = out3.reshape(bs,num_window,num_patch_in_window,patch_depth)
    out3 = Shift_Window_MultHeadSelfAttention(out3_window,self.swmhsa,num_head=2,window_size=4)
    out3 = out2 + out3.reshape(bs,num_patch,patch_depth)

    out4 = self.LN4(out3)
    out4 = self.swsma_mlp1(out4)
    out4 = self.swsma_mlp2(out4)
    out4 += out3

    output = out4.reshape(bs,num_window,num_patch_in_window,patch_depth)

    return output


In [72]:

st = SwinTransformerBlock(16,4,2)

x = torch.randn(3,16,16)

st(x).shape

torch.Size([3, 1, 16, 16])

# 构建Swin transformer model

In [82]:
class SwinTransformerModel(nn.Module):
    
    def __init__(self, input_image_channel=3, patch_size=4, model_dim_C=8, num_classes=10,
                 window_size=4, num_head=2, merge_size=2):
        
        super(SwinTransformerModel, self).__init__()

        patch_depth = patch_size*patch_size*input_image_channel

        self.patch_size = patch_size
        self.model_dim_C = model_dim_C
        self.num_classes = num_classes
        
        self.patch_embedding_weight = nn.Parameter(torch.randn(patch_depth, model_dim_C)) 
        
        self.block1 = SwinTransformerBlock(model_dim_C, window_size, num_head)
        self.block2 = SwinTransformerBlock(model_dim_C*2, window_size, num_head)
        self.block3 = SwinTransformerBlock(model_dim_C*4, window_size, num_head)
        self.block4 = SwinTransformerBlock(model_dim_C*8, window_size, num_head)

        
        self.patch_merging1 = PatchMerging(model_dim_C, merge_size)
        self.patch_merging2 = PatchMerging(model_dim_C*2, merge_size)
        self.patch_merging3 = PatchMerging(model_dim_C*4, merge_size)
        
        self.final_layer = nn.Linear(model_dim_C*8, num_classes)
    def forward(self,input_):

      patch_embedding = img2patch(input_,self.patch_size,self.patch_embedding_weight)

      output1 = self.block1(patch_embedding)


      output2 = self.patch_merging1(output1)
      output2 = self.block2(output2)

 
      output3 = self.patch_merging2(output2)
      output3 = self.block3(output3)

      output4 = self.patch_merging3(output3)
      output4 = self.block4(output4)

      bs,num_window,num_patch_in_window,patch_depth = output4.shape
      output4 = output4.reshape(bs,-1,patch_depth)

      pool_output = torch.mean(output4, dim=1) #[bs, patch_depth]
      output = self.final_layer(pool_output)

      return output

In [83]:
# 难点5 分类模块，输出类别为num_classes的logits
if __name__ == "__main__":
    bs, ic, image_h, image_w = 4, 3, 256, 256
    patch_size = 4
    model_dim_C = 8 #一开始的patch embedding的大小
    # max_num_token = 16
    num_classes = 10
    window_size = 4
    num_head = 2
    merge_size = 2
    
    patch_depth = patch_size*patch_size*ic
    
    image = torch.randn(bs, ic, image_h, image_w)
    
    model = SwinTransformerModel(ic, patch_size, model_dim_C, num_classes,\
                                 window_size, num_head, merge_size)
    logits = model(image)
    print(logits)

tensor([[ 1.2577e-01,  1.7471e-01,  1.1693e-01, -3.8581e-01, -1.2382e-01,
         -1.0274e-01,  4.9425e-01,  4.9765e-01, -7.3770e-01,  4.9017e-01],
        [-1.4078e-01,  1.4214e-01, -3.2334e-02,  3.1662e-01, -4.6753e-01,
         -9.2532e-02,  1.0285e-01,  2.7346e-01, -1.0611e-01,  1.4083e-01],
        [-6.6732e-01,  5.5884e-01,  1.1990e-01, -2.3074e-05, -3.3403e-01,
         -3.6162e-01,  3.9909e-01,  1.8915e-01, -4.3181e-01,  4.4186e-01],
        [-3.2469e-03, -1.4577e-01, -1.4872e-01, -6.9833e-02, -8.1082e-01,
          1.6133e-01,  8.5165e-01,  3.5073e-01, -2.9628e-01, -1.2382e-01]],
       grad_fn=<AddmmBackward0>)
