In [1]:
import torch
from torch import nn, optim
import d2lzh_pytorch as d2l
import time

In [15]:
class MySequential(nn.Module):
    from collections import OrderedDict
    def __init__(self, *args):
        super(MySequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict
            for key, module in args[0].items():
                self.add_module(key, module)  # add_module方法会将module添加进self._modules(一个OrderedDict)
        else:  # 传入的是一些Module
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
    def forward(self, input):
        # self._modules返回一个 OrderedDict，保证会按照成员添加时的顺序遍历成
        for module in self._modules.values():
            if type(module) is torch.nn.modules.rnn.LSTM:
#                 input = input.view(-1, 30, 84*4)
                input, (h_n, c_n) = module(input)
                input = input[:, -1, :]
                print('lstm', input.size())
            else:
                input = module(input)
                print('other', input.size())
        return input

net = MySequential(
            nn.Conv3d(3, 16, (5, 7, 7), stride=1, padding=0), # in_channels, out_channels, kernel_size
            nn.BatchNorm3d(16),
            nn.Sigmoid(),
            nn.MaxPool3d(2, 2), # kernel_size, stride
#             nn.Conv3d(16, 64, 5),
# #             nn.BatchNorm2d(64),
#             nn.Sigmoid(),
#             nn.MaxPool3d(2, 2),
            d2l.MyFlattenLayer(),
            nn.Linear(16*25*25, 480),
            nn.BatchNorm1d(13, 480),
            nn.Sigmoid(),
            nn.Linear(480, 84*4),
            nn.BatchNorm1d(13, 84*4),
            nn.Sigmoid(),
            #nn.Linear(84, 10)
            nn.LSTM(84*4, 1024, num_layers=1, batch_first=True),
            nn.Linear(1024, 64),
            nn.BatchNorm1d(64),
            nn.Sigmoid(),
            nn.Linear(64, 1)
        )
print(net)

# X = torch.rand(10, 1, 57, 57)
# print(net(X))
# Y = net(X)
# print(Y.size())



MySequential(
  (0): Conv3d(3, 16, kernel_size=(5, 7, 7), stride=(1, 1, 1))
  (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): Sigmoid()
  (3): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): MyFlattenLayer()
  (5): Linear(in_features=10000, out_features=480, bias=True)
  (6): BatchNorm1d(13, eps=480, momentum=0.1, affine=True, track_running_stats=True)
  (7): Sigmoid()
  (8): Linear(in_features=480, out_features=336, bias=True)
  (9): BatchNorm1d(13, eps=336, momentum=0.1, affine=True, track_running_stats=True)
  (10): Sigmoid()
  (11): LSTM(336, 1024, batch_first=True)
  (12): Linear(in_features=1024, out_features=64, bias=True)
  (13): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): Sigmoid()
  (15): Linear(in_features=64, out_features=1, bias=True)
)


In [3]:
from torch.utils.data import DataLoader, Dataset
import random
global X, Y
X = torch.load("./X.pt").float()
Y = torch.load("./Y.pt").float()
X = X.view(2516, 3, 57, 57)
print(Y.size())
# train_loader = DataLoader(X, batch_size=128, shuffle=False)
# test_loader = DataLoader(Y, batch_size=128, shuffle=False)

def data_iter_random(X, Y, batch_size, num_steps, device=None):
#     print(X.size())
    num_examples = (len(Y) - num_steps)
#     print('examples', num_examples)
    epoch_size = num_examples // batch_size
#     print('epoch', epoch_size)
    example_indices = list(range(num_examples))
#     print(example_indices)
#     random.shuffle(example_indices)
    
    def _data(pos, data):
        if data is X:
#             print(pos, pos+num_steps)
#             print(data[pos:pos + num_steps, :, :, :].size())
            return data[pos:pos + num_steps, :, :, :]
        if data is Y:
#             print(pos)
            return data[pos + num_steps]
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    for i in range(epoch_size):
        # 每次读取batch_size个随机样本
        i = i * batch_size
        batch_indices = example_indices[i: i + batch_size]
#         print(batch_indices)
        XX = [_data(j, X) for j in batch_indices]
        YY = [_data(j, Y) for j in batch_indices]
        XX = torch.stack(XX)
        YY = torch.stack(YY)
        XX = XX.transpose(1, 2)
        yield XX, YY
        
# for xx, yy in data_iter_random(X, Y, 2, 2):
#     print('X: ', xx[0].size(), '\nY:', yy, '\n')

torch.Size([2517])


In [16]:
# net = MySequential()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
checkpoint = torch.load('./model1.pt')
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
l = checkpoint['loss']
e = checkpoint['epoch']
net.eval()


print(X.size())
x_test1 = X[0:30, :, :, :]
x_test2 = X[1400:1430, :, :, :]
x_test2 = torch.rand(30, 3, 57, 57)
# print(x_test1[:, 1, :, :])
x_test1 = x_test1.transpose(0, 1)
x_test1 = x_test1.view(1, 3, 30, 57, 57)
y_test1 = net(x_test1)
y_test1 = y_test1.view(y_test.shape[0], -1)
print(y_test1)
x_test2 = x_test2.transpose(0, 1)
x_test2 = x_test2.view(1, 3, 30, 57, 57)
y_test2 = net(x_test2)
y_test2 = y_test2.view(y_test.shape[0], -1)
print(y_test2)
print(y_test2.size())

print('l', l)
print('e', e)

torch.Size([2516, 3, 57, 57])
other torch.Size([1, 16, 26, 51, 51])
other torch.Size([1, 16, 26, 51, 51])
other torch.Size([1, 16, 26, 51, 51])
other torch.Size([1, 16, 13, 25, 25])
other torch.Size([1, 13, 10000])
other torch.Size([1, 13, 480])
other torch.Size([1, 13, 480])
other torch.Size([1, 13, 480])
other torch.Size([1, 13, 336])
other torch.Size([1, 13, 336])
other torch.Size([1, 13, 336])
lstm torch.Size([1, 1024])
other torch.Size([1, 64])
other torch.Size([1, 64])
other torch.Size([1, 64])
other torch.Size([1, 1])
tensor([[1.6774]], grad_fn=<ViewBackward0>)
other torch.Size([1, 16, 26, 51, 51])
other torch.Size([1, 16, 26, 51, 51])
other torch.Size([1, 16, 26, 51, 51])
other torch.Size([1, 16, 13, 25, 25])
other torch.Size([1, 13, 10000])
other torch.Size([1, 13, 480])
other torch.Size([1, 13, 480])
other torch.Size([1, 13, 480])
other torch.Size([1, 13, 336])
other torch.Size([1, 13, 336])
other torch.Size([1, 13, 336])
lstm torch.Size([1, 1024])
other torch.Size([1, 64])
o