In [18]:
import pickle
import numpy as np
from torchvision.transforms import Resize
from PIL import Image
import h5py
from tqdm import tqdm

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from collections import deque

# load filenames
with open('filenames.pickle', 'rb') as f:
    filenames = pickle.load(f)

In [14]:
# Functions for preprocessing images

def _preprocess_images(images):
    assert images.dtype == np.uint8, 'image need to be uint8!'
    images = resize_video(images, (32, 32))
    images = np.transpose(images, [0, 3, 1, 2])  # convert to channel-first
    images = images.astype(np.float32) / 255 * 2 - 1
    assert images.dtype == np.float32, 'image need to be float32!'
    return images

def resize_video(video, size):
    if video.shape[1] == 3:
        video = np.transpose(video, (0,2,3,1))
    transformed_video = np.stack([np.asarray(Resize(size)(Image.fromarray(im))) for im in video], axis=0)
    return transformed_video


class AttrDict(dict):
    __setattr__ = dict.__setitem__

    def __getattr__(self, attr):
        # Take care that getattr() raises AttributeError, not KeyError.
        # Required e.g. for hasattr(), deepcopy and OrderedDict.
        try:
            return self.__getitem__(attr)
        except KeyError:
            raise AttributeError("Attribute %r not found" % attr)

    def __getstate__(self):
        return self

    def __setstate__(self, d):
        self = d
        

def sample_batch_function() :
    # sample 128 state-action pairs from 79131 trajectories
    one_batch_indexes = list(np.random.randint(0, len(filenames)-1, size=128))

    prep_data = AttrDict()

    st_i = 0

    for i in one_batch_indexes :

        st_i += 1
        index = i
        data = AttrDict()
        samples_per_file = 1
        file_index = index // samples_per_file
        path = filenames[file_index]


        try:
            with h5py.File(path, 'r') as F:
                ex_index = index % samples_per_file  # get the index
                key = 'traj{}'.format(ex_index)

                # Fetch data into a dict
                for name in F[key].keys():
                    if name in ['states', 'actions', 'pad_mask']:
                        data[name] = F[key + '/' + name][()].astype(np.float32)

                if key + '/images' in F:
                    data.images = F[key + '/images'][()]
                else:
                    data.images = np.zeros((data.states.shape[0], 2, 2, 3), dtype=np.uint8)
                    
                # terminals
                data.terminals = np.full((data.states.shape[0],), False, dtype=bool)
                data.terminals[-1] = True
                
        except:
            raise ValueError("Could not load from file {}".format(path))


        data.images = _preprocess_images(data.images)
        one_transition = np.random.randint(1, data.states.shape[0]-2, size=1)
        data.conc_images = np.concatenate((data.images[one_transition-1], data.images[one_transition]),axis=1)
        data.next_conc_images = np.concatenate((data.images[one_transition], data.images[one_transition+1]),axis=1)
        

        
        if st_i == 1 :

            prep_data.actions = data.actions[one_transition]
            prep_data.images = data.images[one_transition]
            prep_data.conc_images = data.conc_images 
            prep_data.next_conc_images = data.next_conc_images
        
        else :
            prep_data.actions = np.concatenate((prep_data.actions, data.actions[one_transition]),axis=0)
            prep_data.images = np.concatenate((prep_data.images, data.images[one_transition]),axis=0)
            prep_data.conc_images = np.concatenate((prep_data.conc_images, data.conc_images),axis=0)
            prep_data.next_conc_images = np.concatenate((prep_data.next_conc_images, data.next_conc_images),axis=0)
    
    return prep_data

In [30]:
# Unsupervised NN for extracting features from multi-goal datasets

class Phi(nn.Module): #A
    def __init__(self):
        super(Phi, self).__init__()
        # state : (32, 32, 3)
        # state : (128, 60) -> phi(st) : (128, 288)
        self.conv1 = nn.Conv2d(6, 8, kernel_size=(4,4), stride=2, padding=1)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=(4,4), stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(16)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=(4,4), stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(32)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=(4,4), stride=1)

    def forward(self,x):
        
        # torch.Size([128, 2, 3, 32, 32])
        #x = F.normalize(x)

        # torch.Size([128, 32, 2, 16, 16])
        y = F.leaky_relu(self.conv1(x), negative_slope=0.2, inplace=True)
        y = F.leaky_relu(self.bn2(self.conv2(y)), negative_slope=0.2, inplace=True)
        y = F.leaky_relu(self.bn3(self.conv3(y)), negative_slope=0.2, inplace=True)
        y = self.conv4(y)
        
        # torch.Size([128, 34])
        y = y.flatten(start_dim=1) 
        return y

