In [None]:
import os
import time
import copy
import torch
import numpy as np

from vpt.agent import resize_image

from src.policy.vpt import load
from src.policy.ppo import BatchEpisode, PPO
from src.policy.policy import VPTPolicy
from src.policy.mt_ppo import MTBatchEpisode, MTPPO
from src.policy.mt_policy import MTVPTPolicy

from src.collector.conc import MultiThreadCollector, Worker
from src.collector.mt_collect import MTCollector

from src.environment.env import eval_task_ids, eval_task_specs, make_eval_env

from src.reward.mineclip import MineCLIP, soften, min_clip, zero_out_decreased, hidden_dim as clip_embed_dim

In [None]:
os.environ['MINEDOJO_HEADLESS'] = '1'

In [None]:
config = dict(
    ppo=dict(clip=0.2, vf_loss_scale=1,
             gamma=0.999, gae_lambda=0.95, normalize_gae=True,
             policy_reg_scale=0.3, lr=1e-4),
    ppo_step=dict(n_max_epoch=5, context_l=100, max_context=4,
                  batch_size=80, target_kl=0.1),
    policy_reg_scale_decay=0.999,
    adapter=dict(
        dim_factor=8,
    ),
    n_episode_in_batch=10,
    time_limit=200,
    reward=dict(
        soften_window=50,
        min_clip=21,
        scale=0.005,
    ),
    n_iter=1000,
)
args = dict(
    gpu=0,
    n_worker=10,
    eval_period=20,
    n_eval_episode=10,
    process_batch_size=100
)

## MineDojo environment

In [None]:
vpt = load(
    './asset/vpt/foundation-model-3x.model',
    './asset/vpt/bc-house-3x.weights',
    context_l=config['ppo_step']['context_l']
)
vpt_policy = VPTPolicy(copy.deepcopy(vpt), event_level_control=False).to(args['gpu'])
policy = MTVPTPolicy(vpt, clip_embed_dim, config['adapter']['dim_factor'], event_level_control=False)

In [None]:
workers = []
dummy_env_spec = eval_task_specs[eval_task_ids[0]]
for _ in range(args['n_worker']):
    worker = Worker(MTCollector(
        make_eval_env(dummy_env_spec), time_limit=config['time_limit']
    ))
    workers.append(worker)

conc_collector = MultiThreadCollector(workers)

In [None]:
mineclip = MineCLIP('./asset/mineclip/attn.pth')
for param in mineclip.parameters():
    param.requires_grad = False
mineclip.eval()
mineclip.to(args['gpu'])

### Training loop

In [None]:
trainer = MTPPO(policy, vpt_policy, **config['ppo']).to(args['gpu'])

In [None]:
load = torch.load('./asset/task_proposals.pt')
polished = load['polished']
world_seeds = load['world_seeds']

In [None]:
tr_task_envs = []
for ii, eval_task_id in enumerate(eval_task_ids):
    for world_seed in world_seeds:
        print(len(tr_task_envs), end=' ')
    
        # building env
        env_spec = eval_task_specs[eval_task_id]

        env_spec['world_seed'] = world_seed.item()
        env_spec['fast_reset'] = False 
        env_spec['event_level_control'] = False
        
        env_spec['target_quantities'] = 999999 # disable ground-truth success criteria
    
        env = make_eval_env(env_spec)
        env.unwrapped._bridge_env = None
        tr_task_envs.append(env)

In [None]:
tr_task_prompts = []
for ii, p in enumerate(polished):
    if ":" in p:
        p = p.split(': ')[1]
    print(p)
    if ii % 5 == 4: print()
    
    tr_task_prompts.append(p)

In [None]:
tr_task_embeds = []
for prompt in tr_task_prompts:
    print(len(tr_task_embeds), end=' ')
    
    task_embed = mineclip.embed_text(prompt).cpu()
    task_embed = task_embed / task_embed.norm(p=2, dim=-1, keepdim=True)
    task_embed = task_embed[0]
    
    tr_task_embeds.append(task_embed)

