<a href="https://colab.research.google.com/github/deekshaf7/Graph_transformer_notebook/blob/main/Graph_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from mydataset import *
from torch.utils.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.data import DataLoader as GeometricDataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn as nn
from tqdm import tqdm
import wandb
from torch_geometric.nn import TransformerConv, global_mean_pool
import torch.nn.functional as F

PRINT = False

# Define a Graph Transformer model for graph classification
class GraphTransformerModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, dropout):
        super(GraphTransformerModel, self).__init__()
        # Initialize convolution layers
        self.convs = nn.ModuleList()
        self.convs.append(TransformerConv(input_dim, hidden_dim, heads=num_heads, dropout=dropout))
        for _ in range(num_layers-1):
            self.convs.append(TransformerConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, dropout=dropout))
            if PRINT:
                print("conv layer size", self.convs)

        # Initialize fully connected layers
        self.fc_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.fc_layers.append(nn.Linear(hidden_dim * num_heads, hidden_dim * num_heads))

        # Initialize output linear layers
        self.fc_out_linear = nn.ModuleList()
        for _ in range(num_heads):
            self.fc_out_linear.append(nn.Linear(hidden_dim * num_heads, 3))  # Output layer to classify into 3 classes

        # Print the size of each output linear layer
        #for idx, linear_layer in enumerate(self.fc_out_linear):
        #    print(f"out_linear layer {idx} size:", linear_layer.weight.size())

        # Initialize softmax layers
        self.fc_out_softmax = nn.ModuleList()
        for _ in range(num_heads):
            self.fc_out_softmax.append(nn.Softmax(dim=1))  # Softmax layer to convert logits to probabilities

        # Initialize the fc_reduce layer (if needed)
        self.fc_reduce = nn.Linear(hidden_dim * num_heads, 12)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        if PRINT:
            print("Input x shape:", x.shape)
            print("Input edge_index shape:", edge_index.shape)
            print("Input batch shape:", batch.shape[0])
            print("batch value: ", batch)

        # Apply convolution layers
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if PRINT:
                print(f"Shape after conv {i}:", x.shape)
            x = F.relu(x)  # Apply ReLU activation after each convolution layer
            if PRINT:
                print(f"Shape after Conv ReLU {i}:", x.shape)

        #Pooling to get graph-level representation
        x = global_mean_pool(x, batch)
        if PRINT:
            print("Shape after global mean pool:", x.shape)

        # Apply fully connected layers
        for i, fc in enumerate(self.fc_layers):
            x = fc(x)
            if PRINT:
                print(f"Shape after fc {i}:", x.shape)
            x = F.relu(x)  # Apply ReLU activation after each fully connected layer
            if PRINT:
                print(f"Shape after FC ReLU {i}:", x.shape)

        # Apply output layers
        outputs = []
        for linear, softmax in zip(self.fc_out_linear, self.fc_out_softmax):
            logits = linear(x)
            if PRINT:
                print(f"Shape after logits:", logits.shape)
            prob = softmax(logits)
            if PRINT:
                print(type(prob))
            if PRINT:
                print(f"Shape after prob:", prob.shape)

            # Apply argmax to logits
            #max_indices = torch.argmax(logits, dim=-1)
            #if PRINT:
            #    print(f"Shape after argmax:", max_indices.shape)

            outputs.append(prob)
            if PRINT:
                print(f"Shape after outputs:", len(outputs))
            if PRINT:
                print(f"Output Shape: ", len(outputs))

        return outputs

# Random seed for reproducibility
seed = random.randint(0, 100000)
print(f"Seed: {seed}")
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# HYPERPARAMS
dataset = load_dataset(sys.argv[1])

H = {
        "input_dim" : 6 , # Define the input dimension based on the dataset
        "hidden_dim" : 512 ,  # Dimension of the model
        "num_heads" : 12 ,   # Number of attention heads
        "num_layers" : 2 , # Number of layers
        "dropout" : 0.1 , # Dropout rate
}

input_dim = 6  # Define the input dimension based on the dataset
hidden_dim = 512  # Dimension of the model
num_heads = 12
num_layers = 4  # Number of encoder layers
dropout = 0.1  # Dropout rate

model = GraphTransformerModel(input_dim, hidden_dim, num_layers, num_heads, dropout).to(device)

max_valid_metric = 0
batch_size = 12
learning_rate = 0.01
num_epochs = 14
size = 32
criterion = torch.nn.CrossEntropyLoss()
bar = tqdm(total=num_epochs)

# Initialize optimizer and scheduler
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=500, gamma=0.1)

def wandb_init():
    config={
    "learning_rate": learning_rate,
    "dataset": sys.argv[1],
    "epochs": num_epochs,
    "batch_size" : batch_size
    }

    config.update(H)
    wandb.init(
    # set the wandb project where this run will be logged
        project="DETECTive_transformer",
        # track hyperparameters and run metadata
        config=config
    )

