In [17]:
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import DataLoader

import numpy as np
import pandas as pd

import spacy

from sklearn.preprocessing import LabelEncoder

import ast

In [5]:
df = pd.read_csv("../../local_datasets/bsard_extra/bsard_articles_preprocessed.csv")
df.head()

Unnamed: 0,id,article,pos,dep,heads
0,1,Le présent Code règle une matière visée à l'ar...,"[['DET', 'ADJ', 'NOUN', 'VERB', 'DET', 'NOUN',...","[['det', 'amod', 'nsubj', 'ROOT', 'det', 'obj'...","[[2, 2, 3, 3, 5, 3, 5, 9, 9, 6, 9, 13, 13, 9, 3]]"
1,2,Le présent Code transpose en Région de Bruxell...,"[['DET', 'ADJ', 'NOUN', 'VERB', 'ADP', 'NOUN',...","[['det', 'amod', 'nsubj', 'ROOT', 'case', 'obl...","[[2, 2, 3, 3, 5, 3, 7, 5, 5, 5, 11, 3, 11, 3],..."
2,3,Le présent Code poursuit les objectifs suivant...,"[['DET', 'ADJ', 'NOUN', 'VERB', 'DET', 'NOUN',...","[['det', 'amod', 'nsubj', 'ROOT', 'det', 'obj'...","[[2, 2, 3, 3, 5, 3, 5, 3], [2, 2, 2, 4, 2, 4, ..."
3,4,"Au sens du présent Code, il faut entendre par ...","[['ADP', 'NOUN', 'ADP', 'ADJ', 'NOUN', 'PUNCT'...","[['case', 'obl:mod', 'case', 'amod', 'nmod', '...","[[1, 7, 4, 4, 1, 7, 7, 7, 7, 10, 8, 7, 13, 11,..."
4,5,"Le plan régional Air-Climat-énergie, ci-après ...","[['DET', 'NOUN', 'ADJ', 'PROPN', 'PROPN', 'PRO...","[['det', 'nsubj', 'amod', 'nmod', 'nmod', 'nmo...","[[1, 18, 1, 1, 1, 1, 1, 1, 1, 12, 12, 12, 1, 1..."


In [15]:
def string_to_list_of_lists(df, col_name):
    df[col_name] = df[col_name].apply(ast.literal_eval)
    return df

def integer_encode_list(series):
    label_encoder = LabelEncoder()
    
    # Concatenate all the lists in the series to fit the encoder
    concatenated = [item for sublist in series for item in sublist]
    label_encoder.fit(concatenated)
    
    # Transform each list separately and store in a new series
    encoded_series = series.apply(lambda x: label_encoder.transform(x))
    
    return encoded_series

def create_graph_instance(tokens, pos_encoded, heads, article_id):
    sentence_graphs = []
    
    for sentence_tokens, sentence_pos, sentence_heads in zip(tokens, pos_encoded, heads):
        num_nodes = len(sentence_tokens)
        
        # Create nodes and assign features
        x = torch.tensor(sentence_pos, dtype=torch.long).view(-1, 1)
        
        # Convert heads from string to int and create edges using head indices
        sentence_heads = list(map(int, sentence_heads))
        edge_index = [[head_idx, idx] for idx, head_idx in enumerate(sentence_heads)]
        
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        
        # Assign article_id as ground truth
        y = torch.tensor(article_id, dtype=torch.long)
        
        # Create Data instance for PyTorch Geometric
        graph = Data(x=x, edge_index=edge_index, y=y)
        sentence_graphs.append(graph)
    
    return sentence_graphs

In [18]:
df = string_to_list_of_lists(df, 'pos')
df = string_to_list_of_lists(df, 'dep')
df = string_to_list_of_lists(df, 'heads')
df = string_to_list_of_lists(df, 'tokens')

In [None]:
pos_encoded = df['pos'].apply(integer_encode_list)

In [None]:
# Create graphs
graphs = []
for _, row in df.iterrows():
    tokens = row['tokens']
    pos = pos_encoded.loc[_]
    heads = row['heads']
    article_id = row['id']
    graphs.extend(create_graph_instance(tokens, pos, heads, article_id))

# Saving the list of graphs
torch.save(graphs, 'graphs.pt')

# Loading the list of graphs
graphs = torch.load('graphs.pt')