In [3]:
'''
Transformer输入：   (seq_len, batch_size, C)
1DConv输入：        (batch_size, C, seq_len)
2DConv输入：        (batch_size, C, H, W)
Linear输入：        仅对最后一个维度从输入变成输出
'''
import torch
import torch.nn as nn
from net.utils import checksize,toc
from net.mytransformer import PositionalEncoding,TransformerWithPooling
from net.myswinunet_s import SwinTransformerSys
# import torch.nn.functional as F
import time
from net.utils import get_model_memory_nolog
tic = time.time()

#------------------------------参数设置
cudadevice = 'cuda:0'
device = torch.device(cudadevice if torch.cuda.is_available() else "cpu")
tokenlength = 22500
hiddendim = 576
input_matrix = torch.randn(1, tokenlength, hiddendim).to(device)  # batchsize length dim
out_dim = 24
#------------------------------模型初始化
num_layers = 6
pool_size = 2  # 每次减少一半的序列长度

transformer_model = TransformerWithPooling(d_model=hiddendim, nhead=4, dim_feedforward=256, num_layers=num_layers, pool_size=pool_size, activation='silu').to(device)
get_model_memory_nolog(transformer_model)
pe = PositionalEncoding(d_model=hiddendim).to(device)
conv1d1 = nn.Conv1d(576, 1, kernel_size=1, stride=1, dilation=1 ,padding=0).to(device)
get_model_memory_nolog(conv1d1)
fc1d1 = nn.Sequential(
        nn.Linear(351, 351),
        nn.SiLU(),
        nn.Linear(351, out_dim*8*45*90),
        nn.LayerNorm(out_dim*8*45*90)).to(device)
# fc1d1 = nn.Linear(351, 8*45*90)
get_model_memory_nolog(fc1d1)
swinunet = SwinTransformerSys(embed_dim=out_dim,window_size=9).to(device)
get_model_memory_nolog(swinunet)

#------------------------------数据流正文----------------------------
checksize(input_matrix)
tic=toc(tic)
x = input_matrix.reshape(tokenlength, input_matrix.shape[0], -1)  # Transformer输入：Reshape to (seq_len, batch_size, input_channel) (L B C)
checksize(x)

#---------------Transformer Encoder----------------
print('进入Encoder')
x = pe(x)
x = transformer_model(x)  # 传入自定义的 Transformer 模型
tic=toc(tic)

#---------------conv1d+fc bottleneck---------------
print("进入bottleneck")
x = x.reshape(x.shape[1], x.shape[2], -1)  # 1DConv输入：Reshape to (batch_size, input_channel, seq_len)
checksize(x)
x = conv1d1(x)
checksize(x)
x = fc1d1(x)
checksize(x)
tic=toc(tic)

#-------------SwinTransformer Decoder--------------
print("进入Decoder")
x = x.reshape(x.shape[0],45*90,-1)
checksize(x)
x = swinunet(x)
checksize(x)
tic=toc(tic)


模型占用0.0364GB
模型占用0.0000GB
模型占用1.0259GB
SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.1;num_classes:1000
---final upsample expand_first---
模型占用0.0016GB
torch.Size([1, 22500, 576]) 12960000
耗时3.1033s
torch.Size([22500, 1, 576]) 12960000
进入Encoder
torch.Size([11250, 1, 576]) 6480000
torch.Size([5625, 1, 576]) 3240000
torch.Size([2812, 1, 576]) 1619712
torch.Size([1406, 1, 576]) 809856
torch.Size([703, 1, 576]) 404928
torch.Size([351, 1, 576]) 202176
耗时0.0069s
进入bottleneck
torch.Size([1, 576, 351]) 202176
torch.Size([1, 1, 351]) 351
torch.Size([1, 1, 777600]) 777600
耗时0.0009s
进入Decoder
torch.Size([1, 4050, 192]) 777600
torch.Size([1, 16200, 96]) 1555200
torch.Size([1, 64800, 48]) 3110400
torch.Size([1, 259200, 24]) 6220800
torch.Size([1, 259200, 24]) 6220800
torch.Size([1, 1, 360, 720]) 259200
耗时0.0144s
