[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ellisalicante/GraphRewiring-Tutorial/blob/main/3-Inductive-Rewiring-CTLayer.ipynb)
# Inductive rewiring using CT-Layer
***Tutorial on Graph Rewiring: From Theory to Applications in Fairness***

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ellisalicante/GraphRewiring-Tutorial/blob/main/3-Inductive-Rewiring-CTLayer.ipynb)

In [30]:
COLLAB_ENV = False

In [31]:
import os
import torch

In [32]:
if COLLAB_ENV:
    !git clone https://github.com/ellisalicante/GraphRewiring-Tutorial
    !cd GraphRewiring-Tutorial && git submodule update --init --recursive
    !mv GraphRewiring-Tutorial/* ./
    !rm -rf GraphRewiring-Tutorial
    
    os.environ['TORCH'] = torch.__version__
    print(torch.__version__)
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
    !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [33]:
import sys
sys.path.append("./DiffWire")

In [34]:
device="cuda"

## Graph Classification

### CT-Layer

For using it straightforward from the **[DiffWire repository](https://github.com/AdrianArnaiz/DiffWire)**:
```python
from DiffWire.layers.CT_layer import dense_CT_rewiring
```
**However, for the sake of clarity of the tutorial, we will explain the content of that function line by line**

<a href ="https://paperswithcode.com/method/ct-layer"> <img src="https://production-media.paperswithcode.com/methods/305a898a-e0a2-4d74-b8e8-c12839496577.png" alt="CT Layer" style="width:500px;"/> </a>

In [35]:
from DiffWire.layers.utils.ein_utils import _rank3_diag, _rank3_trace

def dense_CT_rewiring(x, adj, s, mask=None, EPS=1e-15):
    """Rewires a Graph using CT Distance (Effective Resistances) given 's' as the CT Embedding. 
    Returns the new adjacency, and the loss for the CT Embbeding (s).

    Args:
        x (dense): feature matrix: NxF
        adj (dense): dense adjacency matrix: NxN
        s (dense): CT Embedding: NxH (H: size of latent space)
        mask (dense): dense mask of batches
        EPS (float): epsilon to avoid nans

    Returns:
        adj: new adjacency = CTdist/vol(G)
        loss: Cut Loss for CT Embedding (s)
        ortho_loss: Loss regularization orthogonality in CT Embedding (s)
    """
    x = x.unsqueeze(0) if x.dim() == 2 else x # adj torch.Size([b, N, f])
    adj = adj.unsqueeze(0) if adj.dim() == 2 else adj # adj torch.Size([b, N, N]) 
    s = s.unsqueeze(0) if s.dim() == 2 else s # s torch.Size([b, N, k])
    
    s = torch.tanh(s) # torch.Size([20, N, k]) One k for each N of each graph
    
    # batck masking
    (batch_size, num_nodes, _), k = x.size(), s.size(-1)
    if mask is not None:
        mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
        x, s = x * mask, s * mask 

    # CT regularization
    # Calculate degree d_flat and degree matrix d
    d_flat = torch.einsum('ijk->ij', adj) # torch.Size([b, N]) 
    d = _rank3_diag(d_flat)+EPS  # d torch.Size([b, N, N])
    
    # Calculate CT_dist (distance matrix)
    CT_dist = torch.cdist(s,s) # [20, N, k], [20, N, k]-> [20,N,N]

    ## Calculate Vol (volumes): one per graph 
    vol = _rank3_trace(d) # torch.Size([20]) 

    ## Calculate out_adj as CT_dist/vol(G)
    N = adj.size(1)
    CT_dist = (CT_dist) / vol.unsqueeze(1).unsqueeze(1)

    ## Mask with adjacency
    adj = CT_dist*adj
    
    
    # Losses
    ## Calculate Laplacian L = D - A 
    L = d - adj
    
    ## Calculate out_adj as A_CT = S.T*L*S
    out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), L), s) #[b, k, N]*[b, N, N]-> [b, k ,N]*[b, N, k] = [20, k, k]
    
    ## Calculate CT_num 
    CT_num = _rank3_trace(out_adj) # mincut_num torch.Size([b]) one sum over each graph

    ## Calculate CT_den 
    CT_den = _rank3_trace(
        torch.matmul(torch.matmul(s.transpose(1, 2), d ), s))+EPS # [b, k, N]*[b, N, N]->[b, k, N]*[b, N, k] -> [b] one sum over each graph

    CT_loss = CT_num / CT_den
    CT_loss = torch.mean(CT_loss) # Mean over batch!
    
    ## Orthogonality regularization.
    ss = torch.matmul(s.transpose(1, 2), s)  #[b, k, N]*[b, N, k]-> [b, k, k]

    i_s = torch.eye(k).type_as(ss) # [k, k]
    ortho_loss = torch.norm(
        ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -
        i_s)
    ortho_loss = torch.mean(ortho_loss) # Mean over batch!
    
    return adj, CT_loss, ortho_loss

**Use $\mathtt{CT-Layer}$ for Graph Classification**
<img src="figs/ctnetwork.png" alt="CT network" style="width:300px;"/> </a>

In [36]:
class CTNet(torch.nn.Module):
    def __init__(self, in_channels, out_channels, k_centers, hidden_channels=32, EPS=1e-15):
        super(CTNet, self).__init__()
        
        self.EPS=EPS
        #Message Passing Layers
        self.conv1 = DenseGraphConv(hidden_channels, hidden_channels)
        self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
        
        # Pooling for CT embedding
        num_of_centers1 =  k_centers # k1 #order of number of nodes
        self.pool1 = Linear(hidden_channels, num_of_centers1)
        
        # Pooling for MinCut Layer
        num_of_centers2 =  16 # k2 #mincut 
        self.pool2 = Linear(hidden_channels, num_of_centers2) 

        # MLPs towards out 
        self.lin1 = Linear(in_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, hidden_channels)
        self.lin3 = Linear(hidden_channels, out_channels)
 

    def forward(self, x, edge_index, batch):    # x torch.Size([N, N]),  data.batch  torch.Size([661])  
        # Make all matrices dense
        adj = to_dense_adj(edge_index, batch)   # adj torch.Size(B, N, N])
        x, mask = to_dense_batch(x, batch)      

        #First layer: linear MLP
        x = self.lin1(x) 
        
        if torch.isnan(adj).any():
              print("adj nan")
        if torch.isnan(x).any():
              print("x nan")
        
        # CT REWIRING
        s1  = self.pool1(x)
        #adj = torch.Size([b, N, h]) --> CT Embedding
        adj, CT_loss, ortho_loss1 = dense_CT_rewiring(x, adj, s1, mask, EPS = self.EPS) 
        #adj = torch.Size([b, N, N]) --> CT Distances
        

        # CONV1: Now on x and rewired adj: 
        x = self.conv1(x, adj) #out: x torch.Size([20, N, F'=32])

        # MINCUT_POOL
        # MLP of k=16 outputs s
        s2 = self.pool2(x) # s torch.Size([20, N, k])
        
        # Call to dense_cut_mincut_pool to get coarsened x, adj and the losses: k=16
        x, adj, mincut_loss2, ortho_loss2 = dense_mincut_pool(x, adj, s2, mask, EPS=self.EPS) # out x torch.Size([20, k=16, F'=32]),  adj torch.Size([20, k2=16, k2=16])

        # CONV2: Now on coarsened x and adj: 
        x = self.conv2(x, adj) #out x torch.Size([20, 16, 32])
        
        # Readout for each of the 20 graphs
        x = x.sum(dim=1) 
        
        # Final MLP for graph classification: hidden channels = 32
        x = F.relu(self.lin2(x)) 
        x = self.lin3(x) 
        
        #loss functions
        CT_loss = CT_loss + ortho_loss1
        mincut_loss = mincut_loss2 + ortho_loss2
        
        return F.log_softmax(x, dim=-1), CT_loss, mincut_loss


In [37]:
def train(epoch, loader):
    model.train()
    loss_all = 0
    correct = 0
    #i = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out, mc_loss, o_loss = model(data.x, data.edge_index, data.batch) # data.batch  torch.Size([783])
        loss = F.nll_loss(out, data.y.view(-1)) + mc_loss + o_loss
        loss.backward()
        loss_all += data.y.size(0) * loss.item()
        optimizer.step()
        correct += out.max(dim=1)[1].eq(data.y.view(-1)).sum().item() #accuracy in train AFTER EACH BACH
    return loss_all / len(loader.dataset), correct / len(loader.dataset)

@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        pred, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(pred, data.y.view(-1)) + mc_loss + o_loss
        correct += pred.max(dim=1)[1].eq(data.y.view(-1)).sum().item()

    return loss, correct / len(loader.dataset)

In [39]:
from transforms import FeatureDegree
dataset = TUDataset(root='data',name="REDDIT-BINARY", transform = FeatureDegree())

LinkerError: [222] Call to cuLinkAddData results in UNKNOWN_CUDA_ERROR
ptxas application ptx input, line 9; fatal   : Unsupported .version 7.3; current version is '7.1'

In [28]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

TRAIN_SPLIT = 1500
BATCH_SIZE = 64

train_dataset = dataset[:TRAIN_SPLIT]
test_dataset = dataset[TRAIN_SPLIT:]
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

SyntaxError: invalid syntax (3246366582.py, line 2)

In [29]:
num_of_centers = 200
EPS = 1e-15

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)  #
model = CTNet(dataset.num_features, dataset.num_classes, k_centers=num_of_centers, EPS=EPS).to(device)

NameError: name 'dataset' is not defined

In [None]:
for epoch in range(1, 100):
    train_loss, _ = train(epoch, train_loader)
    _, train_acc = test(train_loader)
    _, test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

## Node classification