In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.utils.data as data 

from omegaconf import OmegaConf
from torchvision import transforms
from torch.nn.parallel import DistributedDataParallel as DDP
# 
from contrastive_learning.tests.test_model import load_lin_model
from contrastive_learning.models.custom_models import LinearInverse
from contrastive_learning.datasets.pli_dataset import Dataset, get_dataloaders

In [2]:
# Start the multiprocessing to load the saved models properly
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29503"

torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
torch.cuda.set_device(0)

In [3]:
# Set the device and out_dir
device = torch.device('cuda:0')
out_dir = '/home/irmak/Workspace/DAWGE/contrastive_learning/out/2022.07.15/11-36_pli'
cfg = OmegaConf.load('/home/irmak/Workspace/DAWGE/contrastive_learning/configs/pli_train.yaml')
model_path = os.path.join(out_dir, 'models/lin_model.pt')

# Load the encoder
lin_model = load_lin_model(cfg, device, model_path)

In [4]:
print(lin_model)

DistributedDataParallel(
  (module): LinearInverse(
    (model): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=2, bias=True)
    )
  )
)


In [5]:
# Get the dataloaders and compare the actions
train_loader, test_loader, dataset = get_dataloaders(cfg)

In [14]:
for batch in test_loader:
    curr_pos, next_pos, action = [b.to(device) for b in batch]
    pred_action = lin_model(curr_pos, next_pos)
    
    print('Actual Action \t Predicted Action')
    for i in range(len(action)):
        print('{}, \t{}'.format(np.around(dataset.denormalize_action(action[i][0].cpu().detach().numpy()), 2),
                                                               dataset.denormalize_action(pred_action[i][0].cpu().detach().numpy())))
    
    break

Actual Action 	 Predicted Action
[-0.   -0.23], 	[-0.05872796 -0.34057811]
[-0.   -0.23], 	[ 0.11609759 -0.01396022]
[-0.   -0.23], 	[ 0.01402901 -0.20464988]
[0.15 0.05], 	[0.14767949 0.04504267]
[0.15 0.05], 	[0.15712533 0.06268985]
[0.15 0.05], 	[0.15374675 0.05637783]
[0.15 0.05], 	[-0.02441072 -0.27646492]
[0.15 0.05], 	[ 0.08207919 -0.0775151 ]
[0.15 0.05], 	[0.12573274 0.00404065]
[-0.   -0.23], 	[-0.01717921 -0.26295464]
[0.15 0.05], 	[0.14708777 0.04393718]
[0.15 0.05], 	[0.1514471  0.05208152]
[0.15 0.05], 	[0.1402964  0.03124922]
[0.15 0.05], 	[0.14984275 0.04908419]
[0.15 0.05], 	[0.14277234 0.03587489]
[-0.   -0.23], 	[ 0.10006954 -0.04390463]
[-0.   -0.23], 	[-0.00229181 -0.23514125]
[0.15 0.05], 	[0.14354785 0.03732373]
[0.15 0.05], 	[0.13249231 0.01666921]
[0.15 0.05], 	[0.15323466 0.05542113]
[0.15 0.05], 	[0.14871982 0.04698627]
[-0.   -0.23], 	[0.13835142 0.02761549]
[-0.   -0.23], 	[ 0.09222337 -0.05856324]
[0.15 0.05], 	[ 0.11881468 -0.00888402]
[-0.   -0.23], 	[ 0