In [None]:
import relg

import time
import numpy as np

In [None]:
def generate_dataset(sim, num_sims, length, sample_freq):
    loc_all = list()
    vel_all = list()
    edges_all = list()

    for i in range(num_sims):
        t = time.time()
        loc, vel, edges = sim.sample_trajectory(
            T=length,
            sample_freq=sample_freq,
        )
        if i % 100 == 0:
            print("Iter: {}, Simulation time: {}".format(i, time.time() - t))
        loc_all.append(loc)
        vel_all.append(vel)
        edges_all.append(edges)

    loc_all = np.stack(loc_all)
    vel_all = np.stack(vel_all)
    edges_all = np.stack(edges_all)

    return loc_all, vel_all, edges_all

In [None]:
simulation = "springs"
nobjects = 5

if simulation == "springs":
    sim = relg.SpringSim(noise_var=0.0, n_balls=nobjects)
elif simulation == "charged":
    sim = relg.ChargedParticlesSim(noise_var=0.0, n_balls=nobjects)
else:
    raise ValueError(
        "Simulation {} not implemented".format(simulation),
    )

suffix = f"_{simulation}_{nobjects}"

In [None]:
train_sample_size = 10

trajectory_length = 5000
sample_freq = 100

print(
    "Generating {} training simulations".format(train_sample_size),
)
loc_train, vel_train, edges_train = generate_dataset(
    sim,
    train_sample_size,
    trajectory_length,
    sample_freq,
)

loc_train = np.transpose(loc_train, [0, 3, 1, 2])
vel_train = np.transpose(vel_train, [0, 3, 1, 2])

In [None]:
np.transpose(loc_train, [0, 3, 1, 2]).shape

In [None]:
num_atoms = loc_train.shape[3]

_edges_train = np.reshape(edges_train, [-1, num_atoms ** 2])
_edges_train = np.array((_edges_train + 1) / 2, dtype=np.int64)

In [None]:
off_diag_idx = np.ravel_multi_index(
        np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)),
        [num_atoms, num_atoms])

In [None]:
off_diag_idx

In [None]:
_edges_train[:, off_diag_idx]

In [None]:
import torch 
from torch.utils.data import Dataset
import os

class SmallSynthData(Dataset):
    def __init__(self, data_path, mode, params):
        self.mode = mode
        self.data_path = data_path
        if self.mode == 'train':
            path = os.path.join(data_path, 'train_feats')
            edge_path = os.path.join(data_path, 'train_edges')
        elif self.mode == 'val':
            path = os.path.join(data_path, 'val_feats')
            edge_path = os.path.join(data_path, 'val_edges')
        elif self.mode == 'test':
            path = os.path.join(data_path, 'test_feats')
            edge_path = os.path.join(data_path, 'test_edges')
        self.feats = torch.load(path)
        self.edges = torch.load(edge_path)
        self.same_norm = params['same_data_norm']
        self.no_norm = params['no_data_norm']
        if not self.no_norm:
            self._normalize_data()

    def _normalize_data(self):
        train_data = torch.load(os.path.join(self.data_path, 'train_feats'))
        if self.same_norm:
            self.feat_max = train_data.max()
            self.feat_min = train_data.min()
            self.feats = (self.feats - self.feat_min)*2/(self.feat_max-self.feat_min) - 1
        else:
            self.loc_max = train_data[:, :, :, :2].max()
            self.loc_min = train_data[:, :, :, :2].min()
            self.vel_max = train_data[:, :, :, 2:].max()
            self.vel_min = train_data[:, :, :, 2:].min()
            self.feats[:,:,:, :2] = (self.feats[:,:,:,:2]-self.loc_min)*2/(self.loc_max - self.loc_min) - 1
            self.feats[:,:,:,2:] = (self.feats[:,:,:,2:]-self.vel_min)*2/(self.vel_max-self.vel_min)-1

    def unnormalize(self, data):
        if self.no_norm:
            return data
        elif self.same_norm:
            return (data + 1) * (self.feat_max - self.feat_min) / 2. + self.feat_min
        else:
            result1 = (data[:, :, :, :2] + 1) * (self.loc_max - self.loc_min) / 2. + self.loc_min
            result2 = (data[:, :, :, 2:] + 1) * (self.vel_max - self.vel_min) / 2. + self.vel_min
            return np.concatenate([result1, result2], axis=-1)


    def __getitem__(self, idx):
        return {'inputs': self.feats[idx], 'edges':self.edges[idx]}

    def __len__(self):
        return len(self.feats)

In [None]:
dataset = relg.SmallSynthData("datasets", 'train', {"same_data_norm": False, "no_data_norm": None})

In [None]:
next(iter(dataset))["inputs"]

In [None]:
edges = np.ones(3) - np.eye(3)
send_edges = np.where(edges)[0]
recv_edges = np.where(edges)[1]

In [None]:
def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
    return labels_onehot


In [None]:
encode_onehot(recv_edges)

In [None]:
from torch import nn

In [None]:
nn.Parameter(
            torch.FloatTensor(encode_onehot(recv_edges).transpose()),
            requires_grad=False,
        )

In [None]:
z = torch.randn(8, 3, 10, 2)

In [None]:
z.transpose(2, 1).contiguous().shape

In [None]:
from torch import nn

