In [1]:
# !pip install torch_geometric==2.3.1

In [2]:
import argparse
import os.path as osp
from typing import Any, Dict, Optional

import torch
from torch.nn import (
    BatchNorm1d,
    Embedding,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_add_pool
import inspect
from typing import Any, Dict, Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch

from mamba_ssm import Mamba
from torch_geometric.utils import degree, sort_edge_index

In [3]:
def permute_within_batch(x, batch):
    # Enumerate over unique batch indices
    unique_batches = torch.unique(batch)
    
    # Initialize list to store permuted indices
    permuted_indices = []

    for batch_index in unique_batches:
        # Extract indices for the current batch
        indices_in_batch = (batch == batch_index).nonzero().squeeze()
        
        # Permute indices within the current batch
        permuted_indices_in_batch = indices_in_batch[torch.randperm(len(indices_in_batch))]
        
        # Append permuted indices to the list
        permuted_indices.append(permuted_indices_in_batch)
    
    # Concatenate permuted indices into a single tensor
    permuted_indices = torch.cat(permuted_indices)

    return permuted_indices

In [4]:
# path, subset = '/scratch/ssd004/scratch/tsepaole/ZINC_full/', False
path, subset = '', True

transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
train_dataset = ZINC(path, subset=subset, split='train', pre_transform=transform)
val_dataset = ZINC(path, subset=subset, split='val', pre_transform=transform)
test_dataset = ZINC(path, subset=subset, split='test', pre_transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

In [5]:
class GPSConv(torch.nn.Module):

    def __init__(
        self,
        channels: int,
        conv: Optional[MessagePassing],
        heads: int = 1,
        dropout: float = 0.0,
        attn_dropout: float = 0.0,
        act: str = 'relu',
        att_type: str = 'transformer',
        order_by_degree: bool = False,
        shuffle_ind: int = 0,
        d_state: int = 16,
        d_conv: int = 4,
        act_kwargs: Optional[Dict[str, Any]] = None,
        norm: Optional[str] = 'batch_norm',
        norm_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()

        self.channels = channels
        self.conv = conv
        self.heads = heads
        self.dropout = dropout
        self.att_type = att_type
        self.shuffle_ind = shuffle_ind
        self.order_by_degree = order_by_degree
        
        assert (self.order_by_degree==True and self.shuffle_ind==0) or (self.order_by_degree==False), f'order_by_degree={self.order_by_degree} and shuffle_ind={self.shuffle_ind}'
        
        if self.att_type == 'transformer':
            self.attn = torch.nn.MultiheadAttention(
                channels,
                heads,
                dropout=attn_dropout,
                batch_first=True,
            )
        if self.att_type == 'mamba':
            self.self_attn = Mamba(
                d_model=channels,
                d_state=d_state,
                d_conv=d_conv,
                expand=1
            )
            
        self.mlp = Sequential(
            Linear(channels, channels * 2),
            activation_resolver(act, **(act_kwargs or {})),
            Dropout(dropout),
            Linear(channels * 2, channels),
            Dropout(dropout),
        )

        norm_kwargs = norm_kwargs or {}
        self.norm1 = normalization_resolver(norm, channels, **norm_kwargs)
        self.norm2 = normalization_resolver(norm, channels, **norm_kwargs)
        self.norm3 = normalization_resolver(norm, channels, **norm_kwargs)

        self.norm_with_batch = False
        if self.norm1 is not None:
            signature = inspect.signature(self.norm1.forward)
            self.norm_with_batch = 'batch' in signature.parameters

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        if self.conv is not None:
            self.conv.reset_parameters()
        self.attn._reset_parameters()
        reset(self.mlp)
        if self.norm1 is not None:
            self.norm1.reset_parameters()
        if self.norm2 is not None:
            self.norm2.reset_parameters()
        if self.norm3 is not None:
            self.norm3.reset_parameters()

    def forward(
        self,
        x: Tensor,
        edge_index: Adj,
        batch: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tensor:
        r"""Runs the forward pass of the module."""
        hs = []
        if self.conv is not None:  # Local MPNN.
            h = self.conv(x, edge_index, **kwargs)
            h = F.dropout(h, p=self.dropout, training=self.training)
            h = h + x
            if self.norm1 is not None:
                if self.norm_with_batch:
                    h = self.norm1(h, batch=batch)
                else:
                    h = self.norm1(h)
            hs.append(h)

        ### Global attention transformer-style model.
        if self.att_type == 'transformer':
            h, mask = to_dense_batch(x, batch)
            h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False)
            h = h[mask]
            
        if self.att_type == 'mamba':
            
            if self.order_by_degree:
                deg = degree(edge_index[0], x.shape[0]).to(torch.long)
                order_tensor = torch.stack([batch, deg], 1).T
                _, x = sort_edge_index(order_tensor, edge_attr=x)
                
            if self.shuffle_ind == 0:
                h, mask = to_dense_batch(x, batch)
                h = self.self_attn(h)[mask]
            else:
                mamba_arr = []
                for _ in range(self.shuffle_ind):
                    h_ind_perm = permute_within_batch(x, batch)
                    h_i, mask = to_dense_batch(x[h_ind_perm], batch)
                    h_i = self.self_attn(h_i)[mask][h_ind_perm]
                    mamba_arr.append(h_i)
                h = sum(mamba_arr) / self.shuffle_ind
        ###
        
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = h + x  # Residual connection.
        if self.norm2 is not None:
            if self.norm_with_batch:
                h = self.norm2(h, batch=batch)
            else:
                h = self.norm2(h)
        hs.append(h)

        out = sum(hs)  # Combine local and global outputs.

        out = out + self.mlp(out)
        if self.norm3 is not None:
            if self.norm_with_batch:
                out = self.norm3(out, batch=batch)
            else:
                out = self.norm3(out)

        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.channels}, '
                f'conv={self.conv}, heads={self.heads})')

In [6]:
class GraphModel(torch.nn.Module):
    def __init__(self, channels: int, pe_dim: int, num_layers: int, model_type: str, shuffle_ind: int, d_state: int, d_conv: int, order_by_degree: False):
        super().__init__()

        self.node_emb = Embedding(28, channels - pe_dim)
        self.pe_lin = Linear(20, pe_dim)
        self.pe_norm = BatchNorm1d(20)
        self.edge_emb = Embedding(4, channels)
        self.model_type = model_type
        self.shuffle_ind = shuffle_ind
        self.order_by_degree = order_by_degree
        
        self.convs = ModuleList()
        for _ in range(num_layers):
            nn = Sequential(
                Linear(channels, channels),
                ReLU(),
                Linear(channels, channels),
            )
            if self.model_type == 'gine':
                conv = GINEConv(nn)
                
            if self.model_type == 'mamba':
                conv = GPSConv(channels, GINEConv(nn), heads=4, attn_dropout=0.5,
                               att_type='mamba',
                               shuffle_ind=self.shuffle_ind,
                               order_by_degree=self.order_by_degree,
                               d_state=d_state, d_conv=d_conv)
                
            if self.model_type == 'transformer':
                conv = GPSConv(channels, GINEConv(nn), heads=4, attn_dropout=0.5, att_type='transformer')
                
            # conv = GINEConv(nn)
            self.convs.append(conv)

        self.mlp = Sequential(
            Linear(channels, channels // 2),
            ReLU(),
            Linear(channels // 2, channels // 4),
            ReLU(),
            Linear(channels // 4, 1),
        )

    def forward(self, x, pe, edge_index, edge_attr, batch):
        x_pe = self.pe_norm(pe)
        x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)
        edge_attr = self.edge_emb(edge_attr)

        for conv in self.convs:
            if self.model_type == 'gine':
                x = conv(x, edge_index, edge_attr=edge_attr)
            else:
                x = conv(x, edge_index, batch, edge_attr=edge_attr)
                
        x = global_add_pool(x, batch)
        return self.mlp(x)

In [7]:
def train():
    model.train()
    
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.pe, data.edge_index, data.edge_attr,
                    data.batch)
        loss = (out.squeeze() - data.y).abs().mean()
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)

In [8]:
@torch.no_grad()
def test(loader):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        # print(data.x.shape)
        out = model(data.x, data.pe, data.edge_index, data.edge_attr,
                    data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)

In [9]:
# it.to(device)

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GraphModel(channels=64, pe_dim=8, num_layers=10,
                   model_type='mamba',
                   shuffle_ind=0, order_by_degree=True,
                   d_conv=4, d_state=16,
                  ).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)