class Gnet(nn.Module): #B
    def __init__(self):
        super(Gnet, self).__init__()
        
        #
        self.linear1 = nn.Linear(64,32)
        self.linear2 = nn.Linear(32,2)

    def forward(self, state1,state2):
        
        # phi(st) : (128, 32) + phi(st+1) : (128, 32) = (128, 64)
        x = torch.cat( (state1, state2) ,dim=1)
        y = F.relu(self.linear1(x))
        y = self.linear2(y)

        return y

class Fnet(nn.Module): #C
    def __init__(self):
        super(Fnet, self).__init__()
        
        # phi(st) : (128, 32) + action : (128, 2) = (128, 34)
        # (128, 297) -> (128, 1024)
        self.linear1 = nn.Linear(34,52)
        self.linear2 = nn.Linear(52,32)

    def forward(self,state,action):
        x = torch.cat((state,action) ,dim=1)
        y = F.relu(self.linear1(x))
        y = self.linear2(y)
        return y
    
params = {
    'batch_size':150,
    'beta':0.2,
    'lambda':0.1,
    'eta': 1.0,
    'gamma':0.2,
    'max_episode_len':100,
    'min_progress':15,
    'action_repeats':6,
    'frames_per_state':3
}


def loss_fn(inverse_loss, forward_loss):
    # loss_  = torch.Size([128, 1])
    loss_ = (1 - params['beta']) * inverse_loss
    loss_ += params['beta'] * forward_loss
    
    # loss_.flatten() : torch.Size([128])
    loss = loss_.sum() / loss_.flatten().shape[0]
    return loss


def ICM(state1, action, state2, forward_scale=1., inverse_scale=1e4):
    state1_hat = encoder(state1) #A
    state2_hat = encoder(state2)
    state2_hat_pred = forward_model(state1_hat.detach(), action.detach()) #B
    
    # forward_loss : torch.Size([128, 288])
    # forward_loss.sum(dim=1) : torch.Size([128])
    # forward_loss.sum(dim=1).unsqueeze(dim=1) : torch.Size([128, 1])
    forward_pred_err = forward_scale * forward_loss(state2_hat_pred, \
                        state2_hat.detach()).sum(dim=1).unsqueeze(dim=1)
    
    # torch.Size([128, 9])
    pred_action = inverse_model(state1_hat, state2_hat) #C
    
    # inverse_loss : torch.Size([128, 9])
    # inverse_loss.sum(dim=1) : torch.Size([128])
    # inverse_loss.sum(dim=1).unsqueeze(dim=1) : torch.Size([128, 1])
    inverse_pred_err = inverse_scale * inverse_loss(pred_action, \
                                        action.detach()).sum(dim=1).unsqueeze(dim=1)
    return forward_pred_err, inverse_pred_err


def minibatch_train(use_extrinsic=True):

    batch_loaded = sample_batch_function()

    state1_batch = torch.from_numpy(batch_loaded.conc_images).to("cuda")
    state2_batch = torch.from_numpy(batch_loaded.next_conc_images).to("cuda")
    action_batch = torch.from_numpy(batch_loaded.actions).to("cuda")

    
    forward_pred_err, inverse_pred_err = ICM(state1_batch, action_batch, state2_batch) #B
    i_reward = (1. / params['eta']) * forward_pred_err #C
    i_reward = i_reward.detach() #D
 
    return forward_pred_err, inverse_pred_err, i_reward


In [27]:
encoder = Phi().to("cuda")
forward_model = Fnet().to("cuda")
inverse_model = Gnet().to("cuda")
forward_loss = nn.MSELoss(reduction='none')
inverse_loss = nn.MSELoss(reduction='none')

all_model_params = list(encoder.parameters()) #A
all_model_params += list(forward_model.parameters()) + list(inverse_model.parameters())
opt = optim.Adam(lr=0.001, params=all_model_params)

In [31]:
# Train unsupervised NN for extracting features from multi-goal datasets


epochs = 1000

eps=0.15
losses = []
episode_length = 0

ep_lengths = []
use_explicit = False

i_reward_list = []


for i in tqdm(range(epochs)):
    opt.zero_grad()

    forward_pred_err, inverse_pred_err, i_reward = minibatch_train(use_extrinsic=False) #H
    loss = loss_fn(forward_pred_err, inverse_pred_err) #I
    loss_list = (forward_pred_err.flatten().mean(), inverse_pred_err.flatten().mean())
    losses.append(loss_list)
    loss.backward()
    opt.step()
    
    print
    if i % 10 == 0:
        loss = loss.item()
        print(f"loss: {loss:>7f}")

  0%|          | 0/100 [00:00<?, ?it/s]

  1%|          | 1/100 [00:07<12:22,  7.50s/it]

loss: 2354.513184


 11%|█         | 11/100 [00:46<05:44,  3.87s/it]

loss: 1863.161255


 11%|█         | 11/100 [00:48<06:35,  4.45s/it]


KeyboardInterrupt: 