def worker():

    #wandb_init()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    global batch_size, num_epochs, dataset

    # Calculate the number of data points in each set
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size

    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    kwargs = {'num_workers': 8, 'pin_memory': True} if device == "cuda" else {}

    train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, drop_last=True, **kwargs)
    val_loader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, drop_last=True, **kwargs)

    max_valid_metric = 0
    bar = tqdm(total=num_epochs)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=500, gamma=0.1)

    if (len(sys.argv) >= 3):
        tqdm.write(f"Loading state dictionary from {sys.argv[2]}")
        state_dict = torch.load(sys.argv[2])
        clean_state_dict = {}
        for key, value in state_dict['model_sd'].items():
            if key.startswith('module.'):
                clean_state_dict[key[7:]] = value
            else:
                clean_state_dict[key] = value

        model.load_state_dict(clean_state_dict)

    #wandb.watch(model)

    for epoch in range(num_epochs):
        model.train()
        epoch_losses = []
        accuracies = []
        valid_metric_train = []
        for graphs_batch, batch_act_paths, batch_prop_paths, targets_batch, fault_type_batch in train_loader:

            optimizer.zero_grad()

            batch_loss = 0
            if PRINT:
                print("graphs:", type(graphs_batch))
                print("Graphs batch:", graphs_batch)
                print("fault_type:", type(fault_type_batch))
                print("fault type batch: ", fault_type_batch)
                print("targets_type:", type(targets_batch))
                print("targets type batch: ", targets_batch)

            fault_type_batch = fault_type_batch.float().view(-1, 1)  # Ensure fault_type is float and has shape [batch_size, 1]

            # use these tensors in the model
            pred_test_pattern = model(
                graphs_batch
            )

            batch_loss = 0
            if PRINT:
                print("Fault type:", fault_type_batch)
                print("Targets:", targets_batch)
                print("Predict test pattern ", pred_test_pattern)

            target_values = []
            pred_values = []

            for i in range(len(pred_test_pattern)):
                if PRINT:
                    print(f"Predicted Test Pattern {i}: \n", pred_test_pattern[i])
                    print(f"Predicted Test Patterni Dim {i}: \n", torch.argmax(pred_test_pattern[i], dim=-1))
                    print(f"Target Pattern {i}: \n", targets_batch[i])
                # Initialize variables to accumulate the errors
                total_error = 0.0

                target_length = len(targets_batch[i])
                # Loop over the target batch values
                for j in range(target_length):
                    error = criterion(pred_test_pattern[i], torch.tensor(targets_batch[i][j]))
                    #, dtype=torch.long))
                    total_error += error

                # Calculate the average error for the target batch values
                avg_error = total_error / target_length

                # Add the average error to the batch loss
                batch_loss += avg_error

                #batch_loss += criterion(pred_test_pattern[i], torch.tensor(targets_batch[i][0], dtype=torch.long))
                accuracy = get_accuracy(torch.argmax(pred_test_pattern[i], dim=-1), targets_batch[i][0])
                valid_metric_train.append(valid_tv(torch.argmax(pred_test_pattern[i], dim=-1), targets_batch[i]))
                accuracies.append(accuracy)

                #pred_values.append(torch.argmax(pred_test_pattern[i]).item())

            epoch_losses.append(batch_loss.item())
            # batch_loss += loss

            batch_loss.backward()
            optimizer.step()

        avg_loss = np.average(epoch_losses)

        '''
        wandb.log({
            "train/Accuracy" : np.average(accuracies),
            "train/Valid Metric" : np.average(valid_metric_train),
            "train/Loss" : avg_loss
        })
        '''
        tqdm.write(f"[Epoch {epoch}] Loss: {avg_loss : .2f} Train accuracy: {np.average(accuracies) : .2f} Valid Metric: {np.average(valid_metric_train):.2f}")

        scheduler.step()

        ######################### ! VALIDATION ! ############################
        if ((epoch + 1) % 10 == 0):
            model.eval()
            valid_metric = []
            exact_val_accuracies = []

            for graphs_batch, batch_act_paths, batch_prop_paths, targets_batch, fault_type_batch in val_loader:

                with torch.no_grad():

                    # use these tensors in the model
                    pred_test_pattern = model(
                        graphs_batch
                        #fault_type_batch
                    )
                    #print(torch.round(pred_test_pattern[0]), '\n',  targets_batch[0][0])
                    for i in range(len(pred_test_pattern)):
                        accuracy = get_accuracy(torch.argmax(pred_test_pattern[i], dim=-1), targets_batch[i][0])
                        valid_metric_train.append(valid_tv(torch.argmax(pred_test_pattern[i], dim=-1), targets_batch[i]))
                        exact_val_accuracies.append(accuracy)

            valid_metric = np.average(valid_metric)
            validation_accuracy = np.average(exact_val_accuracies)

            if (valid_metric > max_valid_metric):
                tqdm.write(f"Found a new maximum, storing the model")
                max_valid_metric = valid_metric
                # Save the models
                torch.save({
                    'model_sd' : model.state_dict(),
                    'optim_sd' : optimizer.state_dict()
                }, f'model.pt')
            '''
            wandb.log({
                 "val/Validation Accuracy" : validation_accuracy,
            #     "val/Loose Validation Accuracy" : loose_validation_accuracy,
                 "val/Valid Metric" : valid_metric,
             })
            '''
            tqdm.write(f'Validation accuracy = {validation_accuracy : .2f} Average valid tv = {valid_metric : .2f}')
        bar.update()

if __name__ == "__main__":
    worker()