In [1]:
import torch_scatter
from torch_scatter import scatter_log_softmax, scatter_max


import torch_geometric
import torch_geometric.data as gd
import torch_geometric.nn as gnn
from torch_geometric.utils import from_smiles

import torch
from torch.distributions import Categorical
import torch.nn as nn

print("torch\t", torch.__version__)
print("torch_scatter\t", torch_scatter.__version__)
print("torch_geometric\t", torch_geometric.__version__)

torch	 2.0.1
torch_scatter	 2.1.1
torch_geometric	 2.3.1


# Scatter Operation Using `torch_scatter`


Unlike images, text and audio, graphs usually have irregular structures, which makes them hard to batch in tensor frameworks. Many existing implementations use padding to convert graphs into dense grid structures, which costs much unnecessary computation and memory.


With `torch_scatter`, this notebook will show how we can deal with variadic inputs.


See the figure below to see how `torch_scatter` works.


<img width="50%" src="https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true" style="background-color:white;padding:20px;">

In [2]:
import torch
from torch_scatter import scatter_sum

index  = torch.LongTensor([0,0,1,0,2,2,3,3])
input_ = torch.LongTensor([5,1,7,2,3,2,1,3])

output = scatter_sum(input_, index)
output

tensor([8, 7, 5, 4])

# torch.distributions.Categorical

We want to use something similar to the `torch.distributions.Categorical`. 

But `torch.distributions.Categorical` can only take fixed sized tensors. 

Let's see how `Categorical` works first.

In [3]:
# We will use below `g` as an example data

d1 = from_smiles("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") # caffeine
d2 = from_smiles("CC(=O)NC1=CC=C(C=C1)O") # acetaminophen

data_list = [d1, d2]
g = gd.Batch.from_data_list(data_list)
g.num_edges = torch.LongTensor([d.num_edges for d in data_list])

g

DataBatch(x=[25, 9], edge_index=[2, 52], edge_attr=[52, 3], smiles=[2], batch=[25], ptr=[3], num_edges=[2])

In [4]:
class GCNPolicy1(nn.Module):
    def __init__(self, input_dim, output_dim, emb_dim=64):
        super().__init__()
        self.conv1 = gnn.GCNConv(input_dim, emb_dim)
        self.conv2 = gnn.GCNConv(emb_dim, emb_dim)
        self.linear = nn.Linear(emb_dim, output_dim)
        
    def forward(self, g):
        x, edge_index = g.x.float(), g.edge_index
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        glob = gnn.global_add_pool(x, g.batch)
        logits = self.linear(glob)
        
        # Suppose we are to decide what type of atoms to add in the input graph.
        # Since the number of atom types are fixed, we can use `Categorical`.
        return Categorical(logits=logits)




gcn = GCNPolicy1(
    input_dim=g.x.shape[1], 
    output_dim=5
)
cat = gcn(g)
cat

Categorical(logits: torch.Size([2, 5]))

In [5]:
# we can sample from the distribution given by logits
cat.sample()

tensor([3, 3])

In [6]:
# we can get log probabilities corresponding to the target values.
y = torch.LongTensor([1, 0])
cat.log_prob(y)

tensor([-5.9723, -4.1078], grad_fn=<SqueezeBackward1>)

# Variable-sized Logits

In [7]:
class GCNPolicy2(nn.Module):
    def __init__(self, input_dim, emb_dim=64):
        super().__init__()
        self.conv1 = gnn.GCNConv(input_dim, emb_dim)
        self.conv2 = gnn.GCNConv(emb_dim, emb_dim)
        
        self.glob_mlp = nn.Linear(emb_dim, 1)
        self.node_mlp = nn.Linear(emb_dim, 1)
        self.edge_mlp = nn.Linear(emb_dim, 1)
        
    def logits(self, g):
        x, edge_index = g.x.float(), g.edge_index
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        glob = gnn.global_add_pool(x, g.batch)
        
        i, j = edge_index
        edge_feature = x[i] + x[j]
        
        glob_logits = self.glob_mlp(glob).flatten()
        node_logits = self.node_mlp(x).flatten()
        edge_logits = self.edge_mlp(edge_feature).flatten()
        
        logits = torch.cat([glob_logits, node_logits, edge_logits])
        return logits
        
