In [1]:
import torch

In [2]:
from data_loader import load_datasets

In [26]:
import os

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

class CustomTrajDataset(Dataset):
    def __init__(self, traj_df, mode="append"):
        positions = torch.from_numpy(np.array(list(traj_df['position']))).type(torch.FloatTensor)
        orientations = torch.from_numpy(np.array(list(traj_df['orientation']))).type(torch.FloatTensor)
        forces = torch.from_numpy(np.array(list(traj_df['net_force']))).type(torch.FloatTensor)
        torques = torch.from_numpy(np.array(list(traj_df['net_torque']))).type(torch.FloatTensor)

        if mode == "append":
            self.input = torch.cat((positions, orientations), 2)
        else:
            orientations  = orientations[:, :, :4]
            self.input = torch.stack((positions, orientations), dim=2)

        self.in_dim = self.input.shape[-1]
        self.forces = forces
        self.torques = torques

    def __len__(self):
        return len(self.input)

    def __getitem__(self, i):
        return self.input[i], self.forces[i], self.torques[i]


def _get_data_loader(dataset, batch_size, shuffle=True):
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=5)
    return dataloader


In [27]:
data_path = "/home/marjanalbooyeh/logs/datasets/pps_two_synthesized/neighbors/"
val_df = pd.read_pickle(os.path.join(data_path, 'val.pkl'))

In [28]:
valid_dataset = CustomTrajDataset(val_df, mode="stack")

In [29]:
val_dataloader = _get_data_loader(valid_dataset, 16)



In [30]:
for (inp, force, torque) in val_dataloader:
    print(inp.shape)
    print(force.shape)
    print(torque.shape)
    break



torch.Size([16, 2, 2, 4])
torch.Size([16, 2, 3])
torch.Size([16, 2, 3])


In [31]:
import torch.nn as nn

In [47]:
m= nn.Conv2d(in_channels=2, out_channels=2, kernel_size=(2, 4))

In [48]:
out1 = m(inp)

In [49]:
out1.shape

torch.Size([16, 2, 1, 1])

In [53]:
l = nn.Linear(4,3)

In [56]:
out2 = l(inp)
out2

tensor([[[[ 2.1777e-01, -1.2484e+00, -2.6752e-01],
          [-8.6078e-01, -1.5895e-01,  5.1757e-01]],

         [[ 6.2031e-01, -1.2318e+00, -2.3583e-01],
          [ 4.4008e-01, -7.5988e-01, -6.8096e-01]]],


        [[[ 4.3609e-01, -2.9435e+00, -1.5278e+00],
          [-1.1603e+00,  5.9145e-01,  6.6783e-01]],

         [[ 7.9951e-01, -2.9683e-02,  9.1711e-01],
          [ 7.3960e-01, -1.5103e+00, -8.3122e-01]]],


        [[[ 3.5815e-01, -2.5313e+00, -1.2898e+00],
          [-2.6899e-01,  4.4421e-01,  2.9368e-02]],

         [[ 5.9351e-01, -8.9732e-02,  7.5578e-01],
          [-1.5171e-01, -1.3630e+00, -1.9276e-01]]],


        [[[ 3.3401e-01, -1.7279e+00, -5.1905e-01],
          [-1.2832e+00, -3.5133e-01,  1.1802e-01]],

         [[ 9.0158e-01, -1.2454e+00, -9.1655e-02],
          [ 8.6251e-01, -5.6750e-01, -2.8140e-01]]],


        [[[ 3.5268e-01, -1.7713e+00, -5.3404e-01],
          [ 4.9641e-02, -6.5316e-01, -1.8936e-01]],

         [[ 9.3971e-01, -1.2723e+00, -9.1997e-02],
     

In [58]:
out2.mean(dim=-2).shape

torch.Size([16, 2, 3])