In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.utils.data as data 

from omegaconf import OmegaConf
from torchvision import transforms
from torch.nn.parallel import DistributedDataParallel as DDP
# 
from contrastive_learning.tests.test_model import load_lin_model
from contrastive_learning.models.custom_models import LinearInverse
from contrastive_learning.datasets.pli_dataset import Dataset, get_dataloaders

In [2]:
# Start the multiprocessing to load the saved models properly
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29503"

torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
torch.cuda.set_device(0)

In [4]:
# Set the device and out_dir
device = torch.device('cuda:0')
out_dir = '/home/irmak/Workspace/DAWGE/contrastive_learning/out/2022.07.15/14-42_pli'
cfg = OmegaConf.load('/home/irmak/Workspace/DAWGE/contrastive_learning/configs/pli_train.yaml')
model_path = os.path.join(out_dir, 'models/lin_model.pt')

# Load the encoder
lin_model = load_lin_model(cfg, device, model_path)

In [5]:
print(lin_model)

DistributedDataParallel(
  (module): LinearInverse(
    (model): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=2, bias=True)
    )
  )
)


In [6]:
# Get the dataloaders and compare the actions
train_loader, test_loader, dataset = get_dataloaders(cfg)

In [7]:
all_predicted_actions = np.zeros((len(test_loader)*cfg.batch_size, 2))
for i,batch in enumerate(test_loader):
    curr_pos, next_pos, action = [b.to(device) for b in batch]
    pred_action = lin_model(curr_pos, next_pos)
    
    print('Actual Action \t Predicted Action')
    for j in range(len(action)):
        print('{}, \t{}'.format(np.around(dataset.denormalize_action(action[j][0].cpu().detach().numpy()), 2),
                                                               dataset.denormalize_action(pred_action[j][0].cpu().detach().numpy())))
        all_predicted_actions[i*cfg.batch_size+j,:] = dataset.denormalize_action(pred_action[j][0].cpu().detach().numpy())
    
#     break

Actual Action 	 Predicted Action
[-0.   -0.23], 	[ 0.0197001  -0.19405486]
[-0.   -0.23], 	[ 0.05963099 -0.11945397]
[-0.   -0.23], 	[ 0.04347027 -0.14964623]
[0.15 0.05], 	[0.13332138 0.01821812]
[0.15 0.05], 	[0.14905155 0.04760601]
[0.15 0.05], 	[0.15013819 0.04963614]
[0.15 0.05], 	[-0.01637866 -0.26145901]
[0.15 0.05], 	[0.13610744 0.02342318]
[0.15 0.05], 	[0.13165336 0.01510185]
[-0.   -0.23], 	[-0.00751384 -0.24489732]
[0.15 0.05], 	[0.14642642 0.04270161]
[0.15 0.05], 	[0.15013819 0.04963614]
[0.15 0.05], 	[0.13858149 0.02804533]
[0.15 0.05], 	[0.14654584 0.04292473]
[0.15 0.05], 	[0.13479185 0.02096534]
[-0.   -0.23], 	[0.14702245 0.04381516]
[-0.   -0.23], 	[ 0.00623614 -0.2192089 ]
[0.15 0.05], 	[0.14687402 0.04353786]
[0.15 0.05], 	[0.12757606 0.00748444]
[0.15 0.05], 	[0.15013819 0.04963614]
[0.15 0.05], 	[0.15013819 0.04963614]
[-0.   -0.23], 	[ 0.05730327 -0.12380273]
[-0.   -0.23], 	[ 0.09520067 -0.05300089]
[0.15 0.05], 	[ 0.11248257 -0.02071399]
[-0.   -0.23], 	[ 0.0

[0.15 0.05], 	[0.15013819 0.04963614]
[-0.   -0.23], 	[ 0.04288381 -0.15074188]
[-0.   -0.23], 	[-0.05492865 -0.33348004]
[0.15 0.05], 	[0.15013819 0.04963614]
[0.15 0.05], 	[0.15013819 0.04963614]
[0.15 0.05], 	[0.14973961 0.04889149]
[0.15 0.05], 	[0.1482049  0.04602427]
[-0.   -0.23], 	[ 0.01302729 -0.20652134]
[0.15 0.05], 	[ 0.08468345 -0.07264969]
[-0.   -0.23], 	[-0.01161824 -0.25256535]
[0.15 0.05], 	[0.14492458 0.0398958 ]
[0.15 0.05], 	[0.15013819 0.04963614]
[0.15 0.05], 	[ 0.04528861 -0.14624911]
[0.15 0.05], 	[0.15013819 0.04963614]
[-0.   -0.23], 	[-0.01060347 -0.25066951]
[-0.   -0.23], 	[-0.03091268 -0.2886122 ]
[0.15 0.05], 	[0.15742668 0.06325286]
[0.15 0.05], 	[ 0.09178519 -0.05938187]
[0.15 0.05], 	[0.15013819 0.04963614]
[0.15 0.05], 	[0.13364479 0.01882233]
[0.15 0.05], 	[0.13490308 0.02117313]
[0.15 0.05], 	[0.15013819 0.04963614]
[0.15 0.05], 	[0.15013819 0.04963614]
[0.15 0.05], 	[0.14954376 0.0485256 ]
[-0.   -0.23], 	[ 0.00094047 -0.22910254]
[0.15 0.05], 	[0

In [8]:
# Dump all the test predicted actions
with open(os.path.join(out_dir, 'predicted_actions.npy'), 'wb') as f:
    np.save(f, all_predicted_actions)