In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
from tqdm import tqdm_notebook as tqdm
import numpy as np

In [3]:
import torch as T
import torch.nn as nn

In [4]:
import gym

env = gym.make('CartPole-v0')
print(env.observation_space.shape, env.action_space.n)

print(env.reset())
print(env.step(0))

(4,) 2
[-0.02282638  0.0018035   0.04048052 -0.03002856]
(array([-0.02279031, -0.19387488,  0.03987995,  0.27514658]), 1.0, False, {})


In [5]:
class BCO(nn.Module):
    def __init__(self, env, policy='mlp'):
        super(BCO, self).__init__()
        
        self.policy = policy
        self.act_n = env.action_space.n
        
        if self.policy=='mlp':
            self.obs_n = env.observation_space.shape[0]
            self.pol = nn.Sequential(*[nn.Linear(self.obs_n, 32), nn.LeakyReLU(), 
                                       nn.Linear(32, 32), nn.LeakyReLU(), 
                                       nn.Linear(32, self.act_n)])
            self.inv = nn.Sequential(*[nn.Linear(self.obs_n*2, 32), nn.LeakyReLU(), 
                                       nn.Linear(32, 32), nn.LeakyReLU(),  
                                       nn.Linear(32, self.act_n)])
        
        elif self.policy=='cnn':
            pass
    
    def pred_act(self, obs):
        out = self.pol(obs)
        
        return out
    
    def pred_inv(self, obs1, obs2):
        obs = T.cat([obs1, obs2], dim=1)
        out = self.inv(obs)
        
        return out

POLICY = 'mlp'
model = BCO(env, policy=POLICY).cuda()

In [6]:
from torch.utils.data import Dataset, DataLoader

class DS_Inv(Dataset):
    def __init__(self, trajs):
        self.dat = []
        
        for traj in trajs:
            for dat in traj:
                obs, act, new_obs = dat
                
                self.dat.append([obs, new_obs, act])
    
    def __len__(self):
        return len(self.dat)
    
    def __getitem__(self, idx):
        obs, new_obs, act = self.dat[idx]
        
        return obs, new_obs, np.asarray(act)

class DS_Policy(Dataset):
    def __init__(self, traj):
        self.dat = []
        
        for dat in traj:
            obs, act = dat
                
            self.dat.append([obs, act])
    
    def __len__(self):
        return len(self.dat)
    
    def __getitem__(self, idx):
        obs, act = self.dat[idx]
        
        return obs, np.asarray(act)

In [7]:
import pickle

trajs_demo = pickle.load(open('Demo/demo_cart-pole.pkl', 'rb'))
ld_demo = DataLoader(DS_Inv(trajs_demo), batch_size=100)

print(len(ld_demo))
for obs1, obs2, _ in ld_demo:
    print(obs1.shape, obs2.shape)
    
    break

10
torch.Size([100, 4]) torch.Size([100, 4])


In [8]:
loss_func = nn.CrossEntropyLoss().cuda()
optim = T.optim.Adam(model.parameters(), lr=5e-3)

EPOCHS = 15
M = 1000

EPS = 0.9
DECAY = 0.5

In [9]:
trajs_inv = []

