In [None]:
import tensorflow as tf
from data_scripts.LevelDataset import LevelDataset
import matplotlib.pyplot as plt
from converter.to_img_converter.MultiLayerStackDecoder import MultiLayerStackDecoder
from level.LevelVisualizer import LevelVisualizer
from level.LevelReader import LevelReader
import os

In [None]:
dataset_path = "train_datasets/test_run_200/test_run_200.tfrecords"
dataset = LevelDataset(dataset_path = dataset_path, batch_size = 1)
dataset.load_dataset()
iter_data = dataset.get_dataset()

In [None]:
for image_batch, data in iter_data:

    # print(data)
    print(image_batch.shape)
    print(data.keys())
    original = image_batch[0]
    break

In [None]:
# pip install gym

In [None]:
import math

In [None]:
def load_level_decoder():
    multilayer_stack_decoder = MultiLayerStackDecoder()
    multilayer_stack_decoder.round_to_next_int = True
    multilayer_stack_decoder.custom_kernel_scale = True
    multilayer_stack_decoder.minus_one_border = False
    multilayer_stack_decoder.combine_layers = True
    multilayer_stack_decoder.negative_air_value = -1
    multilayer_stack_decoder.cutoff_point = 0.5
    multilayer_stack_decoder.display_decoding = False
    return multilayer_stack_decoder

In [None]:
# define a custom environment
class Environment:
    def init(self, map, max_step=math.inf):
        self.map = map
        self.state = copy.deepcopy(map)
        self.step_count = 0
        self.max_step = max_step
        self.done = False

    def reset(self):
        self.state = copy.deepcopy(self.map)
        self.step_count = 0
        self.done = False

        state_tensor = torch.tensor(self.state)
        return state_tensor

    def step(self, action, distance):
        # add that pixel to the state (level map)
        # action[0] is the x coordinate, action[1] is the y coordinate, action[2] is the type
        self.state[action[0], action[1], action[2]] = 1
        self.step_count += 1
        # reward = self.reward_func(action[0], action[1], action[2], distance)
        reward = 0

        if self.step_count > self.max_step:
            self.done = True

        info = {'step': self.step_count, 'action': action, 'reward': reward, 'done': self.done}

        return torch.tensor(self.state), torch.tensor(reward), self.done, info
    
    def reward_func(action, average_pixel):
        # get the distance from average of the empty pixels
        return 1/(1+math.sqrt((action[0] - average_pixel[0])**2 + (action[1] - average_pixel[1])**2))

In [None]:
dataset_path = "train_datasets/modified_test_run_200/modified_test_run_200.tfrecords"
modified_dataset = LevelDataset(dataset_path = dataset_path, batch_size = 1)
modified_dataset.load_dataset()
modified_iter_data = modified_dataset.get_dataset()

In [None]:
for image_batch, data in modified_iter_data:
    # print(data)
    print(image_batch.shape)
    modified = image_batch[0]
    break

In [None]:
from evaluation.GridSearchDecode import run_evaluation_xml_levels_one_by_one, create_tests

parameters = create_tests()

In [None]:
import shutil
import os

def move_file(source_file, destination_folder):
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)
    
    if os.path.exists(source_file):
        file_name = source_file.split("/")[-1]
        destination_file = os.path.join(destination_folder, file_name)
        shutil.move(source_file, destination_file)
        print(f"File '{source_file}' moved successfully!")
        return destination_file
    else:
        print(f"File '{source_file}' does not exist in the source folder.")

# Example usage:
# move_file("path/to/source/folder", "path/to/destination/folder", "filename.txt")


In [None]:
def find_closest_distance(predicted_pixel, pixel_list, threshold=7):
    min_distance = math.inf
    for pixel in pixel_list:
        distance = math.sqrt((predicted_pixel[0] - pixel[0])**2 + (predicted_pixel[1] - pixel[1])**2 + (predicted_pixel[2] - pixel[2])**2)
        if distance < min_distance:
            min_distance = distance

    if min_distance < threshold:
        return 1/(1+min_distance)
    else:
        return -1

In [None]:
# test the function
pixel_list = [(0, 0, 0), (10, 10, 10), (20, 20, 20), (30, 30, 30)]
predicted_pixel = (3, 3, 3)
reward = find_closest_distance(predicted_pixel, pixel_list)
print(reward)

In [None]:
def max_step_calc(original, modified):
    mask = tf.not_equal(original, modified)
    indices = tf.where(mask)

    num_differences = np.sum(mask)

    return num_differences, indices

In [None]:
# num_differences, indices = max_step_calc(original, modified)
# print(indices)

In [None]:
import numpy as np

In [None]:
def plot_removed_blocks(indices):
    # create a 2d grid of values
    grid = np.zeros((128, 128))
    for index in indices:
        grid[index[0], index[1]] += 1

    # plot the grid
    plt.imshow(grid, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.show()

In [None]:
def calculate_difference(original, modified, predicted_pixel):
    mask = tf.not_equal(original, modified)
    indices = tf.where(mask)

    # Print the positions where the elements are different
    # print("Positions where elements are different:")
    # print(np.where(mask))

    # Calculate the number of differences
    num_differences = np.sum(mask)
    print("Number of differences:", num_differences)

    pixel_list = []

    for idx in indices:
        print(idx)
        # row_idx, col_idx, channel_idx = idx[0], idx[1], idx[2]
        # original_value = tf.gather_nd(original, [idx])
        # modified_value = tf.gather_nd(modified, [idx])
        pixel_list.append(idx)
        # print(f"Difference at position {idx}: original={original_value.numpy()}, modified={modified_value.numpy()}")

    return find_closest_distance(predicted_pixel, pixel_list)

In [None]:
from converter.gan_processing.DecodingFunctions import DecodingFunctions

In [None]:
def xml_convert(original, modified, path, counter, show_fig = False):
    # functions to move the output from [-1, 1] to [0, 1] range
    decoding_functions = DecodingFunctions(threshold_callback = lambda: 0.5)
    decoding_functions.set_rescaling(rescaling = tf.keras.layers.Rescaling)
    decoding_functions.update_rescale_values(max_value = 1, shift_value = 1)
    rescale_function = decoding_functions.rescale

    # function to flatten the gan output to an image with 1 channel
    decoding_function = decoding_functions.argmax_multilayer_decoding_with_air

    if show_fig:
        ref_img, _ = decoding_function(original)
        # print(gan_outputs_reformatted[i].shape)

        # save image trough matplotlib
        plt.imshow(ref_img)
        plt.savefig(f'{path}/original_level{counter}.png')
        # clear plot
        plt.clf()

        ref_img, _ = decoding_function(modified)
        # print(gan_outputs_reformatted[i].shape)

        # save image trough matplotlib
        plt.imshow(ref_img)
        plt.savefig(f'{path}/modified1_level{counter}.png')
        # clear plot
        plt.clf()

    multilayer_stack_decoder = load_level_decoder()
    # level_visualizer = LevelVisualizer()
    level_reader = LevelReader()

    
    level = multilayer_stack_decoder.decode(modified)
    # print("level", level)

    # fig, ax = plt.subplots(1, 1, dpi = 100)
    # level_visualizer.create_img_of_structure(
    #     level.get_used_elements(), use_grid = False, ax = ax, scaled = True
    # )
    # fig.savefig('modified1_decoded_level{counter}.png')
    # plt.clf()

    # Save level to xml
    level_xml = level_reader.create_level_from_structure(level.get_used_elements(), red_birds = True, move_to_ground = True)
    level_reader.write_xml_file(level_xml, os.path.join("./", f'{path}/modified_level{counter}.xml'))

    return f'{path}/modified_level{counter}.xml'
    

In [None]:
# plot a 2d heatmap
# the heatmap is a 2d grid of values and a counter for repetition of each value
def plot_heatmap(heatmap, title = "default"):
    # create a 2d grid of values
    grid = np.zeros((128, 128))
    for key, value in heatmap.items():
        grid[key[0], key[1]] = value

    # plot the grid
    plt.imshow(grid, cmap='hot', interpolation='nearest')
    plt.title(title)
    plt.colorbar()
    plt.show()

In [None]:
def plot_total(array, title = "default"):
    loss_array = np.array(array)
    plt.plot(loss_array)
    plt.title(title)
    plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import copy

# Define a custom environment
class Environment:
    def __init__(self, map, max_step=1000):
        self.map = map  # Initialize the map
        self.state = copy.deepcopy(self.map)
        self.step_count = 0
        self.max_step = max_step
        self.done = False

    def reset(self):
        self.state = copy.deepcopy(self.map)
        self.step_count = 0
        self.done = False
        state_tensor = torch.tensor(self.state, dtype=torch.float32).flatten()
        return state_tensor

    def step(self, action):
        x, y, z = action  # Decompose the action into coordinates
        if self.state[x, y, z] == -1:
            self.state[x, y, z] = 1  # Place block if the spot is empty
            reward = 1  # Positive reward for placing a block
        else:
            reward = -10 # Negative reward if block is already there
        self.step_count += 1
        if self.step_count >= self.max_step:
            self.done = True
        return torch.tensor(self.state, dtype=torch.float32).flatten(), reward, self.done, {}

# PPO Model
class PPO(nn.Module):
    def __init__(self, input_size):
        super(PPO, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc_pi = nn.Linear(128, input_size)
        self.fc_v = nn.Linear(128, 1)
        self.optimizer = optim.Adam(self.parameters(), lr=0.03)
        # self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.99)
        self.data = []  # Initialize data list for storing transitions
        self.total_loss = []
        self.total_reward = []

    def pi(self, x, softmax_dim=0):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc_pi(x)
        prob = torch.softmax(x, dim=softmax_dim)
        return prob

    def v(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        v = self.fc_v(x)
        return v

    def put_data(self, transition):
        self.data.append(transition)

    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = zip(*self.data)
        self.data = []
        return torch.stack(s_lst), torch.tensor(a_lst, dtype=torch.long), torch.tensor(r_lst, dtype=torch.float), torch.stack(s_prime_lst), torch.tensor(done_lst, dtype=torch.float)

    def train_net(self, gamma=0.98, lmbda=0.95, eps_clip=0.2):
        s, a, r, s_prime, done = self.make_batch()
        td_target = r + gamma * self.v(s_prime) * (1 - done)
        delta = td_target - self.v(s)
        delta = delta.detach().numpy()
        advantage_lst = []
        advantage = 0.0
        for delta_t in delta[::-1]:
            advantage = gamma * lmbda * advantage + delta_t[0]
            advantage_lst.append([advantage])
        advantage_lst.reverse()
        advantage = torch.tensor(advantage_lst, dtype=torch.float)
        pi = self.pi(s, softmax_dim=1)
        pi_a = pi.gather(1, a.unsqueeze(1)).squeeze(1)
        ratio = torch.exp(torch.log(pi_a) - torch.log(pi_a.detach()))
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage
        entropy = -(pi * torch.log(pi + 1e-5)).sum(1).mean()
        loss = -torch.min(surr1, surr2) + torch.nn.functional.smooth_l1_loss(self.v(s), td_target.detach()) - 0.01 * entropy
        self.optimizer.zero_grad()
        loss_mean = loss.mean()  # Compute mean loss
        loss_mean.backward()
        self.optimizer.step()
        # self.scheduler.step()
        return loss_mean  # Return the mean loss value
    
    
def main(iter_data, modified_iter_data):
    model = PPO(128*128*5)  # Initialize the PPO model
    score = 0.0
    counter = 0
    combined_dataset = tf.data.Dataset.zip((iter_data, modified_iter_data))
    distance_threshold = 40
    learn_threshold = 1

    for (image_batch, data), (modified_image_batch, modified_data) in combined_dataset:
        
        original = image_batch[0].numpy()
        closest_score = 0.0
        score = 0.0
        env = Environment(modified_image_batch[0].numpy())  # Create environment from batch
        state = env.reset()  # Reset environment at the start of each batch
        diff, indices  = max_step_calc(original, env.state)
        if counter % 20 == 0:
            plot_removed_blocks(indices)
        if 0 == diff or diff > 400:
            continue
        env.max_step = diff
        heatmap_dict = {}
        empty_point_dict = {}
        done = False
        while not done:
            prob = model.pi(state)  # Policy forward pass
            m = Categorical(prob)  # Distribution for sampling actions
            action_index = m.sample().item()  # Sample an action
            action = action_index // (128*5), (action_index % (128*5)) // 5, action_index % 5
            heatmap_dict[(action[0], action[1])] = heatmap_dict.get((action[0], action[1]), 0) + 1
            state_prime, reward, done, _ = env.step(action)  # Execute action in the environment
            # normalize the reward
            reward = reward / env.max_step
            state = state_prime
            score += reward  # Update score
            if reward < 0:
                empty_point_dict[(action[0], action[1])] =  - 1
            else:
                empty_point_dict[(action[0], action[1])] = 1
            closest= find_closest_distance(action, pixel_list, distance_threshold)
            closest = closest / env.max_step
            closest_score += closest
            score += closest

            model.put_data((state, action_index, reward + closest, state_prime, done))  # Store data for training
            if counter % learn_threshold == 0:
                loss = model.train_net()  # Train model
        
        xml_path = xml_convert(original, env.state, "temp", counter, show_fig = True)
        final_path = move_file(xml_path, 'evaluation\\temp')
        # # if level is stable
        # if run_evaluation_xml_levels_one_by_one("temp", parameters[0]):
        #     print("Level is stable")
        #     score += 10
        # else:
        #     print("Level is not stable")
        #     score -= 10
        # os.remove(final_path)

        print(f"distance score: {closest_score}")
        print(f"Episode Score: {score}, Loss: {loss}")
        print(f"Episode {counter} completed")
        model.total_loss.append(loss)
        model.total_reward.append(score)
        print("##############################################")
        counter += 1
        if counter % 20 == 0:
            plot_heatmap(empty_point_dict, title = "empty points")
            plot_heatmap(heatmap_dict, title = "total points")
            print(f"toall difference: {diff}")
            diff, _ = max_step_calc(original, env.state)
            print(f"toall difference after modifying: {diff}")
            wrong_filled = empty_point_dict.values()
            wrong_filled = list(filter(lambda x: x < 0, wrong_filled))
            print(f"number of wrong filled pixels: {len(wrong_filled)}")

            print(f"Episode {counter} completed")
            plot_total(model.total_reward, title = "total reward")
            print("###############################################################################")
            if distance_threshold > 4:
                distance_threshold -= 2
            if learn_threshold < 100:
                learn_threshold += 1
            # Save the model
            print("Model saved successfully")
            torch.save(model.state_dict(), f'temp_models/ppo_model{counter}.pth')


# IDK

In [None]:
# import os
# import numpy as np
# import torch as T
# import torch.nn as nn
# import torch.optim as optim
# from torch.distributions.categorical import Categorical
# import os
# import numpy as np
# import torch as T
# import torch.nn as nn
# import torch.optim as optim
# from torch.distributions.categorical import Categorical
# from torch.utils.data import DataLoader, Dataset
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.distributions import Categorical
# import numpy as np
# import copy

# # Define your environment class here (unchanged)
# class BlockPlacingEnvironment:
#     def __init__(self, map, max_step=500):
#         self.map = map
#         self.state = copy.deepcopy(self.map)
#         self.step_count = 0
#         self.max_step = max_step
#         self.done = False

#     def reset(self):
#         self.state = copy.deepcopy(self.map)
#         self.step_count = 0
#         self.done = False
#         state_tensor = torch.tensor(self.state, dtype=torch.float32).flatten()
#         return state_tensor

#     def step(self, action):
#         # print(action)
#         z = action % 5
#         x = action // (128*5)
#         y = (action % (128*5)) // 5
#         if self.state[x, y, z] == -1:
#             self.state[x, y, z] = 1
#             reward = 10
#         else:
#             reward = -10
#         self.step_count += 1
#         if self.step_count >= self.max_step:
#             self.done = True
#         return torch.tensor(self.state, dtype=torch.float32).flatten(), reward, self.done

# class PPOMemory:
#     def __init__(self, batch_size):
#         self.states = []
#         self.probs = []
#         self.vals = []
#         self.actions = []
#         self.rewards = []
#         self.dones = []

#         self.batch_size = batch_size

#     def generate_batches(self):
#         n_states = len(self.states)
#         batch_start = np.arange(0, n_states, self.batch_size)
#         indices = np.arange(n_states, dtype=np.int64)
#         np.random.shuffle(indices)
#         batches = [indices[i:i+self.batch_size] for i in batch_start]

#         return np.array(self.states),\
#                 np.array(self.actions),\
#                 np.array(self.probs),\
#                 np.array(self.vals),\
#                 np.array(self.rewards),\
#                 np.array(self.dones),\
#                 batches

#     def store_memory(self, state, action, probs, vals, reward, done):
#         self.states.append(state)
#         self.actions.append(action)
#         self.probs.append(probs)
#         self.vals.append(vals)
#         self.rewards.append(reward)
#         self.dones.append(done)

#     def clear_memory(self):
#         self.states = []
#         self.probs = []
#         self.actions = []
#         self.rewards = []
#         self.dones = []
#         self.vals = []

# class ActorNetwork(nn.Module):
#     def __init__(self, n_actions, input_dims, alpha,
#             fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
#         super(ActorNetwork, self).__init__()

#         self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
#         self.actor = nn.Sequential(
#                 nn.Linear(*input_dims, fc1_dims),
#                 nn.ReLU(),
#                 nn.Linear(fc1_dims, fc2_dims),
#                 nn.ReLU(),
#                 nn.Linear(fc2_dims, n_actions),
#                 nn.Softmax(dim=-1)
#         )

#         self.optimizer = optim.Adam(self.parameters(), lr=alpha)
#         self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
#         self.to(self.device)

#     def forward(self, state):
#         dist = self.actor(state)
#         dist = Categorical(dist)
        
#         return dist

#     def save_checkpoint(self):
#         T.save(self.state_dict(), self.checkpoint_file)

#     def load_checkpoint(self):
#         self.load_state_dict(T.load(self.checkpoint_file))

# class CriticNetwork(nn.Module):
#     def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256,
#             chkpt_dir='tmp/ppo'):
#         super(CriticNetwork, self).__init__()

#         self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
#         self.critic = nn.Sequential(
#                 nn.Linear(*input_dims, fc1_dims),
#                 nn.ReLU(),
#                 nn.Linear(fc1_dims, fc2_dims),
#                 nn.ReLU(),
#                 nn.Linear(fc2_dims, 1)
#         )

#         self.optimizer = optim.Adam(self.parameters(), lr=alpha)
#         self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
#         self.to(self.device)

#     def forward(self, state):
#         value = self.critic(state)

#         return value

#     def save_checkpoint(self):
#         T.save(self.state_dict(), self.checkpoint_file)

#     def load_checkpoint(self):
#         self.load_state_dict(T.load(self.checkpoint_file))

# class Agent:
#     def __init__(self, n_actions, input_dims, gamma=0.99, alpha=0.0003, gae_lambda=0.95,
#             policy_clip=0.2, batch_size=64, n_epochs=10):
#         self.gamma = gamma
#         self.policy_clip = policy_clip
#         self.n_epochs = n_epochs
#         self.gae_lambda = gae_lambda

#         self.actor = ActorNetwork(n_actions, input_dims, alpha)
#         self.critic = CriticNetwork(input_dims, alpha)
#         self.memory = PPOMemory(batch_size)
       
#     def remember(self, state, action, probs, vals, reward, done):
#         self.memory.store_memory(state, action, probs, vals, reward, done)

#     def save_models(self):
#         print('... saving models ...')
#         self.actor.save_checkpoint()
#         self.critic.save_checkpoint()

#     def load_models(self):
#         print('... loading models ...')
#         self.actor.load_checkpoint()
#         self.critic.load_checkpoint()

#     def choose_action(self, observation):
#         state = observation

#         dist = self.actor(state)
#         value = self.critic(state)
#         action = dist.sample()

#         probs = T.squeeze(dist.log_prob(action)).item()
#         action = T.squeeze(action).item()
#         value = T.squeeze(value).item()

#         return action, probs, value


#     def learn(self):
#         for _ in range(self.n_epochs):
#             state_arr, action_arr, old_prob_arr, vals_arr,\
#             reward_arr, dones_arr, batches = \
#                     self.memory.generate_batches()

#             values = vals_arr
#             advantage = np.zeros(len(reward_arr), dtype=np.float32)

#             for t in range(len(reward_arr)-1):
#                 discount = 1
#                 a_t = 0
#                 for k in range(t, len(reward_arr)-1):
#                     a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*\
#                             (1-int(dones_arr[k])) - values[k])
#                     discount *= self.gamma*self.gae_lambda
#                 advantage[t] = a_t
#             advantage = T.tensor(advantage).to(self.actor.device)

#             values = T.tensor(values).to(self.actor.device)
#             for batch in batches:
#                 print(f"state arr: {state_arr[batch]}")
#                 states = T.tensor(state_arr[batch], dtype=T.float).to(self.actor.device)
#                 old_probs = T.tensor(old_prob_arr[batch]).to(self.actor.device)
#                 actions = T.tensor(action_arr[batch]).to(self.actor.device)

#                 dist = self.actor(states)
#                 critic_value = self.critic(states)

#                 critic_value = T.squeeze(critic_value)

#                 new_probs = dist.log_prob(actions)
#                 prob_ratio = new_probs.exp() / old_probs.exp()
#                 #prob_ratio = (new_probs - old_probs).exp()
#                 weighted_probs = advantage[batch] * prob_ratio
#                 weighted_clipped_probs = T.clamp(prob_ratio, 1-self.policy_clip,
#                         1+self.policy_clip)*advantage[batch]
#                 actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean()

#                 returns = advantage[batch] + values[batch]
#                 critic_loss = (returns-critic_value)**2
#                 critic_loss = critic_loss.mean()

#                 total_loss = actor_loss + 0.5*critic_loss
#                 self.actor.optimizer.zero_grad()
#                 self.critic.optimizer.zero_grad()
#                 total_loss.backward()
#                 self.actor.optimizer.step()
#                 self.critic.optimizer.step()

#         self.memory.clear_memory()    

#         print("Average Actor Loss:", actor_loss.item())
#         print("Average Critic Loss:", critic_loss.item())
#         print("Actor Weights Norm:", sum(p.norm().item() for p in self.actor.parameters()))
#         print("Critic Weights Norm:", sum(p.norm().item() for p in self.critic.parameters()))

# if __name__ == "__main__":
#     agent = Agent(n_actions=128*128*5, input_dims=(128*128*5,), alpha=0.0001, batch_size=64, n_epochs=5)
#     episode = 0
#     for image_batch, data in modified_iter_data:
#         env = BlockPlacingEnvironment(image_batch[0].numpy())
#         observation = env.reset()
#         done = False
#         total_reward = 0

#         while not done:
#             action, probs, value = agent.choose_action(observation)
#             observation, reward, done = env.step(action)
#             agent.remember(observation, action, probs, value, reward, done)
#             total_reward += reward

#         agent.learn()
#         print(f'Episode {episode + 1}: Total Reward = {total_reward}')
#         episode += 1
#     agent.save_models()           



# Github PPO

In [None]:
# import os
# import numpy as np
# import torch as T
# import torch.nn as nn
# import torch.optim as optim
# from torch.distributions.categorical import Categorical
# import os
# import numpy as np
# import torch as T
# import torch.nn as nn
# import torch.optim as optim
# from torch.distributions.categorical import Categorical
# from torch.utils.data import DataLoader, Dataset
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.distributions import Categorical
# import numpy as np
# import copy

# # Define your environment class here (unchanged)
# class BlockPlacingEnvironment:
#     def __init__(self, map, max_step=500):
#         self.map = map
#         self.state = copy.deepcopy(self.map)
#         self.step_count = 0
#         self.max_step = max_step
#         self.done = False

#     def reset(self):
#         self.state = copy.deepcopy(self.map)
#         self.step_count = 0
#         self.done = False
#         state_tensor = torch.tensor(self.state, dtype=torch.float32).flatten()
#         return state_tensor

#     def step(self, action):
#         # print(action)
#         z = action % 5
#         x = action // (128*5)
#         y = (action % (128*5)) // 5
#         if self.state[x, y, z] == -1:
#             self.state[x, y, z] = 1
#             reward = 10
#         else:
#             reward = -10
#         self.step_count += 1
#         if self.step_count >= self.max_step:
#             self.done = True
#         return torch.tensor(self.state, dtype=torch.float32).flatten(), reward, self.done

# class PPOMemory:
#     def __init__(self, batch_size):
#         self.states = []
#         self.probs = []
#         self.vals = []
#         self.actions = []
#         self.rewards = []
#         self.dones = []

#         self.batch_size = batch_size

#     def generate_batches(self):
#         n_states = len(self.states)
#         batch_start = np.arange(0, n_states, self.batch_size)
#         indices = np.arange(n_states, dtype=np.int64)
#         np.random.shuffle(indices)
#         batches = [indices[i:i+self.batch_size] for i in batch_start]

#         return np.array(self.states),\
#                 np.array(self.actions),\
#                 np.array(self.probs),\
#                 np.array(self.vals),\
#                 np.array(self.rewards),\
#                 np.array(self.dones),\
#                 batches

#     def store_memory(self, state, action, probs, vals, reward, done):
#         self.states.append(state)
#         self.actions.append(action)
#         self.probs.append(probs)
#         self.vals.append(vals)
#         self.rewards.append(reward)
#         self.dones.append(done)

#     def clear_memory(self):
#         self.states = []
#         self.probs = []
#         self.actions = []
#         self.rewards = []
#         self.dones = []
#         self.vals = []

# class ActorNetwork(nn.Module):
#     def __init__(self, n_actions, input_dims, alpha,
#             fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
#         super(ActorNetwork, self).__init__()

#         self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
#         self.actor = nn.Sequential(
#                 nn.Linear(*input_dims, fc1_dims),
#                 nn.ReLU(),
#                 nn.Linear(fc1_dims, fc2_dims),
#                 nn.ReLU(),
#                 nn.Linear(fc2_dims, n_actions),
#                 nn.Softmax(dim=-1)
#         )

#         self.optimizer = optim.Adam(self.parameters(), lr=alpha)
#         self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
#         self.to(self.device)

#     def forward(self, state):
#         dist = self.actor(state)
#         dist = Categorical(dist)
        
#         return dist

#     def save_checkpoint(self):
#         T.save(self.state_dict(), self.checkpoint_file)

#     def load_checkpoint(self):
#         self.load_state_dict(T.load(self.checkpoint_file))

# class CriticNetwork(nn.Module):
#     def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256,
#             chkpt_dir='tmp/ppo'):
#         super(CriticNetwork, self).__init__()

#         self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
#         self.critic = nn.Sequential(
#                 nn.Linear(*input_dims, fc1_dims),
#                 nn.ReLU(),
#                 nn.Linear(fc1_dims, fc2_dims),
#                 nn.ReLU(),
#                 nn.Linear(fc2_dims, 1)
#         )

#         self.optimizer = optim.Adam(self.parameters(), lr=alpha)
#         self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
#         self.to(self.device)

#     def forward(self, state):
#         value = self.critic(state)

#         return value

#     def save_checkpoint(self):
#         T.save(self.state_dict(), self.checkpoint_file)

#     def load_checkpoint(self):
#         self.load_state_dict(T.load(self.checkpoint_file))

# class Agent:
#     def __init__(self, n_actions, input_dims, gamma=0.99, alpha=0.0003, gae_lambda=0.95,
#             policy_clip=0.2, batch_size=64, n_epochs=10):
#         self.gamma = gamma
#         self.policy_clip = policy_clip
#         self.n_epochs = n_epochs
#         self.gae_lambda = gae_lambda

#         self.actor = ActorNetwork(n_actions, input_dims, alpha)
#         self.critic = CriticNetwork(input_dims, alpha)
#         self.memory = PPOMemory(batch_size)
       
#     def remember(self, state, action, probs, vals, reward, done):
#         self.memory.store_memory(state, action, probs, vals, reward, done)

#     def save_models(self):
#         print('... saving models ...')
#         self.actor.save_checkpoint()
#         self.critic.save_checkpoint()

#     def load_models(self):
#         print('... loading models ...')
#         self.actor.load_checkpoint()
#         self.critic.load_checkpoint()

#     def choose_action(self, observation):
#         state = observation

#         dist = self.actor(state)
#         value = self.critic(state)
#         action = dist.sample()

#         probs = T.squeeze(dist.log_prob(action)).item()
#         action = T.squeeze(action).item()
#         value = T.squeeze(value).item()

#         return action, probs, value


#     def learn(self):
#         for _ in range(self.n_epochs):
#             state_arr, action_arr, old_prob_arr, vals_arr,\
#             reward_arr, dones_arr, batches = \
#                     self.memory.generate_batches()

#             values = vals_arr
#             advantage = np.zeros(len(reward_arr), dtype=np.float32)

#             for t in range(len(reward_arr)-1):
#                 discount = 1
#                 a_t = 0
#                 for k in range(t, len(reward_arr)-1):
#                     a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*\
#                             (1-int(dones_arr[k])) - values[k])
#                     discount *= self.gamma*self.gae_lambda
#                 advantage[t] = a_t
#             advantage = T.tensor(advantage).to(self.actor.device)

#             values = T.tensor(values).to(self.actor.device)
#             for batch in batches:
#                 states = T.tensor(state_arr[batch], dtype=T.float).to(self.actor.device)
#                 old_probs = T.tensor(old_prob_arr[batch]).to(self.actor.device)
#                 actions = T.tensor(action_arr[batch]).to(self.actor.device)

#                 dist = self.actor(states)
#                 critic_value = self.critic(states)

#                 critic_value = T.squeeze(critic_value)

#                 new_probs = dist.log_prob(actions)
#                 prob_ratio = new_probs.exp() / old_probs.exp()
#                 #prob_ratio = (new_probs - old_probs).exp()
#                 weighted_probs = advantage[batch] * prob_ratio
#                 weighted_clipped_probs = T.clamp(prob_ratio, 1-self.policy_clip,
#                         1+self.policy_clip)*advantage[batch]
#                 actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean()

#                 returns = advantage[batch] + values[batch]
#                 critic_loss = (returns-critic_value)**2
#                 critic_loss = critic_loss.mean()

#                 total_loss = actor_loss + 0.5*critic_loss
#                 self.actor.optimizer.zero_grad()
#                 self.critic.optimizer.zero_grad()
#                 total_loss.backward()
#                 self.actor.optimizer.step()
#                 self.critic.optimizer.step()

#         self.memory.clear_memory()    

#         print("Average Actor Loss:", actor_loss.item())
#         print("Average Critic Loss:", critic_loss.item())
#         print("Actor Weights Norm:", sum(p.norm().item() for p in self.actor.parameters()))
#         print("Critic Weights Norm:", sum(p.norm().item() for p in self.critic.parameters()))

# if __name__ == "__main__":
#     agent = Agent(n_actions=128*128*5, input_dims=(128*128*5,), alpha=0.0001, batch_size=64, n_epochs=5)
#     episode = 0
#     for image_batch, data in modified_iter_data:
#         env = BlockPlacingEnvironment(image_batch[0].numpy())
#         observation = env.reset()
#         done = False
#         total_reward = 0


#         diff, indices  = max_step_calc(original, env.state)
#         plot_removed_blocks(indices)
#         env.max_step = diff
#         heatmap_dict = {}
#         empty_point_dict = {}
#         closest_score = 0
#         distance_threshold = 40


#         while not done:
#             action, probs, value = agent.choose_action(observation)
#             observation, reward, done = env.step(action)
#             x = action // (128*5)
#             y = (action % (128*5)) // 5
#             z = action % 5
#             print(f"Action: {x}, {y}, {z}")

#             heatmap_dict[(x, y)] = heatmap_dict.get((x, y), 0) + 1
#             state_prime, reward, done = env.step(action)  # Execute action in the environment
#             # normalize the reward
#             reward = reward / env.max_step
#             state = state_prime
#             total_reward += reward  # Update score
#             if reward < 0:
#                 empty_point_dict[(x,y)] =  - 1
#             else:
#                 empty_point_dict[(x,y)] = 1
#             closest= find_closest_distance((x,y,z), pixel_list, distance_threshold)
#             closest = closest / env.max_step
#             closest_score += closest
#             total_reward += closest

#             agent.remember(observation, action, probs, value, reward, done)
            
#         agent.learn()
#         plot_heatmap(empty_point_dict, title = "empty points")
#         plot_heatmap(heatmap_dict, title = "total points")
#         print(f"toall difference: {diff}")
#         diff, _ = max_step_calc(original, env.state)
#         print(f"toall difference after modifying: {diff}")
#         wrong_filled = empty_point_dict.values()
#         wrong_filled = list(filter(lambda x: x < 0, wrong_filled))
#         print(f"number of wrong filled pixels: {len(wrong_filled)}")
#         xml_path = xml_convert(original, env.state, "temp", episode, show_fig = True)
#         final_path = move_file(xml_path, 'evaluation\\temp')
#         # # if level is stable
#         # if run_evaluation_xml_levels_one_by_one("temp", parameters[0]):
#         #     print("Level is stable")
#         #     score += 10
#         # else:
#         #     print("Level is not stable")
#         #     score -= 10
#         # os.remove(final_path)
#         print(f"Episode Score: {total_reward}, Closest Score: {closest_score}")
#         print("##############################################")
#         episode += 1
#         if episode % 20 == 0:
#             print(f"Episode {episode} completed")
            
#             if distance_threshold > 4:
#                 distance_threshold -= 2
                
#         episode += 1
#     agent.save_models()           



In [None]:
if __name__ == '__main__':
    # Assume modified_iter_data is your dataset
    main(iter_data, modified_iter_data)

In [None]:
# Load the model
print("Loading ...")
model = PPO(128*128*5)  # Recreate the model
model.load_state_dict(torch.load('ppo_model.pth'))  # Load the saved parameters
model.eval()

In [None]:
with torch.no_grad():  # Disable gradient computation for inference
    for image_batch, data in modified_iter_data:  # Assuming test data is provided in test_iter_data
        state = env.reset()  # Reset the environment
        done = False
        while not done:
            prob = model.pi(state)
            m = Categorical(prob)
            action_index = m.sample().item()
            action = (action_index // (128*128), (action_index % (128*128)) // 128, action_index % 5)  
            state_prime, reward, done, _ = env.step(action)
            state = state_prime
            if done:
                break


In [None]:
# count = 0
# for idx in indices:
#     count += 1

#     row_idx, col_idx, channel_idx = idx[0], idx[1], idx[2]
#     original_value = tf.gather_nd(original, [idx])  # Get original value
#     modified = tf.tensor_scatter_nd_update(
#         modified, [[row_idx, col_idx, channel_idx]], original_value  # Update modified tensor
#     )



# print("Modified tensor with original values rewritten:")
# print(modified)