In [None]:
import sys
sys.path.insert(0,'/root/Downloads/metaworld-master/metaworld-master/')
import metaworld
import random

ml10 = metaworld.MT50() # Construct the benchmark, sampling tasks

training_envs = []
for name, env_cls in ml10.train_classes.items():
  env = env_cls()
  task = random.choice([task for task in ml10.train_tasks
                        if task.env_name == name])
  env.set_task(task)
  training_envs.append(env)

for env in training_envs:
  obs = env.reset()  # Reset environment
  a = env.action_space.sample()  # Sample an action
  obs, reward, done, info = env.step(a)  # Step the environoment with the sampled random action

In [None]:
envs=training_envs

import sys
sys.path.append("./")
import metaworld
import torch

import os
import time
import os.path as osp

import numpy as np

from torchrl.utils import get_args
from torchrl.utils import get_params
from torchrl.env import get_env


from torchrl.utils import Logger
import torchrl.policies as policies
import torchrl.networks as networks
from torchrl.collector.base import BaseCollector
from torchrl.algo import SAC
from torchrl.algo import TwinSAC
from torchrl.algo import TwinSACQ
from torchrl.algo import MTSAC
from torchrl.collector.para import ParallelCollector
from torchrl.collector.para import AsyncParallelCollector
from torchrl.collector.para.mt import SingleTaskParallelCollectorBase
from torchrl.replay_buffers import BaseReplayBuffer
from torchrl.replay_buffers.shared import SharedBaseReplayBuffer
from torchrl.replay_buffers.shared import AsyncSharedReplayBuffer
import gym

import random

In [None]:
import gym
from gym import Wrapper
from gym.spaces import Box
import numpy as np
from metaworld.envs.mujoco.sawyer_xyz import *
from metaworld.core.serializable import Serializable
import sys
sys.path.append("../..")
from torchrl.env.continuous_wrapper import *
from torchrl.env.get_env import wrap_continuous_env


class SingleWrapper(Wrapper):
    def __init__(self, env):
        self._env = env
        self.action_space = env.action_space
        self.observation_space = env.observation_space
        self.train_mode = True
    def reset(self):
        return self._env.reset()

    def seed(self, se):
        self._env.seed(se)

    def reset_with_index(self, task_idx):
        return self._env.reset()

    def step(self, action):
        obs, reward, done, info = self._env.step(action)
        return obs, reward, done, info

    def train(self):
        self.train_mode = True

    def test(self):
        self.train_mode = False
    def eval(self):
        self.train_mode = False

    def render(self, mode='human', **kwargs):
        return self._env.render(mode=mode, **kwargs)

    def close(self):
        self._env.close()


In [None]:
class Normalizer():
    def __init__(self, shape, clip=10.):
        self.shape = shape
        self._mean = np.zeros(shape)
        self._var = np.ones(shape)
        self._count = 1e-4
        self.clip = clip
        self.should_estimate = True

    def stop_update_estimate(self):
        self.should_estimate = False

    def update_estimate(self, data):
        if not self.should_estimate:
            return
        if len(data.shape) == self.shape:
            data = data[np.newaxis, :]
        self._mean, self._var, self._count = update_mean_var_count(
            self._mean, self._var, self._count,
            np.mean(data, axis=0), np.var(data, axis=0), data.shape[0])

    def inverse(self, raw):
        return raw * np.sqrt(self._var) + self._mean

    def inverse_torch(self, raw):
        return raw * torch.Tensor(np.sqrt(self._var)).to(raw.device) \
            + torch.Tensor(self._mean).to(raw.device)

    def filt(self, raw):
        return np.clip(
            (raw - self._mean) / (np.sqrt(self._var) + 1e-4),
            -self.clip, self.clip)

    def filt_torch(self, raw):
        return torch.clamp(
            (raw - torch.Tensor(self._mean).to(raw.device)) / \
            (torch.Tensor(np.sqrt(self._var) + 1e-4).to(raw.device)),
            -self.clip, self.clip)