for e in tqdm(range(EPOCHS)):
    
    # step1, generate inverse samples
    cnt = 0
    epn = 0
    
    rews = 0
        
    while True:
        traj = []
        rew = 0
            
        obs = env.reset()
        while True:
            inp = T.from_numpy(obs).view(((1, )+obs.shape)).float().cuda()
            out = model.pred_act(inp).cpu().detach().numpy()
                
            if np.random.rand()>=EPS:
                act = np.argmax(out, axis=1)[0]
            else:
                act = env.action_space.sample()
                
            new_obs, r, done, _ = env.step(act)
                
            traj.append([obs, act, new_obs])
            obs = new_obs
            rew += r
            
            cnt += 1
                
            if done==True:
                rews += rew
                trajs_inv.append(traj)
                
                epn += 1
                
                break
        
        if cnt >= M:
            break
        
    rews /= epn
    print('Ep %d: reward=%.2f' % (e+1, rews))
        
    # step2, update inverse model
    ld_inv = DataLoader(DS_Inv(trajs_inv), batch_size=32, shuffle=True)
    
    with tqdm(ld_inv) as TQ:
        ls_ep = 0
        
        for obs1, obs2, act in TQ:
            out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
            ls_bh = loss_func(out, act.cuda())
            
            optim.zero_grad()
            ls_bh.backward()
            optim.step()
            
            ls_bh = ls_bh.cpu().detach().numpy()
            TQ.set_postfix(loss_inv='%.3f' % (ls_bh))
            ls_ep += ls_bh
        
        ls_ep /= len(TQ)
        print('Ep %d: loss_inv=%.3f' % (e+1, ls_ep))
    
    # step3, predict inverse action for demo samples
    traj_policy = []
    
    for obs1, obs2, _ in ld_demo:
        out = model.pred_inv(obs1.float().cuda(), obs2.float().cuda())
        
        obs = obs1.cpu().detach().numpy()
        out = out.cpu().detach().numpy()
        out = np.argmax(out, axis=1)
        
        for i in range(100):
            traj_policy.append([obs[i], out[i]])
    
    # step4, update policy via demo samples
    ld_policy = DataLoader(DS_Policy(traj_policy), batch_size=32, shuffle=True)
    
    with tqdm(ld_policy) as TQ:
        ls_ep = 0
        
        for obs, act in TQ:
            out = model.pred_act(obs.float().cuda())
            ls_bh = loss_func(out, act.cuda())
            
            optim.zero_grad()
            ls_bh.backward()
            optim.step()
            
            ls_bh = ls_bh.cpu().detach().numpy()
            TQ.set_postfix(loss_policy='%.3f' % (ls_bh))
            ls_ep += ls_bh
        
        ls_ep /= len(TQ)
        print('Ep %d: loss_policy=%.3f' % (e+1, ls_ep))
    
    # step5, save model
    T.save(model.state_dict(), 'Model/model_cart-pole_%d.pt' % (e+1))
    
    EPS *= DECAY

HBox(children=(IntProgress(value=0, max=15), HTML(value='')))

Ep 1: reward=18.98


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 1: loss_inv=0.610


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 1: loss_policy=0.570
Ep 2: reward=13.12


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))

Ep 2: loss_inv=0.049


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 2: loss_policy=0.557
Ep 3: reward=26.71


HBox(children=(IntProgress(value=0, max=95), HTML(value='')))

Ep 3: loss_inv=0.001


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 3: loss_policy=0.528
Ep 4: reward=34.76


HBox(children=(IntProgress(value=0, max=127), HTML(value='')))

Ep 4: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 4: loss_policy=0.446
Ep 5: reward=90.42


HBox(children=(IntProgress(value=0, max=161), HTML(value='')))

Ep 5: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 5: loss_policy=0.422
Ep 6: reward=170.17


HBox(children=(IntProgress(value=0, max=193), HTML(value='')))

Ep 6: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 6: loss_policy=0.416
Ep 7: reward=200.00


HBox(children=(IntProgress(value=0, max=224), HTML(value='')))

Ep 7: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 7: loss_policy=0.391
Ep 8: reward=200.00


HBox(children=(IntProgress(value=0, max=255), HTML(value='')))

Ep 8: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 8: loss_policy=0.396
Ep 9: reward=200.00


HBox(children=(IntProgress(value=0, max=286), HTML(value='')))

Ep 9: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 9: loss_policy=0.375
Ep 10: reward=200.00


HBox(children=(IntProgress(value=0, max=318), HTML(value='')))

Ep 10: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 10: loss_policy=0.349
Ep 11: reward=200.00


HBox(children=(IntProgress(value=0, max=349), HTML(value='')))

Ep 11: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 11: loss_policy=0.372
Ep 12: reward=200.00


HBox(children=(IntProgress(value=0, max=380), HTML(value='')))

Ep 12: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 12: loss_policy=0.401
Ep 13: reward=200.00


HBox(children=(IntProgress(value=0, max=411), HTML(value='')))

Ep 13: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 13: loss_policy=0.376
Ep 14: reward=200.00


HBox(children=(IntProgress(value=0, max=443), HTML(value='')))

Ep 14: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 14: loss_policy=0.320
Ep 15: reward=200.00


HBox(children=(IntProgress(value=0, max=474), HTML(value='')))

Ep 15: loss_inv=0.000


HBox(children=(IntProgress(value=0, max=32), HTML(value='')))

Ep 15: loss_policy=0.352