arr = []
for epoch in range(1, 30):
    loss = train()
    val_mae = test(val_loader)
    test_mae = test(test_loader)
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')
    arr.append(test_mae)
ordering = arr
print(ordering)

Epoch: 01, Loss: 0.7204, Val: 0.7170, Test: 0.7365
Epoch: 02, Loss: 0.6121, Val: 0.5867, Test: 0.6100
Epoch: 03, Loss: 0.5747, Val: 0.7788, Test: 0.7760
Epoch: 04, Loss: 0.5386, Val: 0.5892, Test: 0.5950
Epoch: 05, Loss: 0.5261, Val: 0.5417, Test: 0.5383
Epoch: 06, Loss: 0.4998, Val: nan, Test: 0.6481
Epoch: 07, Loss: 0.4902, Val: 0.4364, Test: 0.4555
Epoch: 08, Loss: 0.4680, Val: 0.5670, Test: 0.6019
Epoch: 09, Loss: 0.4335, Val: 0.4619, Test: 0.4543
Epoch: 10, Loss: 0.4447, Val: 0.3931, Test: 0.4147
Epoch: 11, Loss: 0.4289, Val: 0.4258, Test: 0.4431
Epoch: 12, Loss: 0.4220, Val: 0.4031, Test: 0.4058
Epoch: 13, Loss: 0.4112, Val: 0.4105, Test: 0.4198
Epoch: 14, Loss: 0.4066, Val: 0.4957, Test: 0.4886
Epoch: 15, Loss: 0.3990, Val: 0.4142, Test: 0.3974
Epoch: 16, Loss: 0.3858, Val: 0.5763, Test: nan
Epoch: 17, Loss: 0.3757, Val: 0.3922, Test: nan
Epoch: 18, Loss: 0.3749, Val: 0.3993, Test: 0.4512
Epoch: 19, Loss: 0.3724, Val: nan, Test: nan
Epoch: 20, Loss: 0.3494, Val: 0.3795, Test: na

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GraphModel(channels=64, pe_dim=8, num_layers=10,
                   model_type='mamba',
                   shuffle_ind=1, order_by_degree=False,
                   d_conv=4, d_state=16,
                  ).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)