class parser:
    def __init__(self): 
        self.config='config/sac_ant.json'
        self.id='sac_ant'
        self.worker_nums=10
        self.eval_worker_nums=10
        self.seed=20
        self.vec_env_nums=1
        self.save_dir='./save/sac_ant'
        self.log_dir='./log/sac_ant'
        self.no_cuda=True
        self.overwrite=True
        self.device='cpu'
        self.cuda=False
                
args=parser()
params = get_params(args.config)


device = torch.device(
    "cuda:{}".format(args.device) if args.cuda else "cpu")

normalizer=Normalizer(env.observation_space.shape)
env.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

buffer_param = params['replay_buffer']

experiment_name = os.path.split(
    os.path.splitext(args.config)[0])[-1] if args.id is None \
    else args.id
logger = Logger(
    experiment_name, params['env_name'], args.seed, params, args.log_dir)



In [None]:
import copy
import time
from collections import deque
import numpy as np

import torch

import torchrl.algo.utils as atu

import gym

import os
import os.path as osp

class RLAlgo():
    """
    Base RL Algorithm Framework
    """
    def __init__(self,
        env = None,
        replay_buffer = None,
        collector = None,
        logger = None,
        continuous = None,
        discount=0.99,
        num_epochs = 3000,
        epoch_frames = 1000,
        max_episode_frames = 999,
        batch_size = 128,
        device = 'cpu',
        train_render = False,
        eval_episodes = 1,
        eval_render = False,
        save_interval = 100,
        save_dir = None
    ):

        self.env = env
        self.total_frames = 0
        self.continuous = isinstance(self.env.action_space, gym.spaces.Box)

        self.replay_buffer = replay_buffer
        self.collector = collector        
        # device specification
        self.device = device

        # environment relevant information
        self.discount = discount
        self.num_epochs = num_epochs
        self.epoch_frames = epoch_frames
        self.max_episode_frames = max_episode_frames

        self.train_render = train_render
        self.eval_render = eval_render

        # training information
        self.batch_size = batch_size
        self.training_update_num = 0
        self.sample_key = None

        # Logger & relevant setting
        self.logger = logger

        
        self.episode_rewards = deque(maxlen=30)
        self.training_episode_rewards = deque(maxlen=30)
        self.eval_episodes = eval_episodes

        self.save_interval = save_interval
        self.save_dir = save_dir
        if not osp.exists( self.save_dir ):
            os.mkdir( self.save_dir )

        self.best_eval = None

    def start_epoch(self):
        pass

    def finish_epoch(self):
        return {}

    def pretrain(self):
        pass
    
    def update_per_epoch(self):
        pass

    def snapshot(self, prefix, epoch):
        for name, network in self.snapshot_networks:
            model_file_name="model_{}_{}.pth".format(name, epoch)
            model_path=osp.join(prefix, model_file_name)
            torch.save(network.state_dict(), model_path)

    def train(self,epoch):
        if epoch==1:
            self.pf.load_state_dict(torch.load('/root/metaworld-master/newsoftmodule_24/model50_'+str(index)+'/model_pf_best.pth'))
            self.qf.load_state_dict(torch.load('/root/metaworld-master/newsoftmodule_24/model50_'+str(index)+'/model_qf_best.pth'))
            self.vf.load_state_dict(torch.load('/root/metaworld-master/newsoftmodule_24/model50_'+str(index)+'/model_vf_best.pth'))
    
            self.pretrain()
            self.total_frames = 0
            if hasattr(self, "pretrain_frames"):
                self.total_frames = self.pretrain_frames

            self.start_epoch()

        self.current_epoch = epoch
        start = time.time()

        self.start_epoch()

        explore_start_time = time.time()
        training_epoch_info =  self.collector.train_one_epoch()
        for reward in training_epoch_info["train_rewards"]:
            self.training_episode_rewards.append(reward)
        explore_time = time.time() - explore_start_time

        train_start_time = time.time()
        loss=self.update_per_epoch()
        train_time = time.time() - train_start_time

        finish_epoch_info = self.finish_epoch()

        eval_start_time = time.time()
        eval_infos = self.collector.eval_one_epoch()
        eval_time = time.time() - eval_start_time

        self.total_frames += self.collector.active_worker_nums * self.epoch_frames

        infos = {}

        for reward in eval_infos["eval_rewards"]:
            self.episode_rewards.append(reward)
        # del eval_infos["eval_rewards"]

        if self.best_eval is None or \
            np.mean(eval_infos["eval_rewards"]) > self.best_eval:
            self.best_eval = np.mean(eval_infos["eval_rewards"])
            self.snapshot(self.save_dir, 'best')
        del eval_infos["eval_rewards"]
        infos["eval_avg_success_rate"] =eval_infos["success"]
        infos["Running_Average_Rewards"] = np.mean(self.episode_rewards)
        infos["Running_success_rate"] =training_epoch_info["train_success_rate"]
        infos["Train_Epoch_Reward"] = training_epoch_info["train_epoch_reward"]
        infos["Running_Training_Average_Rewards"] = np.mean(
            self.training_episode_rewards)
        infos["Explore_Time"] = explore_time
        infos["Train___Time"] = train_time
        infos["Eval____Time"] = eval_time
        infos.update(eval_infos)
        infos.update(finish_epoch_info)

        self.logger.add_epoch_info(epoch, self.total_frames,
            time.time() - start, infos )

        if epoch % self.save_interval == 0:
            self.snapshot(self.save_dir, epoch)
        if epoch==self.num_epochs-1:
            self.snapshot(self.save_dir, "finish")
            self.collector.terminate()
        return loss
    def update(self, batch):
        raise NotImplementedError

    def _update_target_networks(self):
        if self.use_soft_update:
            for net, target_net in self.target_networks:
                atu.soft_update_from_to(net, target_net, self.tau)
        else:
            if self.training_update_num % self.target_hard_update_period == 0:
                for net, target_net in self.target_networks:
                    atu.copy_model_params_from_to(net, target_net)

    @property
    def networks(self):
        return [
        ]
    
    @property
    def snapshot_networks(self):
        return [
        ]

    @property
    def target_networks(self):
        return [
        ]
    
    def to(self, device):
        for net in self.networks:
            net.to(device)


