In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim as optim
import random
import os
import glob
from torch.distributions import Beta
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

from torch.autograd import Function
from torch.utils.data import SubsetRandomSampler, BatchSampler
import matplotlib.pyplot as plt

In [None]:
import gym
from IPython import display
import matplotlib.pyplot as plt
%matplotlib inline

import env
import DANN

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

if (os.path.exists("./output_r")) == False:
    os.mkdir("output_r")
    
for epoch in range (3000):
    files = glob.glob("./output_r/*.png")
    
    for f in files:
        os.remove(f)

# Enviorment Testing

In [None]:
source_env = env.Env(color = 'g', seed = 0)
unseen_1_env = env.Env(color = 'c1', seed = 0)
unseen_2_env = env.Env(color = 'c2', seed = 0)

discrete_actions = {0 : np.array([0,0,0]),       # do nothing
                                 1 : np.array([-1,0,0]),      # steer sharp left
                                 2 : np.array([1,0,0]),       # steer sharp right
                                 3 : np.array([-0.5,0,0]),    # steer left
                                 4 : np.array([0.5,0,0]),     # steer right
                                 5 : np.array([0,1,0]),       # accelerate 100%
                                 6 : np.array([0,0.5,0]),     # accelerate 50%
                                 7 : np.array([0,0.25,0]),    # accelerate 25%
                                 8 : np.array([0,0,1]),       # brake 100%
                                 9 : np.array([0,0,0.5]),     # brake 50%
                                 10 : np.array([0,0,0.25])}   # brake 25%

def get_obs(env):
    n_actions = len(discrete_actions)
    for i in range (0, 30):
        action = torch.randint(low=0, high=11, size=(1,))
        action_transfered = discrete_actions.get(int(action[0]))

        obs, reward, done, _ = env.step([action_transfered[0], action_transfered[1], action_transfered[2]])

    return (obs[0] + 1) / 2.0

env_preview = []

env_preview.append(get_obs(source_env))
env_preview.append(get_obs(unseen_1_env))
env_preview.append(get_obs(unseen_2_env))


f, axs = plt.subplots(1, 3, figsize = (12, 4))
axs = axs.flatten()

for img, ax in zip(env_preview, axs):
    ax.imshow(img)

plt.show()

# DANN

In [None]:
net = DANN(num_out = 2).double().cuda()

# Generate Random Buffer

In [None]:
def get_random(state):
    red_scale, green_scale, blue_scale = 1., 1., 1.
    base_scale = 0.5
    while (red_scale == green_scale == blue_scale):
        add_green = random.randint(0, 1)
        add_red = random.randint(0, 1)
        add_blue = random.randint(0, 1)
        if (add_red): 
            red_scale = random.uniform(0.5, 1.1)
        if (add_green): 
            green_scale = random.uniform(0.5, 1.1)
        if (add_blue): 
            blue_scale = random.uniform(0.5, 1.1)
            
    for i in range (0, 4):
        s = torch.transpose(state[i], 0, 2)
        road = s [1] - s [0] * 0.6 - s[2] * 0.4
        road = torch.stack((road, road, road), 0)
        ones = torch.ones(3, 96, 96).to(device)

        road_mask = torch.logical_xor(road, ones)
        road_layer = s * road_mask

        light_green = 204 / 128. - 1
        light_green_mask = torch.logical_xor(s - light_green , ones)
        light_green_layer = s * light_green_mask

        dark_green = 230 / 128. - 1
        dark_green_mask = torch.logical_xor(s - dark_green , ones)
        dark_green_layer = s * dark_green_mask

        bg_layer = light_green_layer + dark_green_layer

        ones = torch.ones(96, 96).to(device)
        back_ground_mask = torch.logical_xor(bg_layer[1], ones)

        red_channel = (back_ground_mask * 128) /128. - 1
        green_channel = (back_ground_mask * 128) /128. - 1
        blue_channel = (back_ground_mask * 128) /128. - 1

        if (add_red): red_channel = bg_layer[1] * red_scale
        if (add_green): green_channel = bg_layer[1] * green_scale
        if (add_blue): blue_channel = bg_layer[1] * blue_scale

        new_bg_layer = torch.stack((red_channel, green_channel, blue_channel), 0)

        new_state = new_bg_layer + road_layer

        state[i] = torch.transpose(new_state, 0, 2)
        
    return state  

def get_random_buffer(buffer, batch_size):
    target_buffer = buffer.clone()
    for i in range (batch_size):
        target_buffer[i] = get_random(target_buffer[i])
    return target_buffer    

# Agent

