In [3]:
import torch
import torch.nn as nn
# import torch.nn.functional as F
import time
from net.utils import get_model_memory_nolog
tic = time.time()

#------------------------------参数设置
cudadevice = 'cuda:0'
device = torch.device(cudadevice if torch.cuda.is_available() else "cpu")
tokenlength = 25000
hiddendim = 576
input_matrix = torch.randn(1, hiddendim, tokenlength).to(device)  # batchsize channel 长

#------------------------------模型初始化
transformer_model = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=hiddendim, nhead=4, dim_feedforward=256),num_layers=6).to(device)
get_model_memory_nolog(transformer_model)

#------------------------------数据流正文
print(input_matrix.shape, hiddendim*tokenlength)
print(f'初始化耗时{time.time() - tic:.4f}s')
tic = time.time()

x = input_matrix.reshape(tokenlength,1,-1) # Reshape to (seq_len, batch_size, input_channel)
x = transformer_model(x)
print(x.shape, x.shape[0] * x.shape[1] * x.shape[2])
print(f'耗时{time.time() - tic:.4f}s')


模型占用0.0364GB
torch.Size([1, 576, 25000]) 14400000
初始化耗时0.2574s
torch.Size([25000, 1, 576]) 14400000
耗时0.0152s


In [11]:
import torch
import torch.nn as nn
# import torch.nn.functional as F
import time
from net.utils import get_model_memory_nolog
tic = time.time()

#------------------------------参数设置
cudadevice = 'cuda:0'
device = torch.device(cudadevice if torch.cuda.is_available() else "cpu")
tokenlength = 2500
hiddendim = 576
input_matrix = torch.randn(1, tokenlength, hiddendim).to(device)  # batchsize length dim
# input_matrix = torch.randn(1, hiddendim, tokenlength).to(device)  # batchsize dim length

#------------------------------模型初始化
# transformer_model = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=hiddendim, nhead=4, dim_feedforward=256, batch_first=True),num_layers=6).to(device)
transformer_model = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=hiddendim, nhead=4, dim_feedforward=256,activation='silu'),num_layers=6).to(device)
get_model_memory_nolog(transformer_model)

#------------------------------数据流正文
print(input_matrix.shape, hiddendim*tokenlength)
print(f'初始化耗时{time.time() - tic:.4f}s')
tic = time.time()

# x = input_matrix.reshape(1,tokenlength,-1) # Reshape to ( batch_size, seq_len, input_channel) (1,2500,576) batch length dim
x = input_matrix.reshape(tokenlength,1,-1) # Reshape to (seq_len, batch_size, input_channel) (2500,1,576) length batch dim
print(x.shape, x.shape[0] * x.shape[1] * x.shape[2])

x = transformer_model(x) #真正使用的时候
print(x.shape, x.shape[0] * x.shape[1] * x.shape[2])
print(f'耗时{time.time() - tic:.4f}s')


模型占用0.0364GB
torch.Size([1, 2500, 576]) 1440000
初始化耗时0.0927s
torch.Size([2500, 1, 576]) 1440000
torch.Size([2500, 1, 576]) 1440000
耗时0.0088s


加入Positional Encoding

In [4]:
import torch
import torch.nn as nn
# import torch.nn.functional as F
import time
from net.utils import get_model_memory_nolog
tic = time.time()

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=22500): 
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    def forward(self, x):
        return x + self.pe[:x.size(0), :]

#------------------------------参数设置
cudadevice = 'cuda:1'
device = torch.device(cudadevice if torch.cuda.is_available() else "cpu")
tokenlength = 22500
hiddendim = 576
input_matrix = torch.randn(1, tokenlength, hiddendim).to(device)  # batchsize length dim
# input_matrix = torch.randn(1, hiddendim, tokenlength).to(device)  # batchsize dim length

#------------------------------模型初始化
# transformer_model = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=hiddendim, nhead=4, dim_feedforward=256, batch_first=True),num_layers=6).to(device)
transformer_model = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=hiddendim, nhead=4, dim_feedforward=256,activation='silu'),num_layers=6).to(device)
pe = PositionalEncoding(d_model=hiddendim).to(device)
get_model_memory_nolog(transformer_model)

#------------------------------数据流正文
print(input_matrix.shape, hiddendim*tokenlength)
print(f'初始化耗时{time.time() - tic:.4f}s')
tic = time.time()

# x = input_matrix.reshape(1,tokenlength,-1) # Reshape to ( batch_size, seq_len, input_channel) (1,2500,576) batch length dim
x = input_matrix.reshape(tokenlength,1,-1) # Reshape to (seq_len, batch_size, input_channel) (2500,1,576) length batch dim
print(x.shape, x.shape[0] * x.shape[1] * x.shape[2])

x = pe(x)
x = transformer_model(x) #真正使用的时候
print(x.shape, x.shape[0] * x.shape[1] * x.shape[2])
print(f'耗时{time.time() - tic:.4f}s')


