In [1]:
import rootutils

rootutils.setup_root("./", indicator=".project-root", pythonpath=True)

%load_ext autoreload
%autoreload 2

import hydra
import torch
import torch_geometric
from hydra import compose, initialize
from omegaconf import OmegaConf

from topobenchmarkx.data.preprocessor import PreProcessor
from topobenchmarkx.dataloader.dataloader import TBXDataloader
from topobenchmarkx.data.loaders import GraphLoader

from topobenchmarkx.utils.config_resolvers import (
    get_default_transform,
    get_monitor_metric,
    get_monitor_mode,
    infer_in_channels,
)


initialize(config_path="../configs", job_name="job")


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path="../configs", job_name="job")


hydra.initialize()

In [2]:
cfg = compose(config_name="run.yaml", return_hydra_config=True)

In [3]:
graph_loader = GraphLoader(cfg.dataset.loader.parameters)

In [4]:
dataset, dataset_dir = graph_loader.load()

Download complete.


Extracting /home/aj4332/Research/TopoBenchmarkX/datasets/graph/nyu/LanguageDataset/raw/LanguageDataset.zip
Processing...
Done!


In [5]:
len(dataset.data.attention_scores)



2495008

In [6]:
start_index = 0
attention_scores = []
for token_list in dataset.data.tokens:
    end_index = start_index + len(token_list) * len(token_list)
    attention_scores_sentence = dataset.data.attention_scores[start_index:end_index]
    start_index = end_index
    attention_scores_sentence = torch.reshape(attention_scores_sentence, (len(token_list), len(token_list)))
    attention_scores.append(attention_scores_sentence)
tokens = dataset.data.tokens
ids = dataset.data.ids
tags = dataset.data.tags

In [7]:
num_heads = 32
num_sentences = len(ids)/num_heads

Hence we now have the following data.

The length of each attribute is num_heads*num_sentences:
1. tokens - each element is a list of length sentence_length
2. ids - each element is a list of length sentence_length
3. tags - each element is a list of length sentence_length
4. attention_scores - each element is a tensor of shape (sentence_length, sentence_length)

In [8]:
# Generate a list of tuple occurences
graph_2s = []

# the attention score threshold to consider relation between two tokens 
threshold = 0.0001

# list of tokens to avoid in relations
tokens_avoid = ['<|begin_of_text|>'.lower(), ''.lower()]

# iterate over all sentences for all attention heads
for sentence in range(len(tokens)):
    
    current_attention = attention_scores[sentence]
    current_tokens = tokens[sentence]
    
    for row in range(len(current_attention)):
        
        for col in range(0,row+1):
            
            if row != col and current_attention[row][col] >= threshold:
                word1 = current_tokens[row].lower().strip()
                word2 = current_tokens[col].lower().strip()
                
                # Skip tokens that are empty or beginning of text indicators
                if word1 in tokens_avoid or word2 in tokens_avoid or word1.isnumeric() or word2.isnumeric():
                    continue
                relation = ()
                
                # Create an ordered tuple for consistency
                if word1 < word2:
                    relation = (word1,word2)
                else:
                    relation = (word2,word1)

                graph_2s.append(relation)

In [9]:
print(len(graph_2s))

1001569


In [10]:
import collections
counter=collections.Counter(graph_2s)
# print(collections.OrderedDict(sorted(counter.items())))

In [11]:
attention_scores[0]

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [4.6254e-01, 5.3746e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [4.2101e-01, 9.7475e-02, 4.8151e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [3.7420e-01, 3.1179e-02, 1.3714e-01, 4.5748e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
       

In [12]:
# Generate a list of tuple occurences
attention_head = 0
graph_2s_head = []

# the attention score threshold to consider relation between two tokens 
threshold = 0.0001

# list of tokens to avoid in relations
tokens_avoid = ['<|begin_of_text|>'.lower(), ''.lower()]

# iterate over all sentences for the first attention head
for sentence in range(attention_head,len(tokens), num_heads):
    
    current_attention = attention_scores[sentence]
    current_tokens = tokens[sentence]
    
    for row in range(len(current_attention)):
        
        for col in range(0,row+1):
            
            if row != col and current_attention[row][col] >= threshold:
                word1 = current_tokens[row].lower().strip()
                word2 = current_tokens[col].lower().strip()
                
                # Skip tokens that are empty or beginning of text indicators
                if word1 in tokens_avoid or word2 in tokens_avoid or word1.isnumeric() or word2.isnumeric():
                    continue
                relation = ()
                
                # Create an ordered tuple for consistency
                if word1 < word2:
                    relation = (word1,word2)
                else:
                    relation = (word2,word1)

                graph_2s_head.append(relation)

In [13]:
print(len(graph_2s_head))

32672


In [14]:
import collections
counter2=collections.Counter(graph_2s_head)
# print(collections.OrderedDict(sorted(counter2.items())))