In [None]:
class Agent:
    max_grad_norm = 0.5
    clip_param = 0.1
    ppo_epoch = 10

    transition = np.dtype([
        ('s', np.float64, (4, 96, 96, 3)),
        ('a', np.float64, (3,)),
        ('a_logp', np.float64),
        ('r', np.float64),
        ('s_', np.float64, (4, 96, 96, 3))
    ])

    def __init__(self, net, criterion, optimizer, buffer_capacity=2000, batch_size=128):
        """
        Initialize the Agent.
        
        :param net: neural network for the agent
        :param criterion: loss function
        :param optimizer: optimizer for the network
        :param buffer_capacity: int, capacity of the buffer
        :param batch_size: int, batch size for training
        """
        self.net = net
        self.buffer_capacity = buffer_capacity
        self.batch_size = batch_size
        self.criterion = criterion
        self.optimizer = optimizer

        self.source_buffer = np.empty(self.buffer_capacity, dtype=self.transition)
        self.counter = 0

    def select_action(self, state):
        """
        Select action based on the given state.
        
        :param state: numpy array, state representation
        :return: tuple, (action, action_log_probability)
        """
        state = torch.from_numpy(state).double().to(device).unsqueeze(0)
        with torch.no_grad():
            out = self.net.sketch(state)
            out = torch.squeeze(out)
            out = self.net.feature(out)

            out = self.net.cnn_base(out)
            out = out.view(-1, 256)
            out = self.net.fc(out)
            alpha = self.net.alpha_head(out) + 1
            beta = self.net.beta_head(out) + 1

        dist = Beta(alpha, beta)
        action = dist.sample()
        a_logp = dist.log_prob(action).sum(dim=1)

        action = action.squeeze().cpu().numpy()
        a_logp = a_logp.item()
        return action, a_logp

    def store(self, transition):
        """
        Store the transition in the buffer.
        
        :param transition: tuple, (state, action, action_log_probability, reward, next_state)
        :return: bool, True if the buffer is full, False otherwise
        """
        self.source_buffer[self.counter] = transition
        self.counter += 1
        if self.counter == self.buffer_capacity:
            self.counter = 0
            return True

        return False

    def _prepare_tensors(self):
        s = torch.tensor(self.source_buffer['s'], dtype=torch.double).to(device)
        a = torch.tensor(self.source_buffer['a'], dtype=torch.double).to(device)
        r = torch.tensor(self.source_buffer['r'], dtype=torch.double).to(device).view(-1, 1)
        s_ = torch.tensor(self.source_buffer['s_'], dtype=torch.double).to(device)
        old_a_logp = torch.tensor(self.source_buffer['a_logp'], dtype=torch.double).to(device).view(-1, 1)

        return s, a, r, s_, old_a_logp
    
    def _prepare_domain_labels(self):
        source_domain_label = torch.zeros(self.batch_size).long().to(device)
        target_domain_label = torch.ones(self.batch_size).long().to(device)
        return source_domain_label, target_domain_label

    def _compute_advantage(self, s, r, s_):
        with torch.no_grad():
            target_v = r + 0.99 * self.net(s_)[1]
            adv = target_v - self.net(s)[1]
        return target_v, adv

    def _calculate_losses(self, s, a, old_a_logp, adv, index, eta):
        (alpha, beta), v, domain_out, s_sketch = self.net(s[index], eta)

        source_domain_loss = self.criterion(domain_out, self.source_domain_label)
        source_domain_correct = (torch.argmax(domain_out, dim=1) == self.source_domain_label).sum().item()

        _, _, domain_out, t_sketch = self.net(self.target_batch, eta)
        target_domain_loss = self.criterion(domain_out, self.target_domain_label)
        target_domain_correct = (torch.argmax(domain_out, dim=1) == self.target_domain_label).sum().item()

        dist = Beta(alpha, beta)
        a_logp = dist.log_prob(a[index]).sum(dim=1, keepdim=True)
        ratio = torch.exp(a_logp - old_a_logp[index])
        surr1 = ratio * adv[index]
        surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv[index]
        action_loss = -torch.min(surr1, surr2).mean()
        value_loss = F.smooth_l1_loss(self.net(s[index])[1], self.target_v[index])

        loss = action_loss + 2. * value_loss + source_domain_loss + target_domain_loss

        return loss, source_domain_correct, target_domain_correct, s_sketch, t_sketch

    def _save_images(self, epoch, image_array):
        f, axs = plt.subplots(2, 10, figsize = (16, 4))
        axs = axs.flatten()
        for img, ax in zip(image_array, axs):
            ax.imshow(img)
        f.savefig('./output_r/%04d.png' % epoch)
        plt.close(f)


    def update(self, epoch, eta=0.1):
        # Prepare tensors from the source_buffer
        s, a, r, s_, old_a_logp = self._prepare_tensors()
        self.target_v, adv = self._compute_advantage(s, r, s_)
        self.source_domain_label, self.target_domain_label = self._prepare_domain_labels()

        image_array, source_acc_array, target_acc_array = [], [], []

        for _ in range(self.ppo_epoch):
            total = 0
            source_domain_correct, target_domain_correct = 0, 0
            add_image = True

            for index in BatchSampler(SubsetRandomSampler(range(self.buffer_capacity)), self.batch_size, True):
                total += self.batch_size
                self.target_batch = get_random_buffer(s[index], self.batch_size)

                loss, src_correct, tgt_correct, s_sketch, t_sketch = self._calculate_losses(s, a, old_a_logp, adv, index, eta)
                source_domain_correct += src_correct
                target_domain_correct += tgt_correct

                if add_image:
                    image_array.extend([s_sketch[0][0].reshape(96, 96).cpu().detach().numpy(),
                                        t_sketch[0][0].reshape(96, 96).cpu().detach().numpy()])
                    add_image = False

                self._update_network(loss)

            source_acc_array.append(source_domain_correct / total)
            target_acc_array.append(target_domain_correct / total)

        mean_source_acc = np.mean(source_acc_array)
        mean_target_acc = np.mean(target_acc_array)

        self._save_images(epoch, image_array)

        return mean_source_acc, mean_target_acc