模型占用0.0364GB
torch.Size([1, 22500, 576]) 12960000
初始化耗时0.2944s
torch.Size([22500, 1, 576]) 12960000
torch.Size([22500, 1, 576]) 12960000
耗时0.0196s


尝试Transformer pooling 和query vector动手脚

transformer pooling

In [4]:
'''
Transformer输入：   (seq_len, batch_size, C)
1DConv输入：        (batch_size, C, seq_len)
2DConv输入：        (batch_size, C, H, W)
Linear输入：        仅对最后一个维度从输入变成输出
'''
import torch
import torch.nn as nn
from net.mytransformer import PositionalEncoding,TransformerWithPooling
from net.myswinunet import SwinTransformerSys
# import torch.nn.functional as F
import time
from net.utils import get_model_memory_nolog

tic = time.time()

def checksize(x):
    print(x.shape, x.shape[0] * x.shape[1] * x.shape[2])
    return 1

def toc(tic):
    print(f'耗时{time.time() - tic:.4f}s')
    tic = time.time()
    return tic

#------------------------------参数设置
cudadevice = 'cuda:0'
device = torch.device(cudadevice if torch.cuda.is_available() else "cpu")
tokenlength = 22500
hiddendim = 576
input_matrix = torch.randn(2, tokenlength, hiddendim).to(device)  # batchsize length dim

#------------------------------模型初始化
num_layers = 6
pool_size = 2  # 每次减少一半的序列长度

transformer_model = TransformerWithPooling(d_model=hiddendim, nhead=4, dim_feedforward=256, num_layers=num_layers, pool_size=pool_size, activation='silu').to(device)
get_model_memory_nolog(transformer_model)
pe = PositionalEncoding(d_model=hiddendim).to(device)
conv1d1 = nn.Conv1d(576, 1, kernel_size=1, stride=1, dilation=1 ,padding=0).to(device)
get_model_memory_nolog(conv1d1)
fc1d1 = nn.Sequential(
        nn.Linear(351, 351),
        nn.SiLU(),
        nn.Linear(351, 96*45*90),
        nn.LayerNorm(96*45*90)).to(device)
# fc1d1 = nn.Linear(351, 8*45*90)
get_model_memory_nolog(fc1d1)
swinunet = SwinTransformerSys(embed_dim=12,window_size=9).to(device)
get_model_memory_nolog(swinunet)

#------------------------------数据流正文----------------------------
checksize(input_matrix)
tic=toc(tic)
x = input_matrix.reshape(tokenlength, input_matrix.shape[0], -1)  # Transformer输入：Reshape to (seq_len, batch_size, input_channel) (L B C)
checksize(x)

#---------------Transformer Encoder----------------
print('进入Encoder')
x = pe(x)
x = transformer_model(x)  # 传入自定义的 Transformer 模型
tic=toc(tic)

#---------------conv1d+fc bottleneck---------------
print("进入bottleneck")
x = x.reshape(x.shape[1], x.shape[2], -1)  # 1DConv输入：Reshape to (batch_size, input_channel, seq_len)
checksize(x)
x = conv1d1(x)
checksize(x)
x = fc1d1(x)
checksize(x)
tic=toc(tic)

#-------------SwinTransformer Decoder--------------
print("进入Decoder")
x = x.reshape(x.shape[0],45*90,-1)
x = swinunet(x)
checksize(x)
tic=toc(tic)


模型占用0.0364GB
模型占用0.0000GB
模型占用0.5132GB
SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.1;num_classes:1000
---final upsample expand_first---
模型占用0.0004GB
torch.Size([2, 22500, 576]) 25920000
耗时1.8190s
torch.Size([22500, 2, 576]) 25920000
进入Encoder
1
torch.Size([22500, 2, 576]) 25920000
torch.Size([11250, 2, 576]) 12960000
1
torch.Size([11250, 2, 576]) 12960000
torch.Size([5625, 2, 576]) 6480000
1
torch.Size([5625, 2, 576]) 6480000
torch.Size([2812, 2, 576]) 3239424
1
torch.Size([2812, 2, 576]) 3239424
torch.Size([1406, 2, 576]) 1619712
1
torch.Size([1406, 2, 576]) 1619712
torch.Size([703, 2, 576]) 809856
1
torch.Size([703, 2, 576]) 809856
torch.Size([351, 2, 576]) 404352
耗时0.0109s
进入bottleneck
torch.Size([2, 576, 351]) 404352
torch.Size([2, 1, 351]) 702
torch.Size([2, 1, 388800]) 777600
耗时0.0008s
进入Decoder
torch.Size([2, 16200, 48]) 1555200
torch.Size([2, 64800, 24]) 3110400
torch.Size([2, 259200, 12]) 6220800
torch.Size([2, 259200, 