In [1]:
# Import statements.
import numpy as np
import random as rand
import torch
import math
import matplotlib.pyplot as plt
from ExperimentManager import Experiment
from torch import nn
import torch.nn.functional as F
import torch.distributions as tdist
%matplotlib inline

In [2]:
manager = Experiment.start_experiment('experimentsDiscriminator/', 'experiment', print)

Please enter a brief description of this experiment:
Hypers: (40, 1, agent_params, 1/10, 1/5, 1000)


In [3]:
# A single LSTM cell.
class LSTM_CELL(nn.Module):
    
    # Constructor.
    def __init__(self, input_size, cell_size, hidden_size, t_device):
        super().__init__()
        self.input_size = input_size
        self.cell_size = cell_size
        self.hidden_size = hidden_size
        self.t_device = t_device
        self.cell_forget_gate = nn.Linear(input_size + hidden_size, cell_size)
        self.cell_update_gate_sigmoid = nn.Linear(input_size + hidden_size, cell_size)
        self.cell_update_gate_tanh = nn.Linear(input_size + hidden_size, cell_size)
        self.hidden_dim_reduce = nn.Linear(input_size + hidden_size, hidden_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
    
    # Forward propogates the input through the cell and produces a new cell and hidden state.
    def forward(self, x, cell_state, hidden_state):
        x = torch.cat([hidden_state, x], dim=-1)
        cell_state *= self.sigmoid(self.cell_forget_gate(x))
        cell_state += self.sigmoid(self.cell_update_gate_sigmoid(x)) + self.tanh(self.cell_update_gate_tanh(x))
        hidden_state = self.sigmoid(self.hidden_dim_reduce(x)) * self.tanh(cell_state)
        return cell_state, hidden_state
        

In [4]:
# LSTM based discriminator.
class DISCRIMINATOR(nn.Module):
    
    # The input thrown into the LSTM will be the concat of the distribution and state.
    
    # Constructor.
    def __init__(self, input_size, cell_size, hidden_size, distribution_size, t_device):
        super().__init__()
        self.input_size = input_size
        self.distribution_size = distribution_size
        self.t_device = t_device
        self.lstm_cell = LSTM_CELL(input_size + distribution_size, cell_size, hidden_size, t_device)
        self.hidden_1 = nn.Linear(hidden_size, int(hidden_size * (3/4)))
        self.hidden_2 = nn.Linear(int(hidden_size * (3/4)), int(hidden_size * (1/2)))
        self.hidden_3 = nn.Linear(int(hidden_size * (1/2)), int(hidden_size * (1/4)))
        self.output = nn.Linear(int(hidden_size * (1/4)), 1)
        self.sigmoid = nn.Sigmoid()
        self.relu = F.relu
        self.act_func = self.relu
        
    # Forward progogates the provided input through the network and returns the corresponding labels and inner states.
    # Cell and hidden states should be pre-stacked to the correct sizing to match the state and distribution batch size.
    # The inner states will be a stack of tensors with sizing equal to the batch sizing.
    def forward(self, states, distributions, cell_states, hidden_states, training=False):
        outputs = []
        for index in range(len(states)):
            x = torch.cat([states[index], distributions[index]], dim=-1)
            cell_state, hidden_state = self.lstm_cell(x, cell_states, hidden_states)
            x = self.act_func(self.hidden_1(hidden_state))
            x = self.act_func(self.hidden_2(x))
            x = self.act_func(self.hidden_3(x))
            x = self.output(x)
            if not training:
                x = self.sigmoid(x)
            outputs.append(x)
        return outputs, cell_states, hidden_states

In [5]:
# Generates network weights.
def generate_weights(starting_size, ending_size, weights_needed):
    difference = (starting_size - ending_size) / (weights_needed + 1)
    weights = []
    for i in range(weights_needed):
        weights.append(int(starting_size - (difference * (i+1))))
    return weights

In [6]:
# Simple feedforward generator.
class GENERATOR(nn.Module):
    
    # Constructor.
    def __init__(self, input_size, distribution_size, t_device):
        super().__init__()
        self.input_size = input_size
        self.distribution_size = distribution_size
        self.t_device = t_device
        self.sin = torch.sin
        self.sigmoid = nn.Sigmoid()
        weights = generate_weights(self.input_size, self.distribution_size, 2)
        self.hidden_1 = nn.Linear(input_size, weights[0])
        self.hidden_2 = nn.Linear(weights[0], weights[1])
        self.output = nn.Linear(weights[1], distribution_size)
        
    # Forward propogate input states.
    def forward(self, x):
        x = self.sin(self.hidden_1(x))
        x = self.sin(self.hidden_2(x))
        return self.sigmoid(self.output(x))

In [7]:
# Agent that combines a discriminator and generator to play a game.
class AGENT:
    
    # Constructor.
    # learning_rates[0] = discriminator learning rate & learning_rate[1] = generator learning_rate.
    def __init__(self, name, learning_rates, input_size, hidden_size, action_size, t_device, s_device):
        self.name = name
        self.learning_rates = learning_rates
        self.input_size = input_size
        self.cell_size = hidden_size
        self.hidden_size = hidden_size
        self.action_size = action_size
        self.t_device = t_device
        self.s_device = s_device
        self.discriminator = DISCRIMINATOR(input_size, self.cell_size, hidden_size, action_size, t_device)
        self.generator = GENERATOR(input_size, action_size, t_device)
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=learning_rates[0])
        self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=learning_rates[1])
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.age = 0
        self.threshold = rand.uniform(0.5, 1)
    # Trains the generator to maximize the discriminator's output on the provided sequence of states.
    def train_generator(self, states, cell_state, hidden_state, epochs):
        self.generator_optimizer.zero_grad()
        for e in range(epochs):
            current_cell_state = cell_state
            current_hidden_state = hidden_state
            loss = 0
            for state in states:
                distribution = self.generator(state)
                out, current_cell_state, current_hidden_state = self.discriminator([state], [distribution], current_cell_state, current_hidden_state, True)
                loss += self.bce_loss(out[0], torch.ones(1))
            loss.backward(retain_graph=True)
            self.generator_optimizer.step()
            
    # Trains the discriminator on the provided trajectories.
    # Trajectories should be of the form (label, groups)
    # Groups should just be a list of sequential gameplay generated pairings of size = unroll_size.
    def train_discriminator(self, trajectories, unroll_size, batch_size=64, epochs=50, extra_info=''):
        self.discriminator_optimizer.zero_grad()
        batches = []
        states = [[] for _ in range(unroll_size)]
        distributions_positive = [[] for _ in range(unroll_size)]
        distributions_negative = [[] for _ in range(unroll_size)]
        rand.shuffle(trajectories)
        for trajectory in trajectories:
            groups = trajectory
            for i in range(len(groups)):
                states[i].append(groups[i][0])
                distributions_positive[i].append(groups[i][1])
                distributions_negative[i].append(torch.rand(self.action_size))
            if len(states[0]) >= batch_size:
                states = [torch.stack(state) for state in states]
                distributions_positive = [torch.stack(dist) for dist in distributions_positive]
                distributions_negative = [torch.stack(dist) for dist in distributions_negative]
                batches.append((states, distributions_positive, distributions_negative, len(states[0])))
                states = [[] for _ in range(unroll_size)]
                distributions_positive = [[] for _ in range(unroll_size)]
                distributions_negative = [[] for _ in range(unroll_size)]
        if len(states[0]) > 0:
            states = [torch.stack(state) for state in states]
            distributions_positive = [torch.stack(dist) for dist in distributions_positive]
            distributions_negative = [torch.stack(dist) for dist in distributions_negative]
            batches.append((states, distributions_positive, distributions_negative, len(states[0])))
        losses = []
        for e in range(epochs):
            losses = []
            for b in range(len(batches)):
                batch = batches[b]
                states = batch[0]
                positive_distributions = batch[1]
                negative_distributions = batch[2]
                actual_size = batch[3]
                # These two assignments might need to be messed with to get everything working.
                cell_state = torch.zeros(actual_size, self.cell_size)
                hidden_state = torch.zeros(actual_size, self.hidden_size)
                positive_outputs, _, _ = self.discriminator(states, positive_distributions, cell_state, hidden_state, True)
                negative_outputs, _, _ = self.discriminator(states, negative_distributions, cell_state, hidden_state, True)
                loss = 0
                for o in range(len(positive_outputs)):
                    loss += self.bce_loss(positive_outputs[o], torch.ones(actual_size, 1)) + self.bce_loss(negative_outputs[o], torch.zeros(actual_size, 1))
                loss.backward()
                self.discriminator_optimizer.step()
                losses.append(loss.detach().cpu().numpy())
                print('\rTRAINING {} | AGE {} | BATCH {}/{} | EPOCH {}/{} | LOSS {} {}'.format(self.name, self.age, b+1, len(batches), e+1, epochs, losses[-1], extra_info), end='')
        self.age += 1
        return sum(losses) / len(losses)
                
            
    # Plays the provided game.
    # A group has the following content: (state, distribution, verdict, cell_state, hidden_state)
    def play_game(self, env, generator_epochs, render=False, extra_info=''):
        done = False
        action = 0
        score = 0
        step = 0
        lives = 4
        cell_state = torch.zeros(self.cell_size)
        hidden_state = torch.zeros(self.hidden_size)
        groups = []
        group = []
        approves = []
        env.reset()
        while not done:
            observation, reward, done, info = env.step(action)
            print('\r{} | STEP {} | SCORE {} | AGE {} | THRESHOLD {:0.2f} | APPROVAL {} {}\t\t'.format(self.name, step, score, self.age, self.threshold, sum(approves)/len(approves) if len(approves) > 0 else 0, extra_info), end = '')
            score += reward
            step += 1
            tensor = torch.Tensor.float(torch.from_numpy(observation / 255)).to(self.t_device)
            policy = self.generator(tensor)
            verdict, possible_cell_state, possible_hidden_state = self.discriminator([tensor], [policy], cell_state, hidden_state)
            if verdict[0] < self.threshold:
                approves.append(0)
                states = [tensor]
                past_cell_state = cell_state.detach()
                past_hidden_state = hidden_state.detach()
                states = states[::-1]
                self.train_generator(states, past_cell_state, past_hidden_state, generator_epochs)
                policy = self.generator(tensor)
                group.append((tensor.detach(), policy.detach(), 0, cell_state.detach(), hidden_state.detach()))
                verdict, cell_state, hidden_state = self.discriminator([tensor], [policy], cell_state, hidden_state)
            else:
                approves.append(1)
                group.append((tensor.detach(), policy.detach(), 1, cell_state.detach(), hidden_state.detach()))
                cell_state = possible_cell_state
                hidden_state = possible_hidden_state            
            if min(policy) < 0 or sum(policy) == 0:
                action = rand.randint(0, self.action_size - 1)
            else:
                distribution = torch.distributions.categorical.Categorical(policy)
                action = int(distribution.sample())
            if info['ale.lives'] != lives or done:
                groups.append(group)
                action = 0
                lives = info['ale.lives']
                cell_state = torch.zeros(self.cell_size)
                hidden_state = torch.zeros(self.hidden_size)
                group = []
            if render:
                env.render()
        return groups, score, step

In [8]:
# Population.
class POPULATION:
    
    # Constructor.
    def __init__(self, population_size, number_of_attempts, agent_params, teach_percent, train_percent, age_cutoff):
        self.population_size = population_size
        self.number_of_attempts = number_of_attempts
        self.teach_percent = teach_percent
        self.train_percent = train_percent
        self.population = []
        self.agents_created = population_size
        self.age_cutoff = age_cutoff
        self.agent_params = agent_params
        for i in range(population_size):
            agent = AGENT(f'AGENT_{i}', agent_params[0], agent_params[1], agent_params[2], agent_params[3], agent_params[4], agent_params[5])
            self.population.append(agent)
        self.generation = 0
        
    # Discretizes and labels the provided trajectories.
    def prepare_data(self, positive_examples, negative_examples, unroll_depth):
        trajectories = []
        for groups in positive_examples:
            index = unroll_depth
            while index < len(groups):
                trajectories.append(groups[index-unroll_depth:index])
                index += 1
        rand.shuffle(trajectories)
        return trajectories
        
    # Runs and trains the agents.
    def run_population(self, env, generator_epochs, unroll_depth, epochs, batch_size, render=False):
        new_pop = []
        total_score = 0
        high_score = None
        low_score = None
        manager.print('BEGIN RUNNING POPULATION | GENERATION {}'.format(self.generation))
        rand.shuffle(self.population)
        for agent in self.population:
            candidate_runs = []
            for g in range(self.number_of_attempts):
                groups, score, step = agent.play_game(env, generator_epochs, render, f'| MEMBER {len(new_pop) + 1}/{len(self.population)} | GAME {g+1}/{self.number_of_attempts}')
                total_score += score
                candidate_runs.append((groups, score, step))
                if high_score is None or high_score < score:
                    high_score = score
                if low_score is None or low_score > score:
                    low_score = score
            candidate_runs.sort(key = lambda x: x[1], reverse=True)
            new_pop.append((agent, candidate_runs[0][0], candidate_runs[0][1]))
        print('')
        manager.print('END RUNNING POPULATION | AVERAGE SCORE {} | LOW SCORE {} | HIGH SCORE {}'.format(total_score / (len(new_pop) * self.number_of_attempts), low_score, high_score))
        manager.save()
        new_pop.sort(key = lambda x: x[2], reverse=True)
        teach_pop = new_pop[:int(len(new_pop) * self.teach_percent)]
        train_pop = new_pop[-int(len(new_pop) * self.train_percent):]
        positive_examples = []
        negative_examples = []
        for exp in teach_pop:
            positive_examples += exp[1]
        trajectories = self.prepare_data(positive_examples, negative_examples, unroll_depth)
        manager.print('BEGIN TRAINING POPULATION')
        count = 0
        losses = []
        for train in train_pop:
            agent = train[0]
            # ADD BACK IN AGENT DEATH.
            
            #if agent.age > self.age_cutoff and rand.uniform(0,1) > 0.5:
            #    agent = self.replace_agent(agent)
            rand.shuffle(trajectories)
            loss = agent.train_discriminator(trajectories, unroll_depth, batch_size, epochs, extra_info=f'| MEMBER {count+1}/{len(train_pop)}')
            losses.append(loss)
            count += 1
        print('')
        manager.print('END TRAINING POPULATION | AVG LOSS {}'.format(sum(losses)/len(losses)))
        manager.save()
        self.generation += 1

In [9]:
agent_params = ([0.0001,0.1], 128, 100, 14, torch.device('cpu'), torch.device('cpu'))

In [10]:
population = POPULATION(40, 1, agent_params, 1/10, 1/5, 1000)

In [11]:
import gym
env = gym.make('KungFuMaster-ram-v0')

In [None]:
while True:
    population.run_population(env, 2, 50, 10, 64, True)

BEGIN RUNNING POPULATION | GENERATION 0
AGENT_8 | STEP 1124 | SCORE 800.0 | AGE 0 | THRESHOLD 0.77 | APPROVAL 0.0 | MEMBER 40/40 | GAME 1/1				
END RUNNING POPULATION | AVERAGE SCORE 1027.5 | LOW SCORE 0.0 | HIGH SCORE 8400.0
BEGIN TRAINING POPULATION
TRAINING AGENT_33 | AGE 0 | BATCH 125/125 | EPOCH 10/10 | LOSS 68.03379821777344 | MEMBER 8/88
END TRAINING POPULATION | AVG LOSS 69.41746123695373
BEGIN RUNNING POPULATION | GENERATION 1
AGENT_27 | STEP 1936 | SCORE 2300.0 | AGE 0 | THRESHOLD 0.89 | APPROVAL 0.0 | MEMBER 40/40 | GAME 1/1		40 | GAME 1/1			
END RUNNING POPULATION | AVERAGE SCORE 2050.0 | LOW SCORE 0.0 | HIGH SCORE 8200.0
BEGIN TRAINING POPULATION
TRAINING AGENT_14 | AGE 0 | BATCH 129/129 | EPOCH 10/10 | LOSS 27.748254776000977 | MEMBER 8/8
END TRAINING POPULATION | AVG LOSS 43.001579354437744
BEGIN RUNNING POPULATION | GENERATION 2
AGENT_2 | STEP 733 | SCORE 500.0 | AGE 0 | THRESHOLD 0.93 | APPROVAL 0.0 | MEMBER 37/40 | GAME 1/1		1		0 | GAME 1/1				