In [None]:
eval_envs = []
for eval_task_id in eval_task_ids:
    env_spec = eval_task_specs[eval_task_id]
    
    env_spec['fast_reset'] = False
    env_spec['event_level_control'] = False
    
    eval_env = make_eval_env(env_spec)
    eval_env.unwrapped._bridge_env = None
    eval_envs.append(eval_env)

In [None]:
oracle_task_prompts = [
    'get milk from a cow',
    'shear a sheep and get some wool',
    'hunt a chicken and get its meat',
    'collect logs',
    'kill a cow',
    'kill a sheep',
    'kill a spider',
    'kill a zombie',
]
oracle_task_prompts = mineclip.embed_text(oracle_task_prompts).cpu()[0]
oracle_task_prompts = oracle_task_prompts / oracle_task_prompts.norm(p=2, dim=-1, keepdim=True)

In [None]:
for collect_i in range(config['n_iter']): 
    log = {'collect_i': collect_i, 'episode_i': collect_i*config['n_episode_in_batch']}

    # train env
    with torch.cuda.amp.autocast(), policy.expl():
        time1 = time.perf_counter()
        
        # collector
        for tr_env, task_embed in zip(tr_task_envs, tr_task_embeds):
            conc_collector.submit(policy, copy.deepcopy(tr_env), None, task_embed)
        episodes = conc_collector.wait()
        log['time_collect'] = time.perf_counter()-time1
        
        time1 = time.perf_counter()
        
        # reward
        mineclip.to(args['gpu'])
        for ep_idx, episode in enumerate(episodes):
            episode.rewarded_states = torch.as_tensor(
                np.stack([
                    resize_image(s['rgb'][:, 32:608], (256, 160))
                    for s in episode.states
                ])
            ).permute(0, 3, 1, 2).to(args['gpu'])
                
            text_embed = tr_task_embeds[ep_idx]

            video_embed = mineclip.embed_video(
                episode.rewarded_states, process_batch_size=args['process_batch_size']
            )
            
            episode.rewards = config['reward']['scale'] * min_clip(
                soften(
                    mineclip.compute_reward(
                            text_embed.to(args['gpu']), video_embed.to(args['gpu'])
                    ),
                    config['reward']['soften_window']
                ), config['reward']['min_clip']
            )
            episode.rewards = zero_out_decreased(episode.rewards)
            
            del video_embed
            episode.rewarded_states = episode.rewarded_states.to('cpu')
        mineclip.to('cpu')
        log['time_reward'] = time.perf_counter()-time1

    log['tr_return'] = np.mean([sum(episode.rewards) for episode in episodes])
    log['tr_epi_len'] = np.mean([len(episode.rewards) for episode in episodes])

    # eval env
    if collect_i % args['eval_period'] == 0:
        time1 = time.perf_counter()
        
        # collector
        with torch.cuda.amp.autocast(), policy.expl():
            for task_idx, (eval_env, eval_task_embed) in enumerate(zip(eval_envs, oracle_task_prompts)):
                for _ in range(args['n_eval_episode']):
                    conc_collector.submit(policy, copy.deepcopy(eval_env), None, eval_task_embed)
            eval_episodes = conc_collector.wait()
    
            taskw_eval_episodes = np.array(eval_episodes, np.object).reshape(len(eval_envs), args['n_eval_episode'])
        log['time_eval_collect'] = time.perf_counter()-time1
        
        # success check
        taskw_success_rate = [
            np.mean([any(eval_episode.rewards) for eval_episode in eval_episodes])
            for eval_episodes in taskw_eval_episodes
        ]
        log['eval_success_rate'] = {
            task_idx: success_rate
            for task_idx, success_rate in enumerate(taskw_success_rate) 
        }
        log['eval_avg_success_rate'] = np.mean(taskw_success_rate)

    # logging
    time1 = time.perf_counter()
    stat = trainer.step(
        episodes, **config['ppo_step'],
        process_batch_size=args['process_batch_size']
    )
    log['time_ppo'] = time.perf_counter()-time1
    log.update(stat)
    
    log['policy_reg_scale'] = trainer.policy_reg_scale
    trainer.policy_reg_scale *= config['policy_reg_scale_decay']