In [7]:
import h5py
import torch
import numpy as np

In [3]:
class GlobalNorm1d(torch.nn.Module):
    def __init__(self, input_features, eps=None, affine=True):
        super(GlobalNorm1d, self).__init__()
        
        if eps is None:
            eps = torch.finfo(torch.float32).eps
        self.register_buffer('eps', torch.tensor(eps, dtype=torch.float32))
        
        self.register_buffer('count', torch.tensor(0., dtype=torch.float32))
        self.register_buffer('mean', torch.zeros(input_features, dtype=torch.float32))
        self.register_buffer('var', torch.ones(input_features, dtype=torch.float32))
        
        self.affine = affine
        self.weight = torch.nn.Parameter(torch.ones(input_features, dtype=torch.float32)) if affine else None 
        self.bias = torch.nn.Parameter(torch.zeros(input_features, dtype=torch.float32)) if affine else None
    
    def update_statistics(self, x):
        with torch.no_grad():
            batch_mean = x.mean(dim=0).to(self.mean.device)
            batch_count = x.size(0)
            total_count = self.count + batch_count
            # Update mean
            delta = batch_mean - self.mean
            self.mean += delta * batch_count / total_count
            # Update variance
            ratio = (self.count - 1) / (total_count - 1)
            batch_var = (x - batch_mean).pow(2).sum(dim=0)
            weighted_delta_square = delta.pow(2) * self.count * batch_count / total_count
            self.var *= ratio
            self.var += (batch_var + weighted_delta_square) / (total_count - 1)
            # Update count
            self.count = total_count
            
    def forward(self, x):
        self.update_statistics(x) if self.training else None
        normalized_x = (x - self.mean) / (self.var + self.eps).sqrt()
        return normalized_x * self.weight + self.bias if self.affine else normalized_x
 

class TwoPort(torch.nn.Module):
    def __init__(self):
        super(TwoPort, self).__init__()
        self.gn1 = GlobalNorm1d(5)
        self.linear1 = torch.nn.Linear(5, 16)
        self.linear2 = torch.nn.Linear(16, 16)
        self.linear3 = torch.nn.Linear(16, 2)

    def forward(self, x):
        x = self.gn1(x)

        x1 = self.linear1(x)
        x1 = torch.nn.functional.relu(x1)

        x2 = self.linear2(x1) + x1
        x2 = torch.nn.functional.relu(x2)

        x3 = self.linear3(x2)
        return x3

In [35]:
# model = torch.load('../models/2X16.pth')
model = torch.load('../models/2X16.pth')


# 创建或打开HDF5文件，并保存参数和缓冲区
with h5py.File('../models/2X16.h5', 'w') as h5_file:
    # 保存模型参数
    for name, param in model.named_parameters():
        layer_name, param_type = name.split('.')
        grp = h5_file.require_group(layer_name)
        grp.create_dataset(param_type, data=param.cpu().detach().numpy())

    # 保存模型缓冲区
    for name, buf in model.named_buffers():
        layer_name, buf_type = name.split('.')
        grp = h5_file.require_group(layer_name)
        grp.create_dataset(buf_type, data=buf.cpu().numpy())
        # print(name, buf.cpu().numpy())

# print(model.gn1.var)

gn1.eps 1.1920929e-07
gn1.count 268435460.0
gn1.mean [ 1.5251255e-05  4.5039244e+00  1.4293042e+04 -1.4291543e+04
  2.8557406e+04]
gn1.var [6.4009476e+01 1.6386786e+01 3.2561897e+01 3.2193089e+01 1.0485760e+06]
tensor([6.4009e+01, 1.6387e+01, 3.2562e+01, 3.2193e+01, 1.0486e+06])


In [37]:
np.random.seed(1001)
torch.manual_seed(0)

x = np.random.rand(100, 5)
torch_in = torch.from_numpy(x.astype(np.float32))

with torch.no_grad():
    model.eval()
    y = model.forward(torch_in).detach().numpy()

with h5py.File('../models/2x16_test.h5', 'w') as file:
    file.create_dataset('x', data=x)
    file.create_dataset('y', data=y)

KeyboardInterrupt: 