In [1]:
# Script to check each representation with the given encoders
# Will receive:
# a list of encoders to try for each task
# a list of experts to try the encoders on
import os
import hydra
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.utils import save_image
from torchvision import transforms as T
from PIL import Image
# from agent.encoder import Encoder

from tactile_learning.models import *
from tactile_learning.utils import *
from tactile_learning.tactile_data import *

In [3]:
DEVICE = 'cuda:1'
TACTILE_OUT_DIR = '/home/irmak/Workspace/tactile-learning/tactile_learning/out/2023.01.28/12-32_tactile_byol_bs_512_tactile_play_data_alexnet_pretrained_duration_120'

In [4]:
def get_tactile_repr_module(device):
    tactile_cfg, tactile_encoder, _ = init_encoder_info(
        device = device,
        out_dir = TACTILE_OUT_DIR,
        encoder_type = 'tactile',
        model_type='byol'
    )
    tactile_img = TactileImage(
        tactile_image_size = tactile_cfg.tactile_image_size, 
        shuffle_type = None
    )
    tactile_repr_module = TactileRepresentation( # This will be used when calculating the reward - not getting the observations
        encoder_out_dim = tactile_cfg.encoder.out_dim,
        tactile_encoder = tactile_encoder,
        tactile_image = tactile_img,
        representation_type = 'tdex',
        device = device
    )

    return tactile_repr_module


In [19]:
def get_expert_representations_per_encoder(encoder, task_expert_demos, device):
    # Traverse through all the experts and get the representations
    task_representations = []
    for expert_id in range(len(task_expert_demos)):
        expert_representations = encoder(task_expert_demos[expert_id]['image_obs'].to(device)) # One trajectory representations
        task_representations.append(expert_representations)
    
    return task_representations


In [20]:
def calc_traj_score(traj1, traj2):
    # traj1.shape: (80, 512), traj2.shape: (80,512)
    cost_matrix = cosine_distance(
            traj1, traj2)  # Get cost matrix for samples using critic network.
    transport_plan = optimal_transport_plan(
        traj1, traj2, cost_matrix, method='sinkhorn',
        niter=100, exponential_weight_init=False).float().detach().cpu().numpy()

    max_transport_plan = np.max(transport_plan, axis=1) # We are going to find the maximums for traj1
    print('max_transport_plan.shape: {}, traj1.shape: {}, traj2.shape: {}'.format(
        max_transport_plan.shape, traj1.shape, traj2.shape
    ))
    return np.sum(max_transport_plan)

In [21]:
def calc_encoder_score(encoder, all_expert_demos, encoder_id, device): # Will get all the representations and calculate the score of the 

    all_expert_representations = get_expert_representations_per_encoder(
        encoder = encoder,
        task_expert_demos = all_expert_demos,
        device = device
    )

    # Get combinations of the trajectories and calculate the score for them
    score_matrix = np.zeros((5,5))
    for i in range(score_matrix.shape[0]):
        for j in range(score_matrix.shape[1]):
            traj1 = all_expert_representations[i] 
            traj2 = all_expert_representations[j] 
            score_matrix[i,j] = calc_traj_score(traj1, traj2)

    print('SCORE MATRIX FOR ENCODER: {} \n{}\n-----'.format(
        encoder_id, 
        score_matrix
    ))

    return score_matrix


In [None]:
# # Method to load all the tasks
# def load_all_expert_demos(task_names, expert_demo_nums, view_nums, device):
#     # Create the tactile repr module - this is the same for all the tasks
#     tactile_repr_module = get_tactile_repr_module(device)
#     all_experts = []

#     for i, task_name in enumerate(task_names):
#         root_path = f'/home/irmak/Workspace/Holo-Bot/extracted_data/{task_name}'
#         task_expert_demo_nums = expert_demo_nums[i]
        
#         view_num = view_nums[i]
#         def viewed_crop_transform(image):
#             return crop_transform(image, camera_view=view_num)
#         image_transform =  T.Compose([
#             T.Resize((480,640)),
#             T.Lambda(viewed_crop_transform),
#             T.Resize(480),
#             T.ToTensor(),
#             T.Normalize(VISION_IMAGE_MEANS, VISION_IMAGE_STDS), 
#         ])
#         task_expert_demos = load_expert_demos_per_task(
#             data_path = root_path,
#             expert_demo_nums = task_expert_demo_nums,
#             tactile_repr_module = tactile_repr_module,
#             image_transform = image_transform,
#             view_num = view_num
#         )

