In [1]:
import sys
from pathlib import Path
p = Path('.').absolute().parent
if str(p) not in sys.path:
    sys.path.append(str(p))

from graph_irl.buffer_v2 import GraphBuffer
from graph_irl.policy import GaussPolicy, TwoStageGaussPolicy, GCN, Qfunc
from graph_irl.graph_rl_utils import GraphEnv
from graph_irl.sac import SACAgentGraph, TEST_OUTPUTS_PATH
from graph_irl.reward import GraphReward
from graph_irl.examples.circle_graph import create_circle_graph
from graph_irl.irl_trainer import IRLGraphTrainer

import random
import numpy as np
import torch
torch.set_printoptions(precision=8)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)


def get_params(compute_rewards_online):
    # circular graph with 7 nodes;
    n_nodes, node_dim = 11, 5
    nodes, expert_edge_index = create_circle_graph(n_nodes, node_dim, torch.randn)
    # nodes = torch.ones_like(nodes)  # doesn't seem to work for ones;
    # print(nodes, expert_edge_index)
    encoder_hiddens = [8, 8, 8]
    reward_fn_hiddens = [16, 16]
    gauss_policy_hiddens = [16, 16]
    tsg_policy_hiddens1 = [16, 16]
    tsg_policy_hiddens2 = [16]
    qfunc_hiddens = [16, 16]

    # print(f"IRL training for {n_nodes}-node graph")

    encoder_dict = dict(
        encoder = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True),
        encoderq1 = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True),
        encoderq2 = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True),
        encoderq1t = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True),
        encoderq2t = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True),
        encoder_reward = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True),
    )

    reward_fn = GraphReward(
        encoder_dict['encoder_reward'], 
        encoder_hiddens[-1], 
        hiddens=reward_fn_hiddens, 
        with_layer_norm=True,
        with_batch_norm=False,
    )

    gauss_policy_kwargs = dict(
        obs_dim=encoder_hiddens[-1],
        action_dim=encoder_hiddens[-1],
        hiddens=gauss_policy_hiddens,
        with_layer_norm=True,
        encoder=encoder_dict['encoder'],
        two_action_vectors=True,
    )

    tsg_policy_kwargs = dict(
        obs_dim=encoder_hiddens[-1],
        action_dim=encoder_hiddens[-1],
        hiddens1=tsg_policy_hiddens1,
        hiddens2=tsg_policy_hiddens2,
        encoder=encoder_dict['encoder'],
        with_layer_norm=True,
    )

    qfunc_kwargs = dict(
        obs_action_dim=encoder_hiddens[-1] * 3, 
        hiddens=qfunc_hiddens, 
        with_layer_norm=True, 
        with_batch_norm=False,
        encoder=None
    )

    Q1_kwargs = qfunc_kwargs.copy()
    Q1_kwargs['encoder'] = encoder_dict['encoderq1']
    Q2_kwargs = qfunc_kwargs.copy()
    Q2_kwargs['encoder'] = encoder_dict['encoderq2']
    Q1t_kwargs = qfunc_kwargs.copy()
    Q1t_kwargs['encoder'] = encoder_dict['encoderq1t']
    Q2t_kwargs = qfunc_kwargs.copy()
    Q2t_kwargs['encoder'] = encoder_dict['encoderq2t']

    agent_kwargs=dict(
        name='SACAgentGraph',
        policy_constructor=GaussPolicy,
        qfunc_constructor=Qfunc,
        env_constructor=GraphEnv,
        buffer_constructor=GraphBuffer,
        optimiser_constructors=dict(
            policy_optim=torch.optim.Adam,
            temperature_optim=torch.optim.Adam,
            Q1_optim=torch.optim.Adam,
            Q2_optim=torch.optim.Adam,
        ),
        entropy_lb=encoder_hiddens[-1],
        policy_lr=3e-4,
        temperature_lr=3e-4,
        qfunc_lr=3e-4,
        tau=0.01,
        discount=1.,
        save_to=TEST_OUTPUTS_PATH,
        cache_best_policy=False,
        clip_grads=False,
        zero_temperature=False,
        UT_trick=False,
        with_entropy=False,
    )

    config = dict(
        training_kwargs=dict(
            seed=0,
            num_iters=31,
            num_steps_to_sample=100,
            num_grad_steps=1,
            batch_size=11,
            num_eval_steps_to_sample=n_nodes,
            min_steps_to_presample=0,
        ),
        Q1_kwargs=Q1_kwargs,
        Q2_kwargs=Q2_kwargs,
        Q1t_kwargs=Q1t_kwargs,
        Q2t_kwargs=Q2t_kwargs,
        policy_kwargs=gauss_policy_kwargs,
        buffer_kwargs=dict(
            max_size=10_000,
            nodes=nodes,
            seed=0,
            drop_repeats_or_self_loops=True,
            get_batch_reward=True,
            compute_rewards_online=compute_rewards_online,
            graphs_per_batch=100,
            action_is_index=True,
            per_decision_imp_sample=True,
            reward_scale=encoder_hiddens[-1] * 2,
            log_offset=0.,
        ),
        env_kwargs=dict(
            x=nodes,
            reward_fn=reward_fn,
            max_episode_steps=n_nodes,
            num_expert_steps=n_nodes,
            max_repeats=n_nodes,
            max_self_loops=n_nodes,
            drop_repeats_or_self_loops=True,
            id=None,
            reward_fn_termination=False,
            calculate_reward=False,
            min_steps_to_do=3,
        )
    )
    return agent_kwargs, config, nodes, expert_edge_index


