In [1]:
import argparse
import sys, os
import re
import warnings

import numpy as np

import rmm
from rmm.allocators.torch import rmm_torch_allocator
from rmm.allocators.cupy import rmm_cupy_allocator

# Initialize shared allocator to prevent fragmentation
rmm.reinitialize(devices=0, pool_allocator=False, managed_memory=False)

import cupy
cupy.cuda.set_allocator(rmm_cupy_allocator)

import torch
torch.cuda.change_current_allocator(rmm_torch_allocator)

import cudf

sys.path.append('/mnt/bitgraph')
sys.path.append('/mnt/gremlin++')
from pybitgraph import BitGraph

from preprocess import Sentence_Transformer, Word2Vec_Transformer
from transformers import AutoModel, AutoTokenizer
torch.set_float32_matmul_precision('high')

def read_wiki_data(fname, skip_empty=True):
    df = cudf.read_json('/mnt/para_with_hyperlink.jsonl', lines=True)

    mentions = df.mentions.explode()
    mentions = mentions[~mentions.struct.field('sent_idx').isna()]
    mentions = mentions[~mentions.struct.field('ref_ids').isna()]

    slens = df.sentences.list.len().astype('int64')
    slens[(slens==0)] = 1

    df['sentence_offsets'] = cupy.concatenate([
        cupy.array([0]),
        slens.cumsum().values[:-1]
    ])

    mix = torch.as_tensor(
        mentions.struct.field('ref_ids').list.get(0).astype('int64').values,
        device='cuda'
    )
    ids = torch.as_tensor(df.id.astype('int64').values, device='cuda')
    vals, inds = torch.sort(ids)

    destinations_m = inds[torch.searchsorted(vals, mix)]
    sources_m = torch.as_tensor(
        mentions.struct.field('sent_idx').values + df.sentence_offsets[mentions.index].values + len(df),
        device='cuda'
    )

    if skip_empty:
        # Does not add vertices/edges for vertices with no embedding
        f = destinations_m < len(df)
        destinations_m = destinations_m[f]
        sources_m = sources_m[f]
        del f

    eim = torch.stack([
        torch.as_tensor(sources_m, device='cuda'),
        torch.as_tensor(destinations_m, device='cuda'),
    ])

    sentences = df.sentences.explode().reset_index().rename({"index": 'article'},axis=1)

    sources_s = sentences.index.values + len(df)
    destinations_s = sentences.article.values
    eis = torch.stack([
        torch.as_tensor(sources_s, device='cuda'),
        torch.as_tensor(destinations_s, device='cuda'),
    ])

    eix = torch.concatenate([eim,eis],axis=1)
    del eis
    del eim

    return eix, df.title.to_pandas(), sentences.sentences.to_pandas()


def read_embeddings(graph, directory, td):
    ex = re.compile(r'part_([0-9]+)\_([0-9]+).pt')
    def fname_to_key(s):
        m = ex.match(s)
        return int(m[1]), int(m[2])

    ix = 0

    for emb_type in ['titles', 'sentences']:
        path = os.path.join(directory, emb_type)
        files = os.listdir(path)

        files = sorted(files, key=fname_to_key)
        for f in files:
            e = torch.load(os.path.join(path, f), weights_only=True, map_location='cuda').reshape((-1, td))

            print(ix, e.shape)
            graph.set_vertex_embeddings('emb', ix, ix + e.shape[0] - 1, e)
            
            ix += e.shape[0]
            del e


def getem_roberta(model, tokenizer, text):
    t = tokenizer(text, return_tensors='pt')
    while t.input_ids.shape[1] > 512:
        a = a[:-10]
        t = tokenizer(a, return_tensors='pt')
    return model(t.input_ids, t.attention_mask)


def getem_w2v(model, text):
    return model(text)


args = {
    'skip_empty_vertices': True,
    'property_storage': 'managed',
    'fname': '/mnt/para_with_hyperlink.jsonl',
    'embeddings_dir': '/mnt/bitgraph/data/rag/w2v/',
    'embedding_type': 'w2v',
    'w2v_path': '/mnt/GoogleNews-vectors-negative300.bin.gz',
}