gcn = GCNPolicy2(g.x.shape[1])
logits = gcn.logits(g)
 
logits.shape

torch.Size([79])

In [8]:
glob_batch = torch.arange(g.num_graphs)
node_batch = g.batch
edge_batch = torch.arange(g.num_graphs).repeat_interleave(g.num_edges)

indices = torch.cat([glob_batch, node_batch, edge_batch])


indices

tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1])

In [9]:
# We can calculate log probabilities. 
# But how can we get target probabilities that correspond to y?

y = torch.LongTensor([2, 13])
scatter_log_softmax(logits, indices)

tensor([-2.7812, -2.8636, -3.7907, -3.9286, -3.9379, -3.9802, -3.9975, -4.0006,
        -4.0072, -3.9304, -3.9490, -3.9848, -3.9330, -3.9475, -3.7976, -3.7976,
        -3.5355, -3.6587, -3.6289, -3.6437, -3.6923, -3.6548, -3.6645, -3.7142,
        -3.6645, -3.6548, -3.6392, -3.6885, -3.6885, -3.7759, -3.8233, -3.7759,
        -3.8174, -3.8174, -3.8584, -3.8584, -3.8744, -3.8373, -3.8233, -3.8744,
        -3.8898, -3.8898, -3.8590, -3.8570, -3.8590, -3.8570, -3.8422, -3.7204,
        -3.8422, -3.8459, -3.8379, -3.8459, -3.8373, -3.8379, -3.7160, -3.7160,
        -3.7204, -3.4589, -3.4589, -3.5222, -3.5206, -3.5222, -3.5206, -3.5231,
        -3.5231, -3.5156, -3.5156, -3.5156, -3.4954, -3.4954, -3.5286, -3.5286,
        -3.5286, -3.5239, -3.5286, -3.4954, -3.5156, -3.4954, -3.5239],
       grad_fn=<SubBackward0>)

In [10]:
# We first sort the indices and logits
sorted_indices, mapping = torch.sort(indices, stable=True)
sorted_logits = logits[mapping]
log_probs = scatter_log_softmax(sorted_logits, sorted_indices)

sorted_indices

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1])

In [11]:
# We get target probabilities by adjusting indices
count = torch.bincount(sorted_indices)
offsets = torch.cumsum(count, 0) - count

log_probs[y + offsets]

tensor([-3.9286, -3.4589], grad_fn=<IndexBackward0>)

# Gumbel-max trick

`logits`에서 샘플링을 하고 싶을 때, softmax 값을 직접 구하지 않아도 되는 방법이 있다.

아래 식과 같이 `X`를 샘플링을 하면, 이는 softmax에서 샘플링을 한 것과 같다.

```
u ~ uniform(len(uniform))
X ~ argmax(-log(-log(u)))
```

이를 Gumbel-max trick이라 부르는데, 그 이유는 `-log(-log(u))`가 Gumbel 분포를 따르기 때문이다.

*참고*

    Gumbel 분포의 CDF는 다음과 같으며,

$$
F(x) = \exp\{- \exp\{-x\}\}.
$$

    CDF의 역함수는 다음과 같다

$$
F^{-1}(y) = -\log(-\log y).
$$



- Gumbel-max trick 증명

    https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/

In [12]:
# We can sample from logits using Gumbel-max trick

logits = torch.FloatTensor([0, 1, 10, 2, 1, 0])
indices = torch.LongTensor([1, 1,  1, 1, 0,  0])


indices, mapping = torch.sort(indices, stable=True)
logits = logits[mapping]

unif = torch.rand_like(logits)
gumbel = -(-unif.log() ).log()
_, max_indices = scatter_max(logits + gumbel, indices)

count = torch.bincount(indices)
offsets = torch.cumsum(count, 0) - count
samples =  max_indices - offsets

samples

tensor([0, 2])

# Question
    
1. `torch.rand_like(logits)`와 `torch.rand(len(logits))`와의 차이는??