In [None]:
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url
import torch_geometric.transforms as T

from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GraphConv
from torch_geometric.nn import global_mean_pool

from torch_geometric.loader import DataLoader

from torchmetrics import Accuracy
from torch.utils.data import random_split

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.loggers import MLFlowLogger

In [None]:
# Define a dataset class
class Dataset(InMemory):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        print("INFO: self.processed_paths = ",self.processed_paths)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        # Read data into huge `Data` list.
        data_list = None

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [None]:
# Define your model
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(in_channels, hidden_channels).jittable() #NOTE: NEEDED FOR DEPLOYMENT IN CMAKE
        self.conv2 = GraphConv(hidden_channels, hidden_channels).jittable()
        self.conv3 = GraphConv(hidden_channels, hidden_channels).jittable()
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        x = torch.sigmoid(x) #NOTE: DON'T SOFTMAX IF USING BCELOSS, USE SIGMOID INSTEAD
        
        return x
        

In [None]:
# Define a generic pytorch-lightning model for binary classification

class PLModel(pl.LightningModule):

    def __init__(self,
                 model_class = None,
                 model_class_args = [],
                 model_class_kwargs = {},
                 criterion = None,
                 optimizer = None,
                 optimizer_kwargs = None,
                 task = "binary",
                 num_classes = 1,
                 weight = True,
                 dataset_class = None,
                 ds_args = [],
                 ds_kwargs = {},
                 lengths = [1.0],
                 dataloader_class = None,
                 train_batch_size = 64,
                 val_batch_size = 64,
                 test_batch_size = 64,
                 num_workers = 4
                ):
        super(PLModel, self).__init__()
        self.criterion = criterion if criterion is not None else F.binary_cross_entropy
        self.optimizer = optimizer
        self.optimizer_kwargs = optimizer_kwargs
        self.task = task
        if self.task!='binary': raise TypeError('PLModel: Only binary classification implemented so far')
        self.num_classes = num_classes #NOTE: FOR BCELoss SHOULD HAVE NUM_CLASSES=1.
        self.weight = weight #NOTE: Whether or not to use loss weighting on batch basis
        self.dataset_class = dataset_class
        self.ds_args = ds_args
        self.ds_kwargs = ds_kwargs
        self.lengths = lengths
        self.dataloader_class = dataloader_class
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.test_batch_size = test_batch_size
        self.num_workers = num_workers
        
        # Init random class attributes
        self.dataset = None
        self.ds_train = None
        self.ds_val = None
        self.ds_test = None
        
        self.train_accuracy = Accuracy(task=self.task, num_classes=self.num_classes)
        self.val_accuracy = Accuracy(task=self.task, num_classes=self.num_classes)
        self.test_accuracy = Accuracy(task=self.task, num_classes=self.num_classes)
        
        self.model = model_class(*model_class_args,**model_class_kwargs)

    def training_step(self, batch, batch_idx):
        x = torch.squeeze(self.model(batch.x, batch.edge_index, batch.batch))
        counts = torch.pow(torch.unique(batch.y,return_counts=True)[1] / len(batch.y), -1) if self.weight else None
        weight = torch.tensor([counts[idx] for idx in torch.squeeze(batch.y)]).to(x.device) if self.weight else None #NOTE: THIS ONLY WORKS FOR BINARY CLASSIFICATION WITH BCELOSS
        loss = self.criterion(x, batch.y.float(), weight=weight)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', self.train_accuracy, prog_bar=True, batch_size=self.train_batch_size)
        return loss
    
    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        x = torch.squeeze(self.model(batch.x, batch.edge_index, batch.batch))
        counts = torch.pow(torch.unique(batch.y,return_counts=True)[1] / len(batch.y), -1) if self.weight else None
        weight = torch.tensor([counts[idx] for idx in torch.squeeze(batch.y)]).to(x.device) if self.weight else None #NOTE: THIS ONLY WORKS FOR BINARY CLASSIFICATION WITH BCELOSS
        loss = self.criterion(x, batch.y.float(), weight=weight)
        preds = x.round() #NOTE: ONLY USE FOR BINARY CLASSIFICATION
        self.val_accuracy.update(preds, batch.y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True, batch_size=self.val_batch_size)
        self.log('val_acc', self.val_accuracy, prog_bar=True, batch_size=self.val_batch_size)
        return loss
    
    @torch.no_grad()
    def test_step(self, batch, batch_nb):
        x = torch.squeeze(self.model(batch.x, batch.edge_index, batch.batch))
        counts = torch.pow(torch.unique(batch.y,return_counts=True)[1] / len(batch.y), -1) if self.weight else None
        weight = torch.tensor([counts[idx] for idx in torch.squeeze(batch.y)]).to(x.device) if self.weight else None #NOTE: THIS ONLY WORKS FOR BINARY CLASSIFICATION WITH BCELOSS
        loss = self.criterion(x, batch.y.float(), weight=weight)
        preds = x.round() #NOTE: ONLY USE FOR BINARY CLASSIFICATION
        self.val_accuracy.update(preds, batch.y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('test_loss', loss, prog_bar=True, batch_size=self.test_batch_size)
        self.log('test_acc', self.test_accuracy, prog_bar=True, batch_size=self.test_batch_size)
        return loss

    def configure_optimizers(self):
        if self.optimizer is not None and self.optimizer_kwargs is not None:
            return self.optimizer(self.parameters(), **self.optimizer_kwargs)
        else:
            return torch.optim.Adam(self.parameters(), lr=0.01)
            
    def prepare_data(self): #NOTE: DO NOT MAKE ANY STATE ASSIGNMENTS HERE, JUST DOWNLOAD THE DATA IF NEEDED
        pass

    def setup(self, stage=None): #NOTE: THIS RUNS ACROSS ALL GPUS
        # Assign train/val/test datasets for use in dataloaders
        if self.dataset is None:
            self.dataset = self.dataset_class(*self.ds_args,**self.ds_kwargs) #NOTE: NEEDED KWARGS datasetclass ds_args, ds_kwargs, lengths
        if len(self.lengths)==2 and self.ds_train is None and self.ds_val is None:
            self.ds_train, self.ds_val = random_split(self.dataset, self.lengths)
        elif len(self.lengths)==3 and self.ds_train is None and self.ds_val is None:
            self.ds_train, self.ds_val, self.ds_test = random_split(self.dataset, self.lengths)

    def train_dataloader(self):
        return self.dataloader_class(self.ds_train, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers) #NOTE: NEEDED KWARGS dataloader_class train_batch_size val test...

    def val_dataloader(self):
        return self.dataloader_class(self.ds_val, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        return self.dataloader_class(self.ds_test, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers)


In [None]:
# Define pytorch lightning model with all the training/validation parameters
transform = T.Compose([T.ToUndirected(),T.NormalizeFeatures()])
plmodel = PLModel(
         model_class = GNN,
         model_class_args = [],
         model_class_kwargs = {'in_channels':dataset.num_node_features,'hidden_channels':64,'out_channels':1},
         criterion = torch.nn.functional.binary_cross_entropy,
         optimizer = torch.optim.Adam,
         optimizer_kwargs = {'lr':0.01},
         task = 'binary',
         num_classes = 1,
         weight = True,
         dataset_class = Dataset,
         ds_args = ['/work/clas12/users/mfmce/pyg_datasets/'],
         ds_kwargs = {'transform':transform, 'pre_transform':None, 'pre_filter':None},
         lengths = [0.8,0.1,0.1],
         dataloader_class = DataLoader,
         train_batch_size = 16,
         val_batch_size = 16,
         test_batch_size = 16,
         num_workers = 4
        )

# Sanity check
print(type(plmodel))
print(type(plmodel.model))

In [None]:
# Train model
pl.seed_everything(72, workers=True)
use_mlflow = False
mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs") if use_mlflow else None
trainer = Trainer(
    default_root_dir="./", #NOTE: PL AUTOMATICALLY SAVES PL CHECKPOINT TO PWD UNLESS THIS OPTION IS DIFFERENT
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,
    max_epochs=3,
    callbacks=[TQDMProgressBar(refresh_rate=20)],
    logger=CSVLogger(save_dir="logs/") if mlflow_logger is None else mlf_logger,
    deterministic=True, #NOTE: For reproducibility use pytorch_lightning.seed_everything and this
    logger=mlf_logger
)
trainer.fit(plmodel)

# Test model - pl automatically saves best and last checkpoints
trainer.test(ckpt='best')

# save for use in production environment
script = plmodel.to_torchscript() #NOTE: Different method for pl
torch.jit.save(script, "model.pt")

In [None]:
!ls -lrth | tail