eix, titles, sentences = read_wiki_data(
    args['fname'],
    args['skip_empty_vertices']
)
print('read wiki data')

graph = BitGraph(
    'int64',
    'int64',
    'DEVICE',
    'DEVICE',
    args['property_storage'].upper(),
)

graph.add_vertices(eix.max() + 1)
graph.add_edges(eix[0], eix[1], 'link')

read_embeddings(
    graph,
    args['embeddings_dir'],
    td=300 if args['embedding_type'] == 'w2v' else 1024,
)    
print('read embeddings into graph')

g = graph.traversal()
print('constructed graph')

if args['embedding_type'] == 'w2v':
    import gensim
    warnings.warn("Word2Vec encoder is for testing/debugging purposes only!")
    module = Word2Vec_Transformer(
        gensim.models.KeyedVectors.load_word2vec_format(args['w2v_path'], binary=True),
        dim=300,
    )
    getem = lambda t : getem_w2v(module, t)
elif args['embedding_type'] == 'roberta':
    model = AutoModel.from_pretrained('sentence-transformers/all-roberta-large-v1')
    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-roberta-large-v1')
    
    mod = Sentence_Transformer(model).cuda()
    import torch._dynamo
    torch._dynamo.reset()

    module = torch.compile(mod, fullgraph=True)
    getem = lambda t : getem_roberta(module, tokenizer, t)
else:
    raise ValueError("Expected 'w2v' or 'roberta' for embedding type")


  from .autonotebook import tqdm as notebook_tqdm