#         all_experts.append(
#             task_expert_demos
#         )

#     return all_experts

In [15]:


# This image transform will have everything
def load_expert_demos_per_task(task_name, expert_demo_nums, view_num, device):
    data_path = f'/home/irmak/Workspace/Holo-Bot/extracted_data/{task_name}'
    roots = sorted(glob.glob(f'{data_path}/demonstration_*'))
    data = load_data(roots, demos_to_use=expert_demo_nums) # NOTE: This could be fucked up

    # Get the tactile module and the image transform
    def viewed_crop_transform(image):
        return crop_transform(image, camera_view=view_num)
    image_transform =  T.Compose([
        T.Resize((480,640)),
        T.Lambda(viewed_crop_transform),
        T.Resize(480),
        T.ToTensor(),
        T.Normalize(VISION_IMAGE_MEANS, VISION_IMAGE_STDS), 
    ])
    
    expert_demos = []
    image_obs = [] 
    old_demo_id = -1
    for step_id in range(len(data['image']['indices'])): 
        demo_id, image_id = data['image']['indices'][step_id]
        if (demo_id != old_demo_id and step_id > 0) or (step_id == len(data['image']['indices'])-1): # NOTE: We are losing the last frame of the last expert

            expert_demos.append(dict(
                image_obs = torch.stack(image_obs, 0), 
            ))
            image_obs = [] 

        image = load_dataset_image(
            data_path = data_path, 
            demo_id = demo_id, 
            image_id = image_id,
            view_num = view_num,
            transform = image_transform
        )
        image_obs.append(image)
        # tactile_reprs.append(tactile_repr)


        old_demo_id = demo_id

    return expert_demos

In [10]:
# # Load all the expert demos
# ALL_EXPERT_DEMOS = load_all_expert_demos(
#     task_names = [
#         # 'plier_picking',
#         'bowl_picking',
#         # 'card_flipping',
#         # 'card_turning',
#         # 'peg_insertion'
#     ],
#     expert_demo_nums=[
#         # [3,10,15,16,20,25],
#         [],
#         # [24,26,27,31,32,33],
#         # [],
#         # [] 
#     ],
#     view_nums = [
#         # 0, 1, 0, 0, 0
#         1
#     ]
# )

Using cache found in /home/irmak/.cache/torch/hub/pytorch_vision_v0.10.0


mod_name: collections, name: OrderedDict
mod_name: torch._utils, name: _rebuild_parameter
mod_name: torch._utils, name: _rebuild_tensor_v2


In [16]:
bowl_unstacking_info = dict(
    encoders = [
        dict(
            model_path = '/home/irmak/Workspace/tactile-learning/tactile_learning/out/2023.05.11/13-21_bc_bs_32_epochs_500_lr_1e-05_bowl_picking_after_rss',
            model_type = 'bc',
            view_num = 1,
            encoder_fn = None,
            device = 0
        ),
        dict(
            model_path = '/home/irmak/Workspace/tactile-learning/tactile_learning/out/2023.06.06/18-27_temporal_ssl_bs_32_epochs_1000_view_1_bowl_picking_frame_diff_5_resnet',
            model_type = 'temporal',
            view_num = 1,
            encoder_fn = None,
            device = 1,
        ),
        dict(
            model_path = '/home/irmak/Workspace/tactile-learning/tactile_learning/out/2023.05.06/10-50_image_byol_bs_32_epochs_500_lr_1e-05_bowl_picking_after_rss',
            model_type = 'byol',
            view_num = 1,
            encoder_fn = None,
            device = 2, 
        ),
        dict(
            model_path = None,
            model_type = 'pretrained',
            encoder_fn = resnet18,
            view_num = 1,
            device = 3
        )
    ],
    demos = dict(
        task_name = 'bowl_picking',
        expert_demo_nums = [],
        view_num = 1  
    )
)