agent_kwargs, config, nodes, expert_edge_index = get_params(
    compute_rewards_online=True
)

agent = SACAgentGraph(
    **agent_kwargs,
    **config
)

irl_trainer_config = dict(
    num_expert_traj=10,
    graphs_per_batch=config['buffer_kwargs']['graphs_per_batch'],
    num_extra_paths_gen=5,
    reward_optim_lr_scheduler=None,
    reward_grad_clip=False,
    reward_scale=config['buffer_kwargs']['reward_scale'],
    per_decision_imp_sample=config['buffer_kwargs']['per_decision_imp_sample'],
    unnorm_policy=True,
    add_expert_to_generated=False,
    lcr_regularisation_coef=None,
    mono_regularisation_on_demo_coef=1 / (expert_edge_index.shape[-1] // 2),
    verbose=False,
)

irl_trainer = IRLGraphTrainer(
    reward_fn=agent.env.reward_fn,
    reward_optim=torch.optim.Adam(agent.env.reward_fn.parameters(), lr=1e-2),
    agent=agent,
    nodes=nodes,
    expert_edge_index=expert_edge_index,
    **irl_trainer_config,
)

irl_trainer.reward_fn.requires_grad_(True)
irl_trainer.reward_fn.train()
irl_trainer.agent.policy.requires_grad_(False)
irl_trainer.do_reward_grad_step()

pygame 2.1.3 (SDL 2.0.22, Python 3.8.10)
Hello from the pygame community. https://www.pygame.org/contribute.html
/home/focal/coding/urban-nets-style-transfer/tests
expert return: -122.74497985839844
expert return: -149.66355895996094
expert return: -124.43367004394531
expert return: -146.4972686767578
expert return: -138.7230682373047
expert return: -153.1378173828125
expert return: -139.2919158935547
expert return: -134.63589477539062
expert return: -135.19525146484375
expert return: -153.7252960205078
expert_rewards shape:  torch.Size([10, 11])
sampled undiscounted return: -137.8103790283203
sampled undiscounted return: -143.25941467285156
sampled undiscounted return: -142.55328369140625
sampled undiscounted return: -150.7525177001953
sampled undiscounted return: -119.78134155273438
sampled undiscounted return: -147.7942657470703
sampled undiscounted return: -134.21124267578125
sampled undiscounted return: -147.94541931152344
sampled undiscounted return: -154.61767578125
sampled undi

In [2]:
irl_trainer.reward_fn.requires_grad_(False)
irl_trainer.reward_fn.eval()
irl_trainer.agent.policy.requires_grad_(True)
irl_trainer.agent.policy.eval()
irl_trainer.train_policy_k_epochs(1)

  0%|          | 0/1 [00:00<?, ?it/s]

bob is: 110
tensor([[10, 10,  4,  0,  9,  2,  1,  3,  6,  8,  4],
        [ 1,  7,  9,  8,  8,  0,  4,  0,  7,  9,  3]])
tensor([ -5.01699448, -10.65846252, -10.90252399,  -9.80887985,  -9.80794048,
        -20.05060196,  -6.02020884, -12.78817081,  -9.87467861, -18.93912888,
         -8.24656963])





bob is: 110
tensor([[0, 3, 7, 2, 4, 2, 0, 1, 2, 2, 9],
        [6, 0, 0, 8, 5, 0, 3, 2, 1, 8, 0]])
tensor([-13.58244228,  -8.29182720, -16.62692833, -20.31767845, -14.35860252,
        -21.08855629, -11.96698666,  -8.06600475,  -7.78110695, -20.34567451,
        -12.27365112])





bob is: 110
tensor([[ 1,  5,  2, 10,  0,  3,  0,  9,  9,  1,  5],
        [10, 10, 10,  6,  7, 10, 10,  1,  4,  7,  2]])
tensor([-13.74940014, -17.23501205, -19.97656631, -10.68769932, -13.47595119,
        -10.56976509,  -7.67417145,  -7.25369453,  -9.85009289, -21.47320938,
         -6.64267731])





bob is: 110
tensor([[ 5,  1,  6,  1,  1,  9,  5,  8,  9,  1,  2],
        [ 3,  4,  7,  0,  0,  7,  7, 10,  5,  7, 10]])
tensor([-10.35365200,  -6.02020884, -13.11557293, -14.72109318, -13.87147427,
        -10.26946354, -10.41771317, -14.30883026,  -8.26774216, -15.71107006,
        -19.73834038])





bob is: 110
tensor([[ 8,  4,  3,  5,  3,  9,  6,  5, 10,  3,  2],
        [ 4, 10, 10,  9, 10,  1,  8,  1,  1,  1,  9]])
tensor([ -5.99037743, -14.67303085,  -7.35635662, -16.82743454, -17.20202065,
         -7.44753456, -11.57715893, -16.43317413,  -9.01863194,  -8.73711586,
        -12.50947857])





bob is: 110
tensor([[ 6,  1,  5,  9,  3,  2,  1,  8,  9,  1,  1],
        [ 0,  6,  9,  7, 10,  4,  7,  1,  1,  9,  9]])
tensor([ -8.89008522, -12.44372654, -16.79738045, -12.42251396, -10.87466049,
        -10.56945992, -15.71107006,  -7.33597851, -10.31991673, -12.73385334,
        -13.14181042])





bob is: 110
tensor([[ 8,  7,  2,  0,  2,  0,  2,  2,  0, 10,  9],
        [10,  2,  8,  7,  0, 10,  8, 10, 10,  9,  2]])
tensor([-19.96620941,  -7.94796753, -21.98263168, -12.50961304, -20.71328163,
         -7.96949005, -20.34567451,  -9.90551662, -13.34237289, -12.03491592,
        -12.06001282])





bob is: 110
tensor([[ 7,  4,  8,  1,  7,  9,  8,  1,  1,  4,  0],
        [ 8,  6,  9,  5,  8, 10,  9,  9,  6, 10, 10]])
tensor([-16.91120911, -14.73565578,  -8.85842609,  -7.05359793, -16.71891212,
         -7.16250134, -18.93912888, -16.38153648, -14.70255756,  -7.32860851,
         -7.96949005])





bob is: 110
tensor([[ 3,  2,  8,  3,  2,  3,  0, 10,  8,  6,  7],
        [ 6,  4,  9,  8,  4,  2,  8,  9,  5,  1,  0]])
tensor([-15.67747688,  -9.45127773, -21.18833351, -15.34461594, -17.40869141,
        -15.65111732, -14.94575214, -14.98478508, -13.68291378,  -9.49394035,
        -19.43034744])





bob is: 110
tensor([[ 6,  3, 10,  1,  5,  7,  8,  0,  0,  9,  2],
        [ 2, 10,  1,  0,  2,  3,  1,  8,  1,  1,  1]])
tensor([ -9.70431137, -10.55486774,  -5.87328625, -14.55493450, -10.91096306,
        -22.29721832,  -7.41619778, -14.94575214,  -7.49139071,  -9.06084824,
        -20.08072281])





bob is: 110
tensor([[ 8,  0,  3,  6, 10,  7,  8,  4,  0,  5,  5],
        [ 4,  7,  6,  7,  4,  6,  6,  6,  3,  9,  8]])
tensor([ -9.19682980, -13.12626457, -14.70588875, -10.70311737, -10.44157124,
         -6.88864899, -10.10399055, -14.73565578,  -5.98628473, -16.82743454,
        -19.93266678])





bob is: 110
tensor([[ 5, 10,  0,  7,  1, 10,  6,  2,  5,  8,  2],
        [ 6,  9,  7, 10,  7,  6,  4,  0,  8,  5,  4]])
tensor([-20.38142776, -12.75823593, -13.47595119, -16.01479721, -21.47320938,
        -16.95775986, -10.01554394, -19.93632507, -19.93266678,  -8.49283409,
        -11.85583591])





bob is: 110
tensor([[ 4,  5,  0,  1,  8,  9,  7,  8,  2,  7,  6],
        [10,  9,  9,  5,  1,  4,  1,  6,  4,  1,  7]])
tensor([-12.39560604, -21.04291344, -15.27743149,  -6.00073242,  -7.06168842,
         -9.85009289,  -9.22720242, -11.90094757,  -9.45127773, -13.41860104,
         -9.42696857])





bob is: 110
tensor([[10, 10,  8,  3,  5,  9,  1,  2,  8,  8,  3],
        [ 0,  6,  6,  7,  8,  7,  4,  5,  7, 10,  8]])
tensor([-11.78224087,  -5.66121864, -10.10399055, -10.62113762, -20.86382866,
         -9.99642086, -10.90974426,  -4.60030460, -10.21136379,  -7.88013601,
         -9.15755367])





bob is: 110
tensor([[ 1,  2, 10,  5,  6,  1,  3,  1,  8,  3,  4],
        [ 9, 10,  2,  9,  0, 10,  0, 10, 10,  8, 10]])
tensor([-14.45618057, -19.91670418,  -9.24831200, -17.22081947, -11.60579395,
         -7.39944458, -11.27902794, -13.16800308, -12.04510498, -15.34461594,
        -10.77372837])





bob is: 109
tensor([[ 1, 10,  6,  0,  5,  1,  7, 10,  4,  9,  1],
        [ 6,  9,  0,  4,  1, 10,  6,  0,  9, 10, 10]])
tensor([-14.14079189, -14.74208832, -14.15732384, -10.81862926, -11.30941010,
         -7.67038536, -13.63494301, -13.03579521, -14.97972679,  -9.98200703,
         -4.72071075])





bob is: 110
tensor([[ 2,  9,  9, 10,  9,  6,  4,  1,  2, 10, 10],
        [ 1,  1,  1,  1, 10,  0,  2,  9,  9,  0,  4]])
tensor([ -7.78110695,  -4.96468639,  -7.95093727,  -8.16125298, -13.99039841,
         -7.48609257,  -9.29767418, -16.03378868, -17.59910393, -13.51410484,
         -6.18498182])





bob is: 110
tensor([[ 8,  2,  5,  7,  0,  5,  5,  3,  5,  9,  4],
        [ 7,  9, 10,  6,  9, 10,  8,  8,  0, 10, 10]])
tensor([-13.99958038, -19.02750969,  -9.54340935, -16.56809235, -10.94678879,
        -12.19553280, -20.79857635,  -5.64748240, -15.86811256, -10.99304104,
        -14.67303085])





bob is: 110
tensor([[1, 5, 3, 3, 7, 2, 5, 5, 0, 1, 6],
        [5, 8, 8, 7, 8, 1, 6, 9, 5, 9, 8]])
tensor([-11.11308575, -20.54598427, -19.07320023,  -7.96651077, -16.71891212,
         -5.47315359, -20.86198997, -18.73563766, -11.57705688, -16.03378868,
         -6.69357538])





bob is: 110
tensor([[ 3,  3,  4,  3,  3,  0,  6,  7,  2,  9,  2],
        [ 6,  9,  9,  9,  9,  1, 10,  3,  8,  1, 10]])
tensor([-14.02655125, -14.13656521, -15.70571327, -17.11959076, -16.05929947,
         -4.94739294, -10.14953327,  -9.33137894, -20.28973961,  -9.02924252,
        -19.18160248])





bob is: 110
tensor([[ 0,  9,  8,  0,  2,  2,  9,  7,  3,  6,  4],
        [ 1,  8,  4,  8,  0, 10,  4, 10,  9,  7,  7]])
tensor([-10.09340477, -14.19982624, -11.94560909,  -7.41849327, -21.11357689,
        -14.37753963,  -9.92524147,  -7.36595249, -10.58334923, -10.09396839,
         -8.14468956])





bob is: 110
tensor([[ 8,  2,  0,  0,  0,  4,  7,  2,  2,  2,  7],
        [10, 10,  9,  1,  8,  9,  8,  1,  7,  9,  3]])
tensor([ -8.02362251,  -7.66612530, -14.69013882,  -8.92664623,  -8.07176781,
        -11.07705212, -18.00357819,  -9.66172981,  -8.99657631, -19.37692642,
         -7.47600222])





bob is: 110
tensor([[ 2,  9,  7,  8, 10,  3,  3,  4,  5,  3,  4],
        [ 9,  4, 10,  0,  4,  0,  5,  9,  4,  1, 10]])
tensor([-19.96331215,  -8.11276054, -11.44062424, -12.41375446,  -6.18498230,
        -15.74818897, -15.57771587,  -9.02276039, -10.60958481,  -7.82642365,
         -7.32860851])





bob is: 110
tensor([[ 4,  0,  7,  8,  1,  4,  6,  3, 10,  1,  4],
        [10,  5,  1,  1,  7, 10,  2, 10,  8,  9,  9]])
tensor([ -9.72420216, -12.52479267,  -6.43289948, -10.99028015, -17.78861809,
        -15.34132481,  -9.60834980, -10.72504139, -13.76848507,  -6.09430599,
        -13.33307266])





bob is: 110
tensor([[ 3,  0,  2,  3,  3,  4,  5,  0,  1,  2,  7],
        [ 0,  9,  7,  7,  6,  9, 10,  7,  0, 10,  8]])
tensor([-12.20814991, -15.41094685, -17.10140228,  -7.96609688,  -7.14280176,
        -13.16420937, -15.11346245, -12.74480247, -14.38974762, -19.87360573,
        -16.85740471])





bob is: 110
tensor([[ 0,  7,  7,  5, 10,  7,  9, 10,  2,  9,  8],
        [ 2,  3,  8,  2,  1,  6,  5,  7, 10,  1,  9]])
tensor([-16.29795456,  -4.42839670, -16.01977348,  -7.47103977,  -6.53693247,
         -7.03399563,  -8.26774216, -12.41486359, -19.25117683, -11.44691086,
        -10.40319157])





bob is: 110
tensor([[ 9,  5,  2,  9,  8,  4,  2,  7,  1,  7, 10],
        [ 0,  9, 10, 10,  1,  3, 10,  3,  0,  8,  8]])
tensor([-12.02992916, -18.65868950, -19.55510330, -10.49174118,  -5.30645084,
         -6.78024530,  -7.66612530,  -9.97227001, -14.22922039, -15.83741760,
        -13.76848507])





bob is: 110
tensor([[ 1,  0,  5,  4,  9,  0,  1,  3,  8,  1,  3],
        [10,  4,  9,  9, 10, 10,  6, 10,  1,  9,  4]])
tensor([ -7.89188004, -10.40701675, -19.24053574, -11.18931770,  -5.96048832,
         -8.52951908, -14.19349098, -12.71349812,  -8.84794044, -19.65242386,
        -12.97301865])





bob is: 110
tensor([[ 7,  7,  2,  8,  8,  0,  2,  2,  3,  3,  1],
        [ 9,  8,  7,  4,  7, 10,  8,  0,  4, 10,  9]])
tensor([-15.11336517, -16.19939423, -16.53502846, -10.45444489, -10.15429306,
         -8.54043674, -20.53178024, -18.16492653,  -9.01167202,  -7.35635662,
        -14.69590473])





bob is: 110
tensor([[ 1,  2, 10,  0,  3,  2,  2,  6,  3,  9,  7],
        [ 8,  6,  4,  5,  6, 10,  0,  0,  8,  1,  2]])
tensor([-15.14312744, -19.86323547,  -9.75464344, -11.57705688, -14.14843082,
        -19.07562637, -20.66304398, -12.06444740, -14.62690163,  -9.18458271,
        -10.55954170])



100%|██████████| 31/31 [00:13<00:00,  2.31it/s]
100%|██████████| 1/1 [00:13<00:00, 13.44s/it]

bob is: 110
tensor([[ 5,  9, 10,  6,  7,  6,  1,  7,  2,  6, 10],
        [ 7, 10,  8,  7,  9,  0,  7, 10,  7,  8,  8]])
tensor([-20.30950546,  -8.11925697, -13.98352432, -10.59698582, -14.32473755,
        -11.60579395, -15.71107006, -10.25322914, -16.59001732, -11.71584034,
        -15.57777596])

3409



  ax.set_ylim(ylow, yhigh)


3409 [[10  8]
 [ 0  8]
 [10  6]
 [ 2  6]
 [ 5  0]] [[ 0  9]
 [ 0  6]
 [ 4 10]
 [ 5  0]
 [ 8  7]] [Data(x=[11, 5], edge_index=[2, 0]), Data(x=[11, 5], edge_index=[2, 2]), Data(x=[11, 5], edge_index=[2, 4]), Data(x=[11, 5], edge_index=[2, 6]), Data(x=[11, 5], edge_index=[2, 8])] [Data(x=[11, 5], edge_index=[2, 12]), Data(x=[11, 5], edge_index=[2, 14]), Data(x=[11, 5], edge_index=[2, 16]), Data(x=[11, 5], edge_index=[2, 18]), Data(x=[11, 5], edge_index=[2, 20])]


In [3]:
irl_trainer.reward_fn.requires_grad_(True)
irl_trainer.reward_fn.train()
irl_trainer.agent.policy.requires_grad_(False)
irl_trainer.do_reward_grad_step()

expert return: -114.7294921875
expert return: -116.73078155517578
expert return: -108.88362121582031
expert return: -126.7417221069336
expert return: -112.21012878417969
expert return: -117.59815979003906
expert return: -123.91873931884766
expert return: -108.78541564941406
expert return: -120.0748519897461
expert return: -130.11383056640625
expert_rewards shape:  torch.Size([10, 11])
sampled undiscounted return: -139.0196075439453
sampled undiscounted return: -149.67205810546875
sampled undiscounted return: -145.4356231689453
sampled undiscounted return: -141.41696166992188
sampled undiscounted return: -122.41036224365234
sampled undiscounted return: -135.726318359375
sampled undiscounted return: -137.29229736328125
sampled undiscounted return: -146.5854034423828
sampled undiscounted return: -133.09445190429688
sampled undiscounted return: -125.06513214111328
sampled undiscounted return: -147.83804321289062
sampled undiscounted return: -127.80394744873047
sampled undiscounted return: 

In [4]:
torch.nn.utils.parameters_to_vector(irl_trainer.reward_fn.parameters())

tensor([-1.13003738e-02,  1.24404430e-02, -1.24041177e-02, -1.72416717e-02,
        -1.87604986e-02,  1.32125989e-02,  1.66871380e-02,  1.68949608e-02,
         5.25503457e-01,  5.49489796e-01,  3.66495341e-01,  2.65446931e-01,
        -1.83788076e-01,  3.75868529e-01,  3.17790210e-01,  2.16946438e-01,
        -1.22462017e-02,  1.07374601e-01, -1.60819143e-02, -4.78162974e-01,
         1.71968445e-01, -5.08480251e-01,  3.20569426e-01, -4.21342403e-01,
        -2.06193551e-01, -3.05994183e-01,  5.46577871e-01, -4.95204031e-01,
         3.64286751e-01,  3.60076994e-01,  4.97049958e-01,  1.51331976e-01,
         3.51813316e-01,  5.52480936e-01,  1.22800104e-01,  2.19568774e-01,
        -5.11437356e-01, -1.65760592e-01,  5.66379614e-02,  6.58650845e-02,
        -2.56528825e-01, -6.80525661e-01, -1.64700508e-01,  2.76522189e-01,
         5.68592191e-01, -6.29678607e-01,  3.43454599e-01, -6.29876316e-01,
         1.71520747e-02,  1.89763866e-02, -1.84109267e-02,  1.78483278e-02,
         1.6

In [3]:
import numpy as np
import torch
torch.set_float32_matmul_precision("highest")

rng = np.random.RandomState(1)
dataset1 = rng.uniform(low=-0.01, high=0.01, size=(1000, 20)).astype(np.float32)
dataset2 = torch.from_numpy(dataset1)
print(dataset1.mean(), dataset1.std())
print(dataset2.mean(), dataset2.std())

2.5537092e-05 0.005769608
tensor(2.5537e-05) tensor(0.0058)