In [None]:
import time
import numpy as np
import math

import torch


class OffRLAlgo(RLAlgo):
    """
    Base RL Algorithm Framework
    """
    def __init__(self,

        pretrain_epochs=0,

        min_pool = 0,

        target_hard_update_period = 1000,
        use_soft_update = True,
        tau = 0.001,
        opt_times = 1,

        **kwargs
    ):
        super(OffRLAlgo, self).__init__(**kwargs)

        # environment relevant information
        self.pretrain_epochs = pretrain_epochs
        
        # target_network update information
        self.target_hard_update_period = target_hard_update_period
        self.use_soft_update = use_soft_update
        self.tau = tau

        # training information
        self.opt_times = opt_times
        self.min_pool = min_pool

        self.sample_key = [ "obs", "next_obs", "acts", "rewards", "terminals" ]

    def update_per_timestep(self):
        if self.replay_buffer.num_steps_can_sample() > max( self.min_pool, self.batch_size ):
            for _ in range( self.opt_times ):
                batch = self.replay_buffer.random_batch(self.batch_size, self.sample_key)
                infos = self.update( batch )
                self.logger.add_update_info( infos )

    def update_per_epoch(self):
        loss=[]
        for _ in range( self.opt_times ):
            batch = self.replay_buffer.random_batch(self.batch_size, self.sample_key)
            infos = self.update( batch )
            loss.append(infos['Training/policy_loss'])
            self.logger.add_update_info( infos )
        return np.mean(loss)
    def pretrain(self):
        total_frames = 0
        self.pretrain_epochs * self.collector.worker_nums * self.epoch_frames
        
        for pretrain_epoch in range( self.pretrain_epochs ):

            start = time.time()

            self.start_epoch()
            
            training_epoch_info =  self.collector.train_one_epoch()
            for reward in training_epoch_info["train_rewards"]:
                self.training_episode_rewards.append(reward)

            finish_epoch_info = self.finish_epoch()

            total_frames += self.collector.active_worker_nums * self.epoch_frames
            
            infos = {}
            
            infos["Train_Epoch_Reward"] = training_epoch_info["train_epoch_reward"]
            infos["Running_Training_Average_Rewards"] = np.mean(self.training_episode_rewards)
            infos.update(finish_epoch_info)
            
            self.logger.add_epoch_info(pretrain_epoch, total_frames, time.time() - start, infos, csv_write=False )
        
        self.pretrain_frames = total_frames

        self.logger.log("Finished Pretrain")