In [None]:
l = nn.LSTM(64, 64, batch_first=True)

In [None]:
z = torch.randn(8, 10, 64)

In [None]:
o1, o2 = l(z)

In [None]:
o1.shape

In [1]:
import relg

In [2]:
encoder = relg.models.Decoder(
    input_dim=4,
    hidden_dim=64,
    num_objects=3,
    num_edge_types=3,
)

In [11]:
import torch
inputs = torch.randn(8, 3, 4)
hiddens = torch.randn(8, 3, 64)
edgesf = torch.randn(8, 6, 3)
encoder(inputs, hiddens, edgesf)[1].shape

torch.Size([8, 3, 64])

In [4]:
import numpy as np
from torch import nn


def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
    return labels_onehot

edges = np.ones(3) - np.eye(3)
send_edges = np.where(edges)[0]
recv_edges = np.where(edges)[1]

edge2node_mat = nn.Parameter(
            torch.FloatTensor(
                encode_onehot(recv_edges).transpose(),
            ),
            requires_grad=False,
        )

In [5]:
edge2node_mat.shape

torch.Size([3, 6])

In [6]:
l =edge2node_mat.transpose(-2, -1)
p= torch.randn(8, 6, 64)
p.transpose(-2, -1).matmul(l).transpose(-2, -1)

tensor([[[-0.7651,  0.2843, -0.0388,  ..., -0.2320,  1.4379,  0.7490],
         [-0.4536,  0.6781, -0.9703,  ..., -0.8194,  3.3747,  0.1665],
         [-0.5535, -0.6488, -0.7841,  ..., -1.7051,  2.0720, -0.8003]],

        [[ 0.6678,  1.1954,  0.6067,  ..., -1.5256, -0.0294, -0.1141],
         [-1.7200,  0.2290,  1.5642,  ..., -3.3195,  0.8546, -0.5150],
         [ 0.1222,  2.2425, -0.2941,  ...,  3.7068, -0.0525, -1.9594]],

        [[ 1.6954, -3.8929, -1.6319,  ...,  1.3852, -1.7527, -0.2604],
         [-0.6626, -1.4667, -0.9470,  ..., -3.0094,  1.5733, -1.3428],
         [ 2.0411, -0.4008,  2.2666,  ...,  2.5074,  0.8329, -0.2590]],

        ...,

        [[-0.0711,  0.6353,  2.1946,  ..., -3.3038,  1.0147, -2.0492],
         [ 0.6742, -0.1992,  1.9056,  ...,  0.2868, -1.2637, -0.3676],
         [ 0.6002,  0.6011, -0.9073,  ..., -1.6150, -4.6424,  2.3048]],

        [[-1.1163, -2.0986,  1.3911,  ..., -0.5495,  0.3949, -1.9372],
         [ 0.6155,  3.5560,  0.5999,  ...,  1.4485,  0.

In [7]:
torch.matmul(edge2node_mat, p)

tensor([[[-0.7651,  0.2843, -0.0388,  ..., -0.2320,  1.4379,  0.7490],
         [-0.4536,  0.6781, -0.9703,  ..., -0.8194,  3.3747,  0.1665],
         [-0.5535, -0.6488, -0.7841,  ..., -1.7051,  2.0720, -0.8003]],

        [[ 0.6678,  1.1954,  0.6067,  ..., -1.5256, -0.0294, -0.1141],
         [-1.7200,  0.2290,  1.5642,  ..., -3.3195,  0.8546, -0.5150],
         [ 0.1222,  2.2425, -0.2941,  ...,  3.7068, -0.0525, -1.9594]],

        [[ 1.6954, -3.8929, -1.6319,  ...,  1.3852, -1.7527, -0.2604],
         [-0.6626, -1.4667, -0.9470,  ..., -3.0094,  1.5733, -1.3428],
         [ 2.0411, -0.4008,  2.2666,  ...,  2.5074,  0.8329, -0.2590]],

        ...,

        [[-0.0711,  0.6353,  2.1946,  ..., -3.3038,  1.0147, -2.0492],
         [ 0.6742, -0.1992,  1.9056,  ...,  0.2868, -1.2637, -0.3676],
         [ 0.6002,  0.6011, -0.9073,  ..., -1.6150, -4.6424,  2.3048]],

        [[-1.1163, -2.0986,  1.3911,  ..., -0.5495,  0.3949, -1.9372],
         [ 0.6155,  3.5560,  0.5999,  ...,  1.4485,  0.

In [14]:
edges = torch.randn(8, 10, 3, 3)

In [8]:
edge2node_mat.shape

torch.Size([3, 6])

In [17]:
prior = np.zeros(3)
prior.fill((1 - 0.8)/(3 - 1))
prior[0] = 0.8
# log_prior = torch.FloatTensor(np.log(prior))
# log_prior = torch.unsqueeze(log_prior, 0)
# log_prior = torch.unsqueeze(log_prior, 0)
# log_prior = log_prior.cuda(non_blocking=True)
# log_prior = log_prior
# print("USING NO EDGE PRIOR: ",log_prior)

In [20]:
prior = np.zeros(3)
prior.fill(1.0/3)
log_prior = torch.FloatTensor(np.log(prior))
log_prior

tensor([-1.0986, -1.0986, -1.0986])