In [1]:
import torch
import pickle

In [2]:
def deserialize_matrix_params(filepath, device='cpu'):
    """
    Fast deserialization of matrix parameters from pickle file.
    
    Args:
        filepath: Path to the pickle file
        device: Target device for the tensors
    
    Returns:
        List of 2D torch tensors
    """
    with open(filepath, 'rb') as f:  # Note: using binary mode
        serialized_data = pickle.load(f)
    
    return [torch.tensor(data, dtype=torch.float32, device=device).reshape(shape)
            for data, shape in serialized_data]

In [4]:
# load serialized matrix parameters
path = "logs/666be39c-f5ba-480c-b278-621d5dec9e94/matrix_params_step0.pkl"
matrix_params = deserialize_matrix_params(path)
print(matrix_params)

[tensor([[ 2.3682e-02,  2.1851e-02, -2.5879e-02,  ...,  1.4465e-02,
         -1.3245e-02, -2.1118e-02],
        [-3.5400e-03,  9.5215e-03, -2.2531e-05,  ..., -3.1494e-02,
          3.5400e-02,  2.5269e-02],
        [-4.0894e-03, -1.4160e-02,  3.1494e-02,  ...,  2.2339e-02,
          3.2959e-02,  1.5198e-02],
        ...,
        [-3.1250e-02, -8.2397e-03, -2.7313e-03,  ..., -3.3691e-02,
         -1.9653e-02,  1.4282e-02],
        [ 1.0223e-03, -8.9722e-03,  1.7944e-02,  ...,  5.8594e-03,
         -4.7913e-03, -5.0049e-03],
        [-3.1982e-02,  3.4912e-02, -5.6152e-03,  ...,  2.9053e-02,
          1.5030e-03, -8.5449e-03]]), tensor([[-0.0140, -0.0004,  0.0022,  ...,  0.0176, -0.0222, -0.0177],
        [-0.0342,  0.0245,  0.0145,  ...,  0.0298,  0.0205, -0.0352],
        [ 0.0189, -0.0045,  0.0294,  ..., -0.0053, -0.0085,  0.0066],
        ...,
        [-0.0339, -0.0240,  0.0129,  ...,  0.0178,  0.0017, -0.0265],
        [-0.0332, -0.0260, -0.0167,  ...,  0.0206, -0.0258,  0.0244],
   