In [14]:
def load_encoder(view_num, model_type, model_path, encoder_fn, device):
    # print(kwargs)
    # view_num = kwargs['view_num']
    # model_type = kwargs['model_type']
    # model_path=kwargs['model_path']
    # encoder_fn=kwargs['encoder_fn']
    # devive
    
    # device = torch.device(DEVICE)
    if model_type == 'pretrained' and not (encoder_fn is None):
        # It means that this is pretrained
        image_encoder = encoder_fn(pretrained=True, out_dim=512, remove_last_layer=True).to(device)

    else:
        _, image_encoder, _ = init_encoder_info(
            device = device,
            out_dir = model_path,
            encoder_type = 'image',
            view_num = view_num,
            model_type = model_type
        )

    return image_encoder

In [11]:
bowl_unstacking_encoders = [
    load_encoder(**encoder_args) for encoder_args in bowl_unstacking_info['encoders']
]



mod_name: collections, name: OrderedDict
mod_name: torch._utils, name: _rebuild_parameter
mod_name: torch._utils, name: _rebuild_tensor_v2
mod_name: collections, name: OrderedDict
mod_name: torch._utils, name: _rebuild_parameter
mod_name: torch._utils, name: _rebuild_tensor_v2
mod_name: collections, name: OrderedDict
mod_name: torch._utils, name: _rebuild_parameter
mod_name: torch._utils, name: _rebuild_tensor_v2


In [18]:
bowl_unstacking_demos = [load_expert_demos_per_task(**bowl_unstacking_info['demos'], device=i) for i in range(4)]

Using cache found in /home/irmak/.cache/torch/hub/pytorch_vision_v0.10.0


mod_name: collections, name: OrderedDict
mod_name: torch._utils, name: _rebuild_parameter
mod_name: torch._utils, name: _rebuild_tensor_v2


Using cache found in /home/irmak/.cache/torch/hub/pytorch_vision_v0.10.0


mod_name: collections, name: OrderedDict
mod_name: torch._utils, name: _rebuild_parameter
mod_name: torch._utils, name: _rebuild_tensor_v2


Using cache found in /home/irmak/.cache/torch/hub/pytorch_vision_v0.10.0


mod_name: collections, name: OrderedDict
mod_name: torch._utils, name: _rebuild_parameter
mod_name: torch._utils, name: _rebuild_tensor_v2


Using cache found in /home/irmak/.cache/torch/hub/pytorch_vision_v0.10.0


mod_name: collections, name: OrderedDict
mod_name: torch._utils, name: _rebuild_parameter
mod_name: torch._utils, name: _rebuild_tensor_v2


In [24]:
for encoder_id, encoder in enumerate(bowl_unstacking_encoders):
    encoder_score = calc_encoder_score(
        encoder = encoder,
        all_expert_demos = bowl_unstacking_demos[encoder_id], 
        encoder_id = encoder_id,
        device = encoder_id
    )
    print('id: {} encoder_score: {}'.format(encoder_id, encoder_score))
    print('----')



max_transport_plan.shape: (73,), traj1.shape: torch.Size([73, 512]), traj2.shape: torch.Size([73, 512])
max_transport_plan.shape: (73,), traj1.shape: torch.Size([73, 512]), traj2.shape: torch.Size([69, 512])
max_transport_plan.shape: (73,), traj1.shape: torch.Size([73, 512]), traj2.shape: torch.Size([72, 512])
max_transport_plan.shape: (73,), traj1.shape: torch.Size([73, 512]), traj2.shape: torch.Size([85, 512])
max_transport_plan.shape: (73,), traj1.shape: torch.Size([73, 512]), traj2.shape: torch.Size([80, 512])
max_transport_plan.shape: (69,), traj1.shape: torch.Size([69, 512]), traj2.shape: torch.Size([73, 512])
max_transport_plan.shape: (69,), traj1.shape: torch.Size([69, 512]), traj2.shape: torch.Size([69, 512])
max_transport_plan.shape: (69,), traj1.shape: torch.Size([69, 512]), traj2.shape: torch.Size([72, 512])
max_transport_plan.shape: (69,), traj1.shape: torch.Size([69, 512]), traj2.shape: torch.Size([85, 512])
max_transport_plan.shape: (69,), traj1.shape: torch.Size([69, 51

RuntimeError: CUDA out of memory. Tried to allocate 122.00 MiB (GPU 1; 15.74 GiB total capacity; 14.03 GiB already allocated; 67.06 MiB free; 14.17 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF