In [1]:
from typing import Any

from lightning.pytorch.utilities.types import STEP_OUTPUT
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.4.0


In [None]:
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
import lightning as L
import torch
import torch.nn as nn
import wandb as wndb
from torch_geometric.nn import GATConv
from sklearn.metrics import r2_score, mean_absolute_error, root_mean_squared_error

In [None]:

class WormUpExamplesDataset(Dataset):

  def __init__(self, graphs: list[Data], actions: list[int], rewards: list[int]):
    self.data = zip(graphs, actions, rewards)
    
    
  def __getitem__(self, idx: int):
    return self.data[idx]
  
  def __len__(self):
    return len(self.data)
  
  def collate(self,data: list):
    graphs = []
    actions = []
    rewards = []
    for el in data:
      graphs.append(el[0])
      actions.append(el[1])
      rewards.append(el[2])
    return graphs,actions, rewards
  
  def get_dataloader(self, batch_size: int, shuffle: bool = False):
    return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate)




In [None]:
class GraphNN(nn.Module):

  def __init__(self,in_size, out_size, h_size, deep,activation,device):
    super(GraphNN, self).__init__()
    self.activation = activation
    if deep == 1:
      self.layers = [GATConv(in_size,out_size).to(device)]
    else:
      self.layers = [GATConv(in_size,h_size).to(device)]
      for _ in range(deep-2):
        self.layers.append(GATConv(h_size,h_size).to(device))
      self.layers.append(GATConv(h_size,out_size).to(device))


  def forward(self,data):
    edge_index = data.edge_index
    edge_attr = data.edge_attr
    x = data.x
    for layer in self.layers[:-1]:
      x = self.activation(layer(x, edge_index, edge_attr))

    return self.layers[-1](x, edge_index, edge_attr)



In [None]:
class LinearNN(nn.Module):
  def __init__(self,in_size, out_size, h_size, deep,activation):
    super(LinearNN, self).__init__()
    if deep == 1:
      layers = [nn.Linear(in_size,out_size), activation]
    else:
      layers = [nn.Linear(in_size,h_size), activation]
      for _ in range(deep-2):
        layers.append(nn.Linear(h_size,h_size))
        layers.append(activation)
      layers.append(nn.Linear(h_size,out_size))
    self.linear = nn.Sequential(*layers)

  def forward(self,data):
    return self.linear(data)

In [None]:

class WormModule(L.LightningModule):
  
  def __int__(self, linear: nn.Module, gnn: nn.Module, lr: float = 1e-3):
    self.encoder = gnn
    self.decoder = linear
    self.loss = nn.MSELoss()
    self.validation_predictions = []
    self.validation_targets = []
    self.validation_loss = []
    self.train_loss = []
      
  def forward(self, data):
    actions = data[1]
    graphs = data[0] 
    embeddings = []
    for i in range(len(graphs)):
      g = graphs[i]
      x = self.encoder(g.x, g.edge_index)
      embeddings.append(x[actions[i]])
    embeddings = torch.tensor(embeddings, dtype=torch.float)
    return self.decoder(embeddings)
      
        
  def training_step(self, batch, batch_idx):
    graphs, actions, rewards = batch
    predictions = self.forward((graphs,actions))
    train_loss = self.loss(predictions,rewards)
    self.train_loss.append(train_loss)
      
    return train_loss
    
    
  def validation_step(self, batch, batch_idx):
    graphs, actions, rewards = batch
    self.validation_targets.append(rewards)
    predictions = self.forward((graphs,actions))
    validation_loss = self.loss(predictions,rewards)
    self.validation_predictions.append(predictions)
    self.validation_loss.append(validation_loss)
    
    
    
    
    
  def configure_optimizers(self):
      optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
      return optimizer
    
  

In [None]:
class WormCallback(L.Callback):

  def on_train_epoch_end(self, trainer, pl_module):

      epoch_mean = float(torch.stack(pl_module.train_loss).mean())
      print("training_epoch_mean loss = ", epoch_mean)
      # free up the memory
      pl_module.train_loss.clear()
      try:
        wndb.log({"train_loss": epoch_mean})
      except:
        pass
        

  def on_validation_epoch_end(self,trainer, pl_module: WormModule):
    r2 = r2_score(pl_module.validation_predictions, pl_module.validation_targets)
    mae = mean_absolute_error(pl_module.validation_predictions, pl_module.validation_targets)
    rmse = root_mean_squared_error(pl_module.validation_predictions, pl_module.validation_targets)
    pl_module.validation_loss.clear()
    pl_module.validation_predictions.clear()
    pl_module.validation_targets.clear()
    
    mean_loss = float(torch.stack(pl_module.validation_loss).mean())
    
    print("val_loss = ", mean_loss)
    print("mean_absolute_error = ", mae)
    print("root_mean_squared_error = ", rmse)
    print("r2 = ", r2)
    try:
      wndb.log({"val_loss": mean_loss, "mean_absolute_error": mae, "root_mean_squared_error": rmse, "r2": r2 })
    except:
      pass