read wiki data
0 torch.Size([1, 300])
1 torch.Size([1000000, 300])
1000001 torch.Size([1000000, 300])
2000001 torch.Size([994923, 300])
2994924 torch.Size([1, 300])
2994925 torch.Size([1000000, 300])
3994925 torch.Size([1000000, 300])
4994925 torch.Size([994922, 300])
5989847 torch.Size([1, 300])
5989848 torch.Size([1000000, 300])
6989848 torch.Size([1000000, 300])
7989848 torch.Size([1000000, 300])
8989848 torch.Size([1000000, 300])
9989848 torch.Size([1000000, 300])
10989848 torch.Size([1000000, 300])
11989848 torch.Size([1000000, 300])
12989848 torch.Size([1000000, 300])
13989848 torch.Size([1000000, 300])
14989848 torch.Size([1000000, 300])
15989848 torch.Size([1000000, 300])
16989848 torch.Size([797125, 300])
17786973 torch.Size([1, 300])
17786974 torch.Size([1000000, 300])
18786974 torch.Size([1000000, 300])
19786974 torch.Size([1000000, 300])
20786974 torch.Size([1000000, 300])
21786974 torch.Size([1000000, 300])
22786974 torch.Size([1000000, 300])
23786974 torch.Size([1000000, 



In [2]:
def query(search_query, lim=4):
    qe = getem(search_query)
    vids = g.V().like('emb', [qe], lim).toArray()

    f = vids < len(titles)
    article_ids = vids[f]
    sentence_ids = vids[~f] - len(titles)

    print('articles:', titles.iloc[article_ids.get()])
    print('sentences:', sentences.iloc[sentence_ids.get()])


In [None]:
import pandas
truth_df = pandas.read_json('/mnt/data/train.json')
truth_df

In [None]:
truth_df.question.iloc[2]

In [None]:
[z[0] for z in truth_df.supporting_facts.iloc[2]]

In [None]:
'\n'.join([' '.join(z[1]) for z in truth_df.context.iloc[2]])

In [None]:
query(truth_df.question.iloc[167453])

In [13]:
g.V().like('emb', [getem("Miley Cyrus")], 1).toArray()

array([2902052])

In [9]:
g.V().like('emb', [getem("Miley Cyrus")], 1)._in().count().toArray()

array([327], dtype=uint64)

In [26]:
sentences.iloc[g.V().like('emb', [getem("Pumpin ' Up The Party")], 1)._in().toArray().get() - len(titles)]

7600544    "Pumpin' Up the Party" is a pop song by Americ...
7600545    She is performing as Hannah Montana – the alte...
7600546    The song was released to Radio Disney as promo...
7600547      The song has teen pop and dance-pop influences.
7600548    In the United States, the song peaked at numbe...
7600549    Its appearance on the "Billboard" Hot 100 made...
7600550    A music video for "Pumpin' Up the Party" was t...
7600551    Cyrus, dressed as Hannah Montana, performed th...
Name: sentences, dtype: object

In [4]:
titles.iloc[2902052]

'Miley (surname)'

In [3]:
g.V().like('emb', [getem("Pumpin ' Up The Party")], 2).toArray()

array([1954484, 4998348])

In [6]:
v = g.V().like('emb', [getem("Pumpin ' Up The Party")], 2).inE().order().toArray().get()
v
#titles.iloc[v]

array([  946208,  3130776,  5072346,  6709252,  8821095, 14736735,
       16161698, 18491732, 29522486, 29522487, 29522488, 29522489,
       29522490, 29522491, 29522492, 29522493, 41512249, 41512250])

In [None]:
g.V([5013434, 374345]).similarity('emb', [getem('Move (1970 film)')]).toArray()

In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

def extract(entsList):
    words = []
    for ents in entsList:
        row = []
        for ent in ents:
            row.append(ent['word'])
        words.append(row)
    return words

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-large-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-large-NER")

ner = pipeline("ner", model=model, tokenizer=tokenizer, device=0, aggregation_strategy="max")

In [None]:
import numpy as np
vids = np.concatenate([
    g.V().like('emb', [getem(ent['word'])], 4).toArray()
    for ent in ner(truth_df.question.iloc[167453])
])

print(vids)

f = (vids < len(titles))
print('articles:', titles.iloc[vids[f].get()])
print('sentences:', sentences.iloc[vids[~f].get() - len(titles)])

In [9]:
from pygremlinxx import GraphTraversal
__ = lambda : GraphTraversal()

# The subgraph step does not work due to nanobind limitations, so use this way instead
out = graph.subgraph_coo(
    g.V(vids).bothE().dedup()._as('h0').inV().bothE().dedup()._union([__().select('h0'), __().identity()]).dedup().toArray()
)

In [10]:
from torch_geometric.data import Data

def coo_to_data(coo):
    data = Data()
    data.edge_index = torch.stack([
        torch.as_tensor(coo['dst'].astype('int64'), device='cuda'),
        torch.as_tensor(coo['src'].astype('int64'), device='cuda'),
    ])
    data.x = torch.as_tensor(
        g.V(coo['vid']).encode('emb').toArray(),
        device='cuda'
    ).reshape((-1, 300))
    data.batch = torch.zeros((data.x.shape[0],), dtype=torch.int64, device='cuda')

    return data

In [None]:
from torch_geometric.nn import GRetriever, GAT
from torch_geometric.nn.nlp import LLM

llm = LLM(
    model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1',
    num_params=1,
)

gnn = GAT(
    in_channels=300,
    hidden_channels=256,
    out_channels=300,
    num_layers=4,
    heads=4,
)

model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048)

In [None]:
ent_match_limit = 4
que_match_limit = 4

out_limit_h0 = 4
out_limit_h1 = 4
in_limit_h0 = 4
in_limit_h1 = 4

from time import perf_counter

for i in range(3):
    question = truth_df.question.iloc[i]
    answer = truth_df.answer.iloc[i]
    emb_q = getem(question)

    vids_q = np.concatenate(
        [
            g.V().like('emb', [getem(ent['word'])], ent_match_limit).toArray()
            for ent in ner(question)
        ] + [
            g.V().like('emb', [emb_q], que_match_limit).toArray()
        ]
    )
    
    # TODO control hops
    start_time = perf_counter()
    eids = g.V(vids_q)._union([
        __().outE().order().by(__().inV().similarity('emb', [emb_q])).limit(4)._as('h0').inV(),
        __().inE().order().by(__().outV().similarity('emb', [emb_q])).limit(4)._as('h0').outV(),
    ])._union([
        __().outE().order().by(__().inV().similarity('emb', [emb_q])).limit(4)._as('h1').inV(),
        __().inE().order().by(__().outV().similarity('emb', [emb_q])).limit(4)._as('h1').outV(),
    ])._union([__().select('h0'), __().select('h1')]).dedup().toArray()
    end_time = perf_counter()

    print('query time:', end_time - start_time)

    out = graph.subgraph_coo(
        eids
    )

    data = coo_to_data(out)
    print(data)

    loss = model(
        question=[f'question: {question}\nanswer:'],
        x=data.x,
        edge_index=data.edge_index,
        batch=data.batch,
        label=[answer],
        edge_attr=None, # edge features
        additional_text_context=None # additional context
    )
    print(loss)

In [None]:
data

In [None]:
g.V(vids).out().toArray()

In [1]:
import torch

import sys
sys.path.append('/mnt/bitgraph')
sys.path.append('/mnt/gremlin++')
from pybitgraph import BitGraph


graph = BitGraph(
    'uint64',
    'uint64',
    'DEVICE',
    'MANAGED',
    'DEVICE',
)

src = torch.tensor([5, 4, 1, 0, 2, 3, 5, 1, 2, 0], dtype=torch.uint64)
dst = torch.tensor([1, 3, 2, 5, 1, 5, 4, 4, 4, 1], dtype=torch.uint64)

graph.add_vertices(6)
graph.add_edges(src, dst, 'e')

g = graph.traversal()


In [None]:
g.E().toArray()

In [None]:
graph.subgraph_coo(torch.tensor([0, 2, 4], dtype=torch.uint64))

In [None]:
g.V(2).bothE().toArray()

In [None]:
from pygremlinxx import GraphTraversal
__ = lambda : GraphTraversal()

g.V([0, ]).bothE().dedup()._as('h0').inV().bothE().dedup()._union([__().select('h0'), __().identity()]).dedup().toArray()

In [1]:
import cudf
df = cudf.read_json('/mnt/para_with_hyperlink.jsonl', lines=True)

In [2]:
import torch
mentions = df.mentions.explode()
mentions = mentions[~mentions.struct.field('sent_idx').isna()]
mentions = mentions[~mentions.struct.field('ref_ids').isna()]

mix = torch.as_tensor(
    mentions.struct.field('ref_ids').list.get(0).astype('int64').values,
    device='cuda'
)
ids = torch.as_tensor(df.id.astype('int64').values, device='cuda')
vals, inds = torch.sort(ids)


destinations_m = inds[torch.searchsorted(vals, mix)]
destinations_m

tensor([ 181012, 5324480, 3286068,  ..., 2423755, 5409000, 2196530],
       device='cuda:0')

In [42]:
import cupy

slens = df.sentences.list.len().astype('int64')
slens[(slens==0)] = 1

df['sentence_offsets'] = cupy.concatenate([
    cupy.array([0]),
    slens.cumsum().values[:-1]
])
df

Unnamed: 0,id,title,sentences,mentions,sentence_offsets
0,17888798,The Circle (Wipers album),[The Circle is the sixth studio album by punk ...,"[{'id': 0, 'start': 40, 'end': 49, 'ref_url': ...",0
1,17888807,Urgand,[Urgand is a village in Badakhshan Province in...,"[{'id': 0, 'start': 12, 'end': 19, 'ref_url': ...",3
2,17888822,"Urup, Afghanistan",[Urup is a village in Badakhshan Province in n...,"[{'id': 0, 'start': 10, 'end': 17, 'ref_url': ...",4
3,17888850,WMIA (AM),"[""For the Miami, Florida radio station, see WM...","[{'id': 0, 'start': 9, 'end': 23, 'ref_url': '...",5
4,17888858,Guido of Acqui,[Saint Guido of Acqui( also Wido)( c. 1004 – 1...,"[{'id': 0, 'start': 62, 'end': 77, 'ref_url': ...",10
...,...,...,...,...,...
5989842,12347579,Hebeclinium,[Hebeclinium is a genus of flowering plant in ...,"[{'id': 0, 'start': 26, 'end': 41, 'ref_url': ...",23333440
5989843,12347585,Hebeclinium recreense,[Hebeclinium recreense is a species of floweri...,"[{'id': 0, 'start': 38, 'end': 53, 'ref_url': ...",23333441
5989844,12347593,Helichrysum aciculare,[Helichrysum aciculare is a species of floweri...,"[{'id': 0, 'start': 38, 'end': 53, 'ref_url': ...",23333445
5989845,12347598,Helichrysum arachnoides,[Helichrysum arachnoides is a species of flowe...,"[{'id': 0, 'start': 40, 'end': 55, 'ref_url': ...",23333448


In [43]:
src = torch.as_tensor(
    mentions.struct.field('sent_idx').values + df.sentence_offsets[mentions.index].values,
    device='cuda'
) + len(df)
src

tensor([ 5989847,  5989847,  5989847,  ..., 29323299, 29323300, 29323300],
       device='cuda:0')

In [44]:
(src[(destinations_m == 4111782)] == 13590393).sum()

tensor(0, device='cuda:0')

In [45]:
sentences = df.sentences.explode().reset_index().rename({"index": 'article'},axis=1)
sentences

Unnamed: 0,article,sentences
0,0,The Circle is the sixth studio album by punk r...
1,0,The album received positive reviews.
2,0,"""The Rough Guide to Rock"" wrote that ""jazzy di..."
3,1,Urgand is a village in Badakhshan Province in ...
4,2,Urup is a village in Badakhshan Province in no...
...,...,...
23333449,5989845,It is found only in Yemen.
23333450,5989845,Its natural habitat is subtropical or tropical...
23333451,5989846,Helichrysum balfourii is a species of flowerin...
23333452,5989846,It is found only in Yemen.


In [46]:
sentences = df.sentences.explode().reset_index().rename({"index": 'article'},axis=1)
sentences.dropna(inplace=True)
sentences.reset_index(drop=True, inplace=True)

In [47]:
destinations_s = sentences.index.values + len(df)
sources_s = sentences.article.values

In [None]:
destinations_s[sources_s==1954484]

In [None]:
sentences.iloc[13560798-len(df)]

In [None]:
sentences[sentences.article==1954484]

In [None]:
7600544+len(df)

In [None]:
mentions[-5:].struct.field('ref_url')

In [None]:
mentions[-5:].struct.field('ref_ids')

In [None]:
src[-5:]

In [None]:
sentences.iloc[src[-5:]].sentences.values_host.tolist()

In [8]:
qp = {"question_vertex_match_limit": 1, "hop_1_outgoing_limit": 8, "hop_1_incoming_limit": 8, "hop_0_outgoing_limit": 2, "hop_0_incoming_limit": 2, "entity_vertex_match_limit": 2}

In [None]:
#question = "What is the date of birth of the director of film Rathimanmadhan?"
#question = "What is the place of birth of the director of film Discord (Film)?"
#question = "Did the movies Torkaman (Film) and Shameless (2008 Film), originate from the same country?"
question = "Do both directors of films The Big Bang (1989 Film) and Tender Fictions share the same nationality?"

ents = ner(question)
emb_q = getem(question)
ents

In [41]:
def decode(vids):
    f = (vids < len(titles))
    print('articles:', titles.iloc[vids[f].get()])
    print('sentences:', sentences.iloc[vids[~f].get() - len(titles)])

In [None]:
g.V().like('emb', [emb_q], qp['question_vertex_match_limit']).toArray()

In [None]:
ents

In [None]:
getem('Shameless')

In [None]:


vids_q = cupy.concatenate(
    [
        g.V().like('emb', [getem(ent['word'])], qp['entity_vertex_match_limit']).toArray()
        for ent in ents
    ] + [
        g.V().like('emb', [emb_q], qp['question_vertex_match_limit']).toArray()
    ]
)
decode(vids_q)

In [None]:
qp['question_vertex_match_limit']

In [None]:
vids = g.V().like('emb', [getem('Shameless')], 4).toArray()
decode(vids)

In [None]:
titles.iloc[5956065]

In [None]:
sentences.iloc[vids_q.get() - len(titles)]

In [None]:
from pygremlinxx import GraphTraversal
__ = lambda : GraphTraversal()

vids = g.V(vids_q)._union([
    __().out().order().by(__().similarity('emb', [emb_q])).limit(qp['hop_0_outgoing_limit'])._as('h0'),
    __()._in().order().by(__().similarity('emb', [emb_q])).limit(qp['hop_0_incoming_limit'])._as('h0'),
])._union([
    __().out().order().by(__().similarity('emb', [emb_q])).limit(qp['hop_1_outgoing_limit'])._as('h1'),
    __()._in().order().by(__().similarity('emb', [emb_q])).limit(qp['hop_1_incoming_limit'])._as('h1'),
])._union([__().select('h0'), __().select('h1')]).dedup().toArray()

decode(vids)

In [1]:
import pandas
df = pandas.read_json('/mnt/data/train.json')

In [2]:
df

Unnamed: 0,_id,type,question,context,supporting_facts,evidences,answer
0,13f5ad2c088c11ebbd6fac1f6bf848b6,bridge_comparison,Are director of film Move (1970 Film) and dire...,"[[Stuart Rosenberg, [Stuart Rosenberg (August ...","[[Move (1970 film), 0], [Méditerranée (1963 fi...","[[Move (1970 film), director, Stuart Rosenberg...",no
1,3057c6c4086111ebbd5dac1f6bf848b6,bridge_comparison,Do both films The Falcon (Film) and Valentin T...,"[[The Falcon Takes Over, [The Falcon Takes Ove...","[[The Falcon (film), 0], [Valentin the Good, 0...","[[The Falcon (film), director, Vatroslav Mimic...",no
2,89bc944808a111ebbd79ac1f6bf848b6,bridge_comparison,"Which film whose director is younger, Charge I...","[[Danger: Diabolik, [Danger:, Diabolik is a 1...","[[Charge It to Me, 1], [Danger: Diabolik, 1], ...","[[Charge It to Me, director, Roy William Neill...",Danger: Diabolik
3,633f80660bdd11eba7f7acde48001122,compositional,What is the date of birth of Mina Gerhardsen's...,"[[Pamela Jain, [Pamela Jain is an Indian playb...","[[Mina Gerhardsen, 1], [Rune Gerhardsen, 0]]","[[Mina Gerhardsen, father, Rune Gerhardsen], [...",13 June 1946
4,2dc3f9740bda11eba7f7acde48001122,compositional,What nationality is the director of film Weddi...,"[[Weekend in Paradise (1931 film), [Weekend in...","[[Wedding Night in Paradise (1950 film), 0], [...","[[Wedding Night in Paradise, director, Géza vo...",Hungarian
...,...,...,...,...,...,...,...
167449,56100d300bdc11eba7f7acde48001122,compositional,What is the place of birth of the director of ...,"[[S. N. Mathur, [S.N. Mathur was the Director ...","[[Rolling in Money, 0], [Albert Parker (direct...","[[Rolling in Money, director, Albert Parker], ...",New York
167450,3df1a97108ad11ebbd83ac1f6bf848b6,comparison,"Who was born first, Dušan Ninić or Eszter Balint?","[[Tom Dickinson, [Thomas Eastwood Dickinson( 1...","[[Dušan Ninić, 0], [Eszter Balint, 0]]","[[Dušan Ninić, date of birth, September 6, 195...",Dušan Ninić
167451,8be4ef3e0bdc11eba7f7acde48001122,compositional,When did the director of film Morchha die?,"[[Thomas Scott (diver), [Thomas Scott( 1907- d...","[[Morchha, 0], [Ravikant Nagaich, 0]]","[[Morchha, director, Ravikant Nagaich], [Ravik...",6 January 1991
167452,12357df20bdc11eba7f7acde48001122,compositional,What is the date of birth of the director of f...,"[[Peter Levin, [Peter Levin is an American dir...","[[Double Cross (1951 film), 0], [Riccardo Fred...","[[Double Cross, director, Riccardo Freda], [Ri...",24 February 1909


In [3]:
df.question[30]

'Which film has the director who died later, Aaranya Kandam or One Hundred Nails?'