arr = []
for epoch in range(1, 30):
    loss = train()
    val_mae = test(val_loader)
    test_mae = test(test_loader)
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')
    arr.append(test_mae)
permute = arr
print(permute)

Epoch: 01, Loss: 0.6478, Val: 0.5409, Test: 0.5737
Epoch: 02, Loss: 0.5205, Val: 0.4522, Test: 0.4622
Epoch: 03, Loss: 0.4889, Val: 0.5605, Test: 0.5807
Epoch: 04, Loss: 0.4440, Val: 0.3877, Test: 0.3950
Epoch: 05, Loss: 0.4151, Val: 0.4781, Test: 0.4825
Epoch: 06, Loss: 0.4200, Val: 0.3819, Test: 0.3898
Epoch: 07, Loss: 0.3929, Val: 0.4256, Test: 0.4256
Epoch: 08, Loss: 0.3695, Val: 0.3649, Test: 0.3617
Epoch: 09, Loss: 0.3680, Val: 0.4223, Test: 0.3876
Epoch: 10, Loss: 0.3569, Val: 0.3277, Test: 0.3216


In [None]:
# import matplotlib.pyplot as plt
# import pandas as pd

# import numpy as np
# res_df = pd.read_csv('30_ep_res.csv')

# WINDOW = 1
# fig, ax = plt.subplots(1, figsize=(15,5))

