In [2]:
import torch_scatter
from torch_scatter import scatter_log_softmax, scatter_max, scatter_sum


import torch_geometric
import torch_geometric.data as gd
import torch_geometric.nn as gnn
from torch_geometric.utils import from_smiles

import torch
from torch.distributions import Categorical
import torch.nn as nn

print("torch\t", torch.__version__)
print("torch_scatter\t", torch_scatter.__version__)
print("torch_geometric\t", torch_geometric.__version__)

torch	 2.0.1
torch_scatter	 2.1.1
torch_geometric	 2.3.1


# Scatter Operation Using `torch_scatter`


Unlike images, text and audio, graphs usually have irregular structures, which makes them hard to batch in tensor frameworks. Many existing implementations use padding to convert graphs into dense grid structures, which costs much unnecessary computation and memory.


With `torch_scatter`, this notebook will show how we can deal with variadic inputs.


See the figure below to see how `torch_scatter` works.


<img width="50%" src="https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true" style="background-color:white;padding:20px;">

In [3]:
import torch
from torch_scatter import scatter_sum

index  = torch.LongTensor([0,0,1,0,2,2,3,3])
input_ = torch.LongTensor([5,1,7,2,3,2,1,3])

output = scatter_sum(input_, index)
output

tensor([8, 7, 5, 4])

In [10]:
# torch_geometric.nn.global_add_pool gives the same functionality

gnn.global_add_pool(input_, index)

tensor([8, 7, 5, 4])

# Variable-sized Logits

In [5]:

class GCNPolicy(nn.Module):
    def __init__(self, input_dim, emb_dim=64):
        super().__init__()
        self.conv1 = gnn.GCNConv(input_dim, emb_dim)
        self.conv2 = gnn.GCNConv(emb_dim, emb_dim)
        
        self.glob_mlp = nn.Linear(emb_dim, 1)
        self.node_mlp = nn.Linear(emb_dim, 1)
        self.edge_mlp = nn.Linear(emb_dim, 1)
        
    def logits(self, g):
        x, edge_index = g.x.float(), g.edge_index
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        glob = gnn.global_add_pool(x, g.batch)
        
        i, j = edge_index
        edge_feature = x[i] + x[j]
        
        glob_logits = self.glob_mlp(glob).flatten()
        node_logits = self.node_mlp(x).flatten()
        edge_logits = self.edge_mlp(edge_feature).flatten()
        
        logits = torch.cat([glob_logits, node_logits, edge_logits])
        return logits


d1 = from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") # caffeine
d2 = from_smiles("CC(=O)NC1=CC=C(C=C1)O") # acetaminophen

data_list = [d1, d2]
g = gd.Batch.from_data_list(data_list)
g.num_edges = torch.LongTensor([d.num_edges for d in data_list])


gcn = GCNPolicy(g.x.shape[1])
logits = gcn.logits(g)
 
logits.shape

torch.Size([79])

In [8]:
glob_batch = torch.arange(g.num_graphs)
node_batch = g.batch
edge_batch = torch.repeat_interleave(g.num_edges)

indices = torch.cat([glob_batch, node_batch, edge_batch])


indices

tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1])

In [13]:
# We can calculate log probabilities for each logit using scatter_log_softmax. 
# But how can we get target probabilities that correspond to labels?

y = torch.LongTensor([2, 13])
scatter_log_softmax(logits, indices)

tensor([-8.1225, -6.6523, -3.7698, -3.8608, -3.8230, -3.8196, -3.8377, -3.8532,
        -3.8855, -3.8294, -3.8896, -3.8974, -3.8367, -3.8742, -3.7648, -3.7648,
        -3.5090, -3.6289, -3.5678, -3.5616, -3.5614, -3.5216, -3.5402, -3.6160,
        -3.5402, -3.5216, -3.5569, -3.7128, -3.7128, -3.7558, -3.7766, -3.7558,
        -3.7646, -3.7646, -3.7853, -3.7853, -3.7949, -3.7774, -3.7766, -3.7949,
        -3.7953, -3.7953, -3.7666, -3.7776, -3.7666, -3.7776, -3.7694, -3.7126,
        -3.7694, -3.7593, -3.7696, -3.7593, -3.7774, -3.7696, -3.7128, -3.7128,
        -3.7126, -3.4040, -3.4040, -3.4231, -3.4430, -3.4231, -3.4430, -3.4777,
        -3.4777, -3.4923, -3.4923, -3.4923, -3.4783, -3.4783, -3.5013, -3.5013,
        -3.5013, -3.4821, -3.5013, -3.4783, -3.4923, -3.4783, -3.4821],
       grad_fn=<SubBackward0>)

In [14]:
# We first sort the indices and logits
sorted_indices, mapping = torch.sort(indices, stable=True)
sorted_logits = logits[mapping]

sorted_indices

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1])

In [15]:
# We get target probabilities by adjusting indices
count = torch.bincount(sorted_indices)
offsets = torch.cumsum(count, 0) - count

y + offsets

tensor([ 2, 58])

In [16]:
log_probs = scatter_log_softmax(sorted_logits, sorted_indices)
log_probs[y + offsets]

tensor([-3.8608, -3.4040], grad_fn=<IndexBackward0>)

# Gumbel-max trick

Suppose we would like to obtain samples from multiple multinomial distributions, parameters given by variable-sized logits.

We can sample without even calculating softmax using Gumbel-max trick.

```
u ~ uniform(len(logits))
X ~ argmax(logits -log(-log(u)))
```

The noise added to logits, `-log(-log(u))`, follows Gumbel distribution, which gives the name.

*Note*

CDF of Gumbel distribution: 

$$ 
F(x) = \exp\{- \exp\{-x\}\}.
$$

Inverse CDF of Gumbel distribution: 

$$
F^{-1}(y) = -\log(-\log y).
$$



Proof: https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/

In [17]:
# We can sample from logits using Gumbel-max trick

logits = torch.FloatTensor([0, 1, 10, 2, 1, 0])
indices = torch.LongTensor([1, 1,  1, 1, 0,  0])


indices, mapping = torch.sort(indices, stable=True)
logits = logits[mapping]

unif = torch.rand_like(logits)
gumbel = -(-unif.log()).log()
_, max_indices = scatter_max(logits + gumbel, indices)

count = torch.bincount(indices)
offsets = torch.cumsum(count, 0) - count
samples =  max_indices - offsets

samples

tensor([1, 2])

# Question
    
1. `torch.rand_like(logits)`와 `torch.rand(len(logits))`와의 차이는??