In [None]:
import time
import numpy as np
import copy

import torch
import torch.optim as optim
from torch import nn as nn


class SAC(OffRLAlgo):
    """
    SAC
    """
    def __init__(
            self,
            pf, vf, qf,
            plr,vlr,qlr,
            optimizer_class=optim.Adam,
            
            policy_std_reg_weight=1e-3,
            policy_mean_reg_weight=1e-3,

            reparameterization = True,
            automatic_entropy_tuning = True,
            target_entropy = None,
            **kwargs
    ):
        super(SAC, self).__init__(**kwargs)
        self.pf = pf
        self.qf = qf
        self.vf = vf
        self.target_vf = copy.deepcopy(vf)
        self.to(self.device)

        self.plr = plr
        self.vlr = vlr
        self.qlr = qlr

        self.qf_optimizer = optimizer_class(
            self.qf.parameters(),
            lr=self.qlr,
        )

        self.vf_optimizer = optimizer_class(
            self.vf.parameters(),
            lr=self.vlr,
        )

        self.pf_optimizer = optimizer_class(
            self.pf.parameters(),
            lr=self.plr,
        )
        
        self.automatic_entropy_tuning = automatic_entropy_tuning
        if self.automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.env.action_space.shape).item()  # from rlkit
            self.log_alpha = torch.zeros(1).to(self.device)
            self.log_alpha.requires_grad_()
            self.alpha_optimizer = optimizer_class(
                [self.log_alpha],
                lr=self.plr,
            )

        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.policy_std_reg_weight = policy_std_reg_weight
        self.policy_mean_reg_weight = policy_mean_reg_weight

        self.reparameterization = reparameterization

    def update(self, batch):
        self.training_update_num += 1
        
        obs = batch['obs']
        actions = batch['acts']
        next_obs = batch['next_obs']
        rewards = batch['rewards']
        terminals = batch['terminals']

        rewards = torch.Tensor(rewards).to( self.device )
        terminals = torch.Tensor(terminals).to( self.device )
        obs = torch.Tensor(obs).to( self.device )
        actions = torch.Tensor(actions).to( self.device )
        next_obs = torch.Tensor(next_obs).to( self.device )

        """
        Policy operations.
        """
        sample_info = self.pf.explore(obs, return_log_probs=True )

        mean        = sample_info["mean"]
        log_std     = sample_info["log_std"]
        new_actions = sample_info["action"]
        log_probs   = sample_info["log_prob"]
        ent         = sample_info["ent"]

        q_pred = self.qf([obs, actions])
        v_pred = self.vf(obs)

        if self.automatic_entropy_tuning:
            """
            Alpha Loss
            """
            alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha = 1
            alpha_loss = 0

        """
        QF Loss
        """
        target_v_values = self.target_vf(next_obs)
        q_target = rewards + (1. - terminals) * self.discount * target_v_values
        qf_loss = self.qf_criterion( q_pred, q_target.detach())

        """
        VF Loss
        """
        q_new_actions = self.qf([obs, new_actions])
        v_target = q_new_actions - alpha * log_probs
        vf_loss = self.vf_criterion( v_pred, v_target.detach())

        """
        Policy Loss
        """
        if not self.reparameterization:
            log_policy_target = q_new_actions - v_pred
            policy_loss = (
                log_probs * ( alpha * log_probs - log_policy_target).detach()
            ).mean()
        else:
            policy_loss = ( alpha * log_probs - q_new_actions).mean()

        std_reg_loss = self.policy_std_reg_weight * (log_std**2).mean()
        mean_reg_loss = self.policy_mean_reg_weight * (mean**2).mean()

        policy_loss += std_reg_loss + mean_reg_loss
        
        """
        Update Networks
        """
        self.pf_optimizer.zero_grad()
        
        w=[]
        for key in pfs[0].state_dict().keys():
            w.append(torch.cat([pfs[j].state_dict()[key].unsqueeze(0) for j in range(len(envs))]))            
        
        rloss[index] = policy_loss.clone()
        rlosscopy=rloss.copy()
        rlosscopy[index] =rlosscopy[index].detach().item()
        low=np.array(rlosscopy).mean()-3*np.array(rlosscopy).std()
        high=np.array(rlosscopy).mean()+3*np.array(rlosscopy).std()
        if np.random.random()<sum(np.array(rlosscopy)<low)+sum(np.array(rlosscopy)>high)>len(envs)/len(envs)+np.exp(-i_episode/1000)+np.exp(-np.array(rlosscopy).mean()*40):
            pre=rloss[index]+lossw(currindex,index,rloss,w,B)/10   
        else:
            pre=rloss[index]
        # compute gradients

        pre.backward()

        # train the NN
    
        self.pf_optimizer.step()
        rloss[index]=rloss[index].detach().item()
        
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()

        self._update_target_networks()

        # Information For Logger
        info = {}
        info['Reward_Mean'] = rewards.mean().item()

        if self.automatic_entropy_tuning:
            info["Alpha"] = alpha.item()
            info["Alpha_loss"] = alpha_loss.item()
        info['Training/policy_loss'] = policy_loss.item()
        info['Training/vf_loss'] = vf_loss.item()
        info['Training/qf_loss'] = qf_loss.item()

        info['log_std/mean'] = log_std.mean().item()
        info['log_std/std'] = log_std.std().item()
        info['log_std/max'] = log_std.max().item()
        info['log_std/min'] = log_std.min().item()

        info['log_probs/mean'] = log_std.mean().item()
        info['log_probs/std'] = log_std.std().item()
        info['log_probs/max'] = log_std.max().item()
        info['log_probs/min'] = log_std.min().item()

        info['mean/mean'] = mean.mean().item()
        info['mean/std'] = mean.std().item()
        info['mean/max'] = mean.max().item()
        info['mean/min'] = mean.min().item()

        return info

    @property
    def networks(self):
        return [
            self.pf,
            self.qf,
            self.vf,
            self.target_vf
        ]
    
    @property
    def snapshot_networks(self):
        return [
            ["pf", self.pf],
            ["qf", self.qf],
            ["vf", self.vf]
        ]

    @property
    def target_networks(self):
        return [
            ( self.vf, self.target_vf )
        ]


In [None]:

pfs=[]
qf1s=[]
vfs=[]
agents=[]
epochs=[1 for i in range(len(envs))]
for index in range(len(envs)):
    print(index)
    env=SingleWrapper(envs[index])
    params = get_params(args.config)
    params['general_setting']['logger'] =  Logger(
            'mt50', str(index), args.seed, params, './log/mt50_'+str(index)+'/')
    params['env_name']=str(index)
    params['general_setting']['env'] = env

    replay_buffer = BaseReplayBuffer(
        max_replay_buffer_size=int(buffer_param['size'])#,
    #    time_limit_filter=buffer_param['time_limit_filter']
    )
    params['general_setting']['replay_buffer'] = replay_buffer


    params['general_setting']['device'] = device

    params['net']['base_type'] = networks.MLPBase
    params['net']['activation_func'] = torch.nn.ReLU
    
    

    
    pf = policies.GuassianContPolicy(
        input_shape=env.observation_space.shape[0], 
        output_shape=2 * env.action_space.shape[0],
        **params['net'],
        **params['sac'])

    qf1 = networks.QNet(
        input_shape=env.observation_space.shape[0] + env.action_space.shape[0],
        output_shape=1,
        **params['net'])

    vf = networks.Net(
            input_shape=env.observation_space.shape,
            output_shape=1,
            **params['net']
        )
    pfs.append(pf)
    qf1s.append(qf1)
    vfs.append(vf)
    params['general_setting']['collector'] = BaseCollector(
        env=env, pf=pf,
        replay_buffer=replay_buffer, device=device,
        train_render=False,
        **params["collector"]
    )
    params['general_setting']['save_dir'] = osp.join(
        './log/', "model50_"+str(index))
    agent = SAC(
            pf=pf,
            qf=qf1,plr=3e-4,vlr=3e-4,qlr=3e-4,
            vf=vf,
            **params["sac"],
            **params["general_setting"]
        )
    agents.append(agent)


In [None]:
for index in range(len(envs)):
    agents[index].pf.load_state_dict(torch.load('/root/metaworld-master/newsoftmodule_24/model50_'+str(index)+'/model_pf_best.pth'))
    agents[index].qf.load_state_dict(torch.load('/root/metaworld-master/newsoftmodule_24/model50_'+str(index)+'/model_qf_best.pth'))
    agents[index].vf.load_state_dict(torch.load('/root/metaworld-master/newsoftmodule_24/model50_'+str(index)+'/model_vf_best.pth'))    

In [None]:
import torch
def pss(x,points):
    def pss0(x,i):
        return torch.tanh(200*torch.tensor(x-i))/2+0.5
    return len(points)-sum([pss0(x,i) for i in points])
import matplotlib.pyplot as plt
import sys
sys.path.insert(0,r'/root/Downloads/metaworld-master/metaworld-master/constopt-pytorch/')
import constopt
from constopt.constraints import LinfBall
from constopt.stochastic import PGD, PGDMadry, FrankWolfe, MomentumFrankWolfe
import torch
from torch.autograd import Variable
import torch.nn.utils as utils
from scipy.stats import rankdata
def loss(rloss,w,B,mu=0.2,lamb=[0.01,0.01,0.01]):
    return torch.tensor([1+mu*(np.linalg.norm(B[t],ord=1)-np.linalg.norm(B[t][t],ord=1)) for t in range(len(envs))]).dot(rloss)+lamb[0]*sum([sum([sum([torch.norm(w[i][t]-sum([B.T[t][j]*w[i][j] for j in range(len(envs))]),p=2)**2]) for i in range(2)]) for t in range(len(envs))])
def losst(currindex,t,rloss,w,B,mu=0.2,lamb=[0.01,0.01,0.01],U=[13],pi=list(range(len(envs)))):
    new_rloss=[i for i in rloss]
    new_rloss[t]=new_rloss[t]+1
    rlossRank=1+len(envs)-rankdata(new_rloss, method='min')
    points=B[t]
    return (1+mu*sum([torch.norm(torch.tensor(B[t][i]),p=1)for i in set(list(range(len(envs))))-set([t])]))*rloss[t]+lamb[0]*sum([sum([sum([torch.norm(w[i][s]-sum([B[pi[j]][s]*w[i][pi[j]] for j in range(currindex-1)])-B[t][s]*w[i][t],p=2)**2]) for i in range(2)]) for s in U])+lamb[2]*torch.norm(torch.tensor(rlossRank)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2
def lossb(currindex,t,rloss,w,B,mu=0.2,lamb=[0.01,0.01,0.01],U=[13],pi=list(range(len(envs)))):
    new_rloss=[i for i in rloss]
    new_rloss[t]=new_rloss[t]+1
    rlossRank=1+len(envs)-rankdata(new_rloss, method='min')
    points=B[t]
    return (1+mu*sum([torch.norm(B[t][i],p=1)for i in set(list(range(len(envs))))-set([t])]))*rloss[t]+lamb[0]*sum([sum([sum([torch.norm(w[i][s]-sum([B[pi[j]][s]*w[i][pi[j]] for j in range(currindex-1)])-B[t][s]*w[i][t],p=2)**2]) for i in range(2)]) for s in U])+lamb[2]*torch.norm(torch.tensor(rlossRank)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2
def lossw(currindex,t,rloss,w,B,mu=0.2,lamb=[0.01,0.01,0.01],U=[13],pi=list(range(len(envs)))):
    new_rloss=[i for i in rloss]
    new_rloss[t]=new_rloss[t]+1
    rlossRank=1+len(envs)-rankdata(new_rloss, method='min')
    points=B[t]
    return (1+mu*sum([torch.norm(torch.tensor(B[t][i]),p=1)for i in set(list(range(len(envs))))-set([t])]))*rloss[t]+lamb[0]*sum([sum([torch.norm(w[i][t]-sum([B[pi[j]][t]*w[i][pi[j]] for j in range(currindex-1)]),p=2)**2]) for i in range(2)])+lamb[2]*torch.norm(torch.tensor(rlossRank)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2
def lossw2(currindex,t,rloss,w,B,mu=0.2,lamb=[0.01,0.01,0.01],U=[13],pi=list(range(len(envs)))):
    points=B[t]
    return (1+mu*sum([torch.norm(torch.tensor(B[t][i]),p=1)for i in set(list(range(len(envs))))-set([t])]))*rloss+lamb[0]*sum([sum([torch.norm(w[i][t]-sum([B[pi[j]][t]*w[i][pi[j]] for j in range(currindex-1)]),p=2)**2]) for i in range(2)])
#+0*torch.norm(torch.tensor(priors[current])-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2
import torch.optim as optim
torch.random.manual_seed(0)

OPTIMIZER_CLASSES = [FrankWolfe]# [PGD, PGDMadry, FrankWolfe, MomentumFrankWolfe]
radius=0.05

def setup_problem(make_nonconvex=False):
    radius2 = radius
    loss_func=lossb
    constraint = LinfBall(radius2)

    return loss_func, constraint


def optimize(loss_func, constraint, optimizer_class, iterations=100):
    for i in range(len(envs)):
        if i!=t:
            B[t][i] =torch.tensor(B[t][i],requires_grad=True)
    optimizer = [optimizer_class([B[t][i]], constraint) for i in set(list(range(len(envs))))-set([t])]
    iterates = [[B[t][i].data if i!=t else B[t][i] for i in range(len(envs))]]
    losses = []
    # Use Madry's heuristic for step size
    step_size = {
        FrankWolfe.name: None,
        MomentumFrankWolfe.name: None,
        PGD.name: 2.5 * constraint.alpha / iterations * 2.,
        PGDMadry.name: 2.5 * constraint.alpha / iterations
    }

    for _ in range(iterations):
        for i in range(len(envs)-1):
            optimizer[i].zero_grad()
        loss = loss_func(currindex,t,rloss,w,B,U=list(set(U)-set(list([t]))))
        loss.backward(retain_graph=True)
        for i in  range(len(envs)-1):
            optimizer[i].step(step_size[optimizer[i].name])
        for i in set(list(range(len(envs))))-set([t]):
            B[t][i].data.clamp_(0,100)
        losses.append(loss)
        iterates.append([B[t][i].data if i!=t else B[t][i] for i in range(len(envs))])
    loss = loss_func(currindex,t,rloss,w,B,U=list(set(U)-set(list([t])))).detach()
    losses.append(loss)
    B[t]=[B[t][i].data if i!=t else B[t][i] for i in range(len(envs))]
    return losses, iterates
paras=[[0.01,0.01,0.01,0.01],
[1,0.01,0.01,0.01],
[1,0.01,0.01,0.02],
[0.01,0.01,0,0.1],
[0.01,0.01,0,0.05],
[0.5,0.01,0.01,0.01],
[1,0.01,0.01,0.1],
[0.01,0.01,0,1],
[0.01,0.01,0,0],
[0.1,0.01,0,0.01],
[0.2,0.01,0,0.05],
[0.1,0.01,0,0.05],
[0.01,0.01,0.01,0.5],
[1,0.01,0.02,0.02],
[1,0.01,0.02,0.01],
[0.1,0.01,0,0.2],
[1,0.01,0.02,0.05]]
para=paras[5]
mu,lamb=0.01,[0.01,0.01,0.01]#para[0],[para[1],para[2],para[3]]
rloss=[0.0 for i in range(len(envs))]
rewardsRec=[[] for i in range(len(envs))]
rewardsRec_nor=[[0] for i in range(len(envs))]
succeessRec=[[] for i in range(len(envs))]
TotalRewardRec=[]
B=[list(i) for i in np.diag(np.ones(len(envs)))]


In [None]:
for i_episode in range(10000):
    rlosscopy=rloss.copy()
    low=np.array(rlosscopy).mean()-3*np.array(rlosscopy).std()
    high=np.array(rlosscopy).mean()+3*np.array(rlosscopy).std()
    if np.random.random()<sum(np.array(rlosscopy)<low)+sum(np.array(rlosscopy)>high)>len(envs)/len(envs)+np.exp(-i_episode/1000)+np.exp(-np.array(rlosscopy).mean()*40):
        p = np.random.random()
        # roll = np.random.randint(2)
        length = 0
        w=[]
        for key in pfs[0].state_dict().keys():
            w.append(torch.cat([pfs[j].state_dict()[key].unsqueeze(0) for j in range(len(envs))]))            
        
        # if np.random.random()<sum(np.array(rloss)<low)+sum(np.array(rloss)>high)>len(envs)/len(envs)+np.exp(-rnd/1000)+np.exp(-np.array(rloss).mean()*40):
        #    multitask=True
        #else:
         #   multitask=False
        #rloss=torch.tensor([0 for i in range(len(envs))])
        #update pi
        U=list(range(len(envs)))
        pi=[0 for i in range(len(envs))]
        for currindex in range(len(envs)):
            indexdict={}
            for t in U:
                indexdict[losst(currindex,t,torch.tensor(rloss),w,B,mu=mu,lamb=lamb,U=list(set(U)-set(list([t])))).item()]=t
            t=indexdict[min(indexdict.keys())]
            pi[currindex]=t
            U=list(set(U)-set(list([pi[currindex]])))  
            #update b
            loss_func, constraint = setup_problem(make_nonconvex=True)
            iterations = 10
            for opt_class in OPTIMIZER_CLASSES:
                losses_, iterates_ = optimize(loss_func,
                                              constraint,
                                              opt_class,
                                              iterations)
      #      B[t]=torch.tensor(B[t],requires_grad=True)
       #     optimizer=torch.optim.Adam([B[t]],lr=1e-2)
        #    for step in range(5):
         #       pre=lossb(currindex,t,torch.tensor(rloss),w,B,U=list(set(U)-set(list([t]))))
          #      optimizer.zero_grad()
           #     pre.backward(retain_graph=True)
            #    optimizer.step()
            #B[t]=B[t].detach().numpy()


            env_id=t
            index=t
            env=envs[index]
            print('-------------------------------------------------------------------------------',i_episode,index)
            agents[index].train(epochs[index])
            epochs[index]+=1
            np.save('B_oursall0.01_mt50.npy',B)
    else:
        
        for index, env in enumerate(envs):
            print('-------------------------------------------------------------------------------',i_episode,index)
            agents[index].train(epochs[index])
            epochs[index]+=1
        np.save('B_oursall0.01_mt50.npy',B)
        
if i_episode == 5999:
    print('Well that failed')