# for col in res_df.columns:
#     plt.plot(res_df[col].clip(0,0.7).rolling(WINDOW, min_periods=1).mean(), label=col)

# plt.legend()
# plt.show()

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, figsize=(15,5))

plt.plot(permute[:20], label='permute')
plt.plot(ordering[:20], label='order')

plt.legend()
plt.show()

In [None]:
fffffffffff

# Plotting

# Tests

In [None]:
ggggggggggg

In [None]:
it = next(iter(train_loader))
# h, mask = to_dense_batch(it.x, it.batch)
# it.x.shape, h.shape, mask.shape

In [None]:
deg.dtype

In [None]:
it

In [None]:
it.edge_index[0]

In [None]:
it.to(device)
out = model(it.x, it.pe, it.edge_index, it.edge_attr,
                    it.batch)

In [None]:
batch = torch.tensor([0,0,0,1,1,1,1])
x = torch.tensor([0,1,2,3,4,5,6])
batch.shape, x.shape

In [None]:
import torch

def permute_within_batch(x, batch):
    # Enumerate over unique batch indices
    unique_batches = torch.unique(batch)
    
    # Initialize list to store permuted indices
    permuted_indices = []

    for batch_index in unique_batches:
        # Extract indices for the current batch
        indices_in_batch = (batch == batch_index).nonzero().squeeze()
        
        # Permute indices within the current batch
        permuted_indices_in_batch = indices_in_batch[torch.randperm(len(indices_in_batch))]
        
        # Append permuted indices to the list
        permuted_indices.append(permuted_indices_in_batch)

    # Concatenate permuted indices into a single tensor
    permuted_indices = torch.cat(permuted_indices)

    return permuted_indices

# Example usage
batch = torch.tensor([0, 0, 0, 1, 1, 1, 1])
x = torch.tensor([0, 10, 20, 30, 40, 50, 60])

# Get permuted indices
permuted_indices = permute_within_batch(x, batch)

# Use permuted indices to get the permuted tensor
permuted_x = x[permuted_indices]

print("Original x:", x)
print("Permuted x:", permuted_x)
print("Permuted indices:", permuted_indices)


In [None]:
mask[0].sum(), (it.batch==0).sum()

In [None]:
self_attn = Mamba(d_model=64, # Model dimension d_model
                                d_state=16,  # SSM state expansion factor
                                d_conv=4,    # Local convolution width
                                expand=1,    # Block expansion factor
                            )
print(sum(p.numel() for p in self_attn.parameters() if p.requires_grad), sum(p.numel() for p in self_attn.parameters()))

In [None]:
self_attn = Mamba(d_model=64, # Model dimension d_model
                                d_state=8,  # SSM state expansion factor
                                d_conv=2,    # Local convolution width
                                expand=1,    # Block expansion factor
                            )
print(sum(p.numel() for p in self_attn.parameters() if p.requires_grad), sum(p.numel() for p in self_attn.parameters()))

In [None]:
self_attn = Mamba(d_model=64, # Model dimension d_model
                                d_state=16,  # SSM state expansion factor
                                d_conv=8,    # Local convolution width
                                expand=1,    # Block expansion factor
                            )
print(sum(p.numel() for p in self_attn.parameters() if p.requires_grad), sum(p.numel() for p in self_attn.parameters()))

In [None]:
self_attn = torch.nn.MultiheadAttention(
                64,
                4,
                dropout=0.5,
                batch_first=True,
            )
print(sum(p.numel() for p in self_attn.parameters() if p.requires_grad), sum(p.numel() for p in self_attn.parameters()))