# **SETUP**

In [None]:

!pip install torch-geometric lightning wandb gymnasium
!pip install -U scikit-learn


In [1]:
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
import lightning as L
import torch
import torch.nn as nn
import wandb as wndb
from torch_geometric.nn import GATConv
from sklearn.metrics import r2_score, mean_absolute_error, root_mean_squared_error
import numpy as np
from sklearn.model_selection import train_test_split
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import gymnasium as gym
from gymnasium.envs.registration import register

# **MODEL AND RELATED STUFF**

In [None]:
class GreedyWorm(nn.Module):
  
  def __init__(self):
    super(GreedyWorm, self).__init__()
  
  def forward(self, data):
    graphs, positions = data
    return [graphs[i].x[positions[i]] for i in range(len(positions))]
    

In [None]:

class WormUpExamplesDataset(Dataset):

  def __init__(self, graphs: list[Data], actions: list[int], rewards: list[int]):
    self.data = zip(graphs, actions, rewards)


  def __getitem__(self, idx: int):
    return self.data[idx]

  def __len__(self):
    return len(self.data)

  def collate(self,data: list):
    graphs = []
    actions = []
    rewards = []
    for el in data:
      graphs.append(el[0])
      actions.append(el[1])
      rewards.append(el[2])
    return graphs,actions, rewards

  def get_dataloader(self, batch_size: int, shuffle: bool = False):
    return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate)




In [None]:
class GraphNN(nn.Module):

  def __init__(self,in_size, out_size, h_size, deep,activation,device="cpu"):
    super(GraphNN, self).__init__()
    self.activation = activation
    if deep == 1:
      self.layers = [GATConv(in_size,out_size)]#.to(device)]
    else:
      self.layers = [GATConv(in_size,h_size)]#.to(device)]
      for _ in range(deep-2):
        self.layers.append(GATConv(h_size,h_size))#.to(device))
      self.layers.append(GATConv(h_size,out_size))#.to(device))


  def forward(self,data):
    edge_index = data.edge_index
    edge_attr = data.edge_attr
    x = data.x
    for layer in self.layers[:-1]:
      x = self.activation(layer(x, edge_index, edge_attr))

    return self.layers[-1](x, edge_index, edge_attr)



In [None]:
class LinearNN(nn.Module):
  def __init__(self,in_size, out_size, h_size, deep,activation):
    super(LinearNN, self).__init__()
    if deep == 1:
      layers = [nn.Linear(in_size,out_size), activation]
    else:
      layers = [nn.Linear(in_size,h_size), activation]
      for _ in range(deep-2):
        layers.append(nn.Linear(h_size,h_size))
        layers.append(activation)
      layers.append(nn.Linear(h_size,out_size))
    self.linear = nn.Sequential(*layers)

  def forward(self,data):
    return self.linear(data)

In [None]:

class IntelligentWorm(L.LightningModule):

    def __int__(self, linear: nn.Module, gnn: nn.Module, lr: float = 1e-3):
        super(IntelligentWorm, self).__init__()
        self.encoder = gnn
        self.decoder = linear
        self.loss = nn.MSELoss()
        self.validation_predictions = []
        self.validation_targets = []
        self.validation_loss = []
        self.train_loss = []
        self.best_val_loss = 100000000
        self.best_mae = 100000000
        self.best_rmse = 1000000000
        self.best_r2 = -1
        self.best_model = 0
        self.lr = lr

    def update_best_stats(self, val_loss, mae, rmse, r2):
        self.best_val_loss = val_loss
        self.best_mae = mae
        self.best_rmse = rmse
        self.best_r2 = r2
            

    def forward(self, data):
        actions = data[1]
        graphs = data[0]
        embeddings = []
        for i in range(len(graphs)):
            g = graphs[i]
            x = self.encoder(g.x, g.edge_index)
            embeddings.append(x[actions[i]])
        embeddings = torch.tensor(embeddings, dtype=torch.float)
        return self.decoder(embeddings)

    def training_step(self, batch, batch_idx):
        graphs, actions, rewards = batch
        predictions = self.forward((graphs, actions))
        train_loss = self.loss(predictions, rewards)
        self.train_loss.append(train_loss)

        return train_loss

    def validation_step(self, batch, batch_idx):
        graphs, actions, rewards = batch
        self.validation_targets.append(rewards)
        predictions = self.forward((graphs, actions))
        validation_loss = self.loss(predictions, rewards)
        self.validation_predictions.append(predictions)
        self.validation_loss.append(validation_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer



In [None]:
class WormCallback(L.Callback):

  def on_train_epoch_end(self, trainer, pl_module):

      epoch_mean = float(torch.stack(pl_module.train_loss).mean())
      print("training_epoch_mean loss = ", epoch_mean)
      # free up the memory
      pl_module.train_loss.clear()
      try:
        wndb.log({"train_loss": epoch_mean})
      except:
        pass


  def on_validation_epoch_end(self,trainer, pl_module: IntelligentWorm):
    r2 = r2_score(pl_module.validation_predictions, pl_module.validation_targets)
    mae = mean_absolute_error(pl_module.validation_predictions, pl_module.validation_targets)
    rmse = root_mean_squared_error(pl_module.validation_predictions, pl_module.validation_targets)
    pl_module.validation_loss.clear()
    pl_module.validation_predictions.clear()
    pl_module.validation_targets.clear()

    mean_loss = float(torch.stack(pl_module.validation_loss).mean())

    print("val_loss = ", mean_loss)
    print("mean_absolute_error = ", mae)
    print("root_mean_squared_error = ", rmse)
    print("r2 = ", r2)
    count = 0
    count += 1 if mean_loss < pl_module.best_val_loss else 0
    count += 1 if mae < pl_module.best_mae else 0
    count += 1 if rmse < pl_module.best_rmse else 0
    count += 1 if r2 > pl_module.best_r2 else 0
    if count >= 3 or count == 2 and mean_loss < pl_module.best_val_loss:
        pl_module.update_best_stats(mean_loss,mae,rmse,r2)
        pl_module.best_model -= 1
        pl_module.log("best_model", pl_module.best_model)
    else:
        pl_module.log("best_model", pl_module.best_model + 1)
    
    try:
      wndb.log({"val_loss": mean_loss, "mean_absolute_error": mae, "root_mean_squared_error": rmse, "r2": r2 })
    except:
      pass



# **AGENT**

In [None]:


class WormsMasterAgent:

    def __init__(
            self,
            model,
            initial_epsilon: float,
            epsilon_decay: float,
            final_epsilon: float,
            learning_rate: float = 1,
            discount_factor: float = 0.95,
            decay_after: int = 1,
            #trainer params
            batch_size: int = 32,
            episodes_for_batch: int = 20,
            trainer_deterministic: bool = True,
            trainer_max_epochs: int = 20,
            trainer_accelerator: str = "cpu"

    ):
        self.batch_size = batch_size
        self.episodes_for_batch = episodes_for_batch
        self.episode = 0
        self.learning_model = model
        self.model = GreedyWorm()
        self.lr = learning_rate
        self.discount_factor = discount_factor
        self.epsilon = initial_epsilon
        self.epsilon_decay = epsilon_decay
        self.final_epsilon = final_epsilon
        self.decay_after = decay_after
        self.checkpoint_callback = ModelCheckpoint(dirpath="Model/", filename="worms_model.ckpt", save_top_k=1,
                                                   mode='min', monitor='best_model')
        self.worm_callback = WormCallback()
        self.early_stopping_callback = EarlyStopping(monitor='best_model', mode='min', patience=3)
        self.trainer = L.Trainer(deterministic=trainer_deterministic,
                                 max_epochs=trainer_max_epochs, accelerator=trainer_accelerator,
                                 callbacks=[self.checkpoint_callback, self.worm_callback, self.early_stopping_callback])
        self.actual_rewards = np.array([])
        self.actual_observations = []
        self.actual_actions = []
        self.model_training_data = {
            "actions": [],
            "observations": [],
            "rewards": []
        }

    def get_action(self, observation, available_actions: list[int]) -> int:
        graph = observation["field"]
        snakes = observation["snakes"]
        np.append(self.actual_rewards, 0)
        # with probability epsilon return a random action to explore the environment
        if np.random.random() < self.epsilon:
            i = np.random.randint(0, high=len(available_actions))
            return available_actions[i]

        # with probability (1 - epsilon) act greedily (exploit)
        else:
            self.model.eval()
            with torch.no_grad():
                return np.argmax(self.model([graph] * len(available_actions), available_actions))

    def update(self, observation, action: int, reward: int, terminated: bool):
        graph = observation["field"]
        snakes = observation["snakes"]
        self.actual_observations.append(graph)
        self.actual_actions.append(action)
        self.actual_rewards += self.lr * reward

        if terminated:
            self.episode += 1
            self.model_training_data["actions"] += self.actual_actions.copy()
            self.actual_actions = []
            self.model_training_data["observations"] += self.actual_observations.copy()
            self.actual_observations = []
            self.model_training_data["rewards"] += self.actual_rewards.tolist()
            self.actual_rewards = np.array([])
            if self.episode % self.decay_after == 0:
                self.decay_epsilon()

            if self.episode % self.episodes_for_batch == 0:
                self.train_model()
                self.model = self.learning_model

    def decay_epsilon(self):
        self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)

    def prepare_data(self):
        x = zip(self.model_training_data["observations"], self.model_training_data["actions"])
        y = self.model_training_data["rewards"]
        x_train, x_val, train_rewards, val_rewards = train_test_split(x, y, test_size=0.2)
        del x, y
        train_graphs = []
        train_actions = []
        for el in x_train:
            train_graphs.append(el[0])
            train_actions.append(el[1])
        val_graphs = []
        val_actions = []
        for el in x_val:
            val_graphs.append(el[0])
            val_actions.append(el[1])
        train_dataset = WormUpExamplesDataset(train_graphs, train_actions, train_rewards)
        val_dataset = WormUpExamplesDataset(val_graphs, val_actions, val_rewards)
        train_dataloader = train_dataset.get_dataloader(self.batch_size, shuffle=True)
        val_dataloader = val_dataset.get_dataloader(self.batch_size, shuffle=True)
        return train_dataloader, val_dataloader

    def train_model(self):
        train_dataloader, val_dataloader = self.prepare_data()
        wndb.init(
            # set the wandb project where this run will be logged
            project="WormsWarmingUp",
            
            # track hyperparameters and run metadata
            config={
            "learning_rate": self.model_training_data.lr,
            "architecture": str(self.model_training_data),
            "batch": self.episode // self.episodes_for_batch
            }
        )
        self.trainer.fit(self.model_training_data, train_dataloader, val_dataloader)
        wndb.finish()
        self.learning_model = IntelligentWorm.load_from_checkpoint(checkpoint_path="Model/worms_model.ckpt")
        

# **INITIALIZE MODEL, AGENT AND ENVIRONMENT**

In [None]:

# device = "gpu" if 
gnn_part = GraphNN(1,64,0,1,nn.ReLU())
linear_part = LinearNN(64,1,128,2, nn.ReLU())

worm_model = IntelligentWorm(linear_part, gnn_part)



In [None]:

agent = WormsMasterAgent(worm_model,0.95,0.005,0.15)


In [None]:

register(
          id="worms_env",
          entry_point="worms_env:WormsEnv",
          max_episode_steps=300,
     )
environment = gym.make('worms_env',  env_file="Data/00-example.txt", render_mode="human")