# Training

In [None]:
def eval(agent, env):
    """
    Evaluate the agent on the given environment.

    :param agent: Agent, trained agent to be evaluated
    :param env: Environment, environment to evaluate the agent on
    :return: float, total reward received by the agent
    """
    score = 0
    state = env.reset()

    for t in range(1000):
        action, a_logp = agent.select_action(state)
        state_, reward, done, _ = env.step_eval(action * np.array([2., 1., 1.]) + np.array([-1., 0., 0.]))
        score += reward
        state = state_

        if done:
            break

    return score

def train(source_env, target_env, agent):
    """
    Train the agent on the source environment and evaluate it on the target environments.

    :param source_env: Environment, source environment for training the agent
    :param target_env: list, target environments for evaluating the agent
    :param agent: Agent, reinforcement learning agent
    """
    training_records = []
    running_score_records = []
    running_score = 0

    c1_running_score = 0
    c1_training_records = []
    c1_running_score_records = []

    c2_running_score = 0
    c2_training_records = []
    c2_running_score_records = []

    eta = 0.2

    for i_ep in range(3000):
        score = 0
        state = source_env.reset()

        for t in range(1000):
            action, a_logp = agent.select_action(state)
            state_, reward, done, die = source_env.step(action * np.array([2., 1., 1.]) + np.array([-1., 0., 0.]))
            score += reward

            should_update = agent.store((state, action, a_logp, reward, state_))

            if should_update:
                eta_max = 0.5 if i_ep < 500 else (0.45 if i_ep < 1500 else 0.3)
                print("eta: {:.2f}".format(eta))
                s_acc, t_acc = agent.update(epoch=i_ep, eta=eta)
                eta = 0.1

            state = state_

            if done or die:
                break

        # Record scores and calculate moving averages
        training_records.append(score)
        running_score = running_score * 0.99 + score * 0.01
        running_score_records.append(running_score)

        c1_score = eval(agent, target_env[0])
        c2_score = eval(agent, target_env[1])
        c1_training_records.append(c1_score)
        c2_training_records.append(c2_score)

        c1_running_score = c1_running_score * 0.99 + c1_score * 0.01
        c2_running_score = c2_running_score * 0.99 + c2_score * 0.01
        c1_running_score_records.append(c1_running_score)
        c2_running_score_records.append(c2_running_score)

        # Display progress every 10 episodes
        if i_ep + 1 % 10 == 0:
            print('Ep {}\tLast score: {:.2f}\tMoving average score: {:.2f}'.format(i_ep, score, running_score))
            print('c1 score: {:.2f}\t c1 Moving average: {:.2f}'.format(c1_score, c1_running_score))
            print('c2 score: {:.2f}\t c2 Moving average: {:.2f}'.format(c2_score, c2_running_score))

            f, axs = plt.subplots(1, 2, figsize = (16, 8))
            axs[0].plot(range(len(training_records)), training_records)
            axs[0].plot(range(len(c1_training_records)), c1_training_records)
            axs[0].plot(range(len(c2_training_records)), c2_training_records)
                
            axs[1].plot(range(len(running_score_records)), running_score_records)
            axs[1].plot(range(len(c1_running_score_records)), c1_running_score_records)
            axs[1].plot(range(len(c2_running_score_records)), c2_running_score_records)
                
            axs[0].legend(['s', 'c1', 'c2'])
            axs[1].legend(['s', 'c1', 'c2'])

            f.savefig('./output_r/result_%04d.png' % i_ep)
            plt.close(f)