# Dialogical Argumentation Mining using GNNs

The goal of the following notebook is to build a dialogic Arumentation Extraction Model based on Inference Anchoring Theory.

The data is given in the Argument Interchange Format (2024). 

Basic model:
We are given 3 node types, where two are entities and one is a relation. Entities given are locutionary (L-nodes) or propositional (I-nodes). Relations given are TA-nodes (transition nodes) and indicate the dialogic chain (sequence of uttered statements) between L-nodes.

Instead of framing the task as "relation classification", we frame it as a "node-prediction" task on a graph G. The graph was generated through a "nodeset normalisation" procedure as described in Binder (2024).



## Main Code

### Preliminaries

In [1]:
# Any installs
! pip install cowsay 



In [None]:
%pip install --upgrade pip
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
%pip install torch_geometric transformers protobuf tiktoken sentencepiece tokenizers torch-sparse torch-scatter

Note: you may need to restart the kernel to use updated packages.


In [1]:
# Declare Imports
import os, sys, json
import tabulate
import pandas as pd
pd.set_option('display.max_columns', None)

ModuleNotFoundError: No module named 'tabulate'

In [2]:
from transformers import (
    AutoTokenizer,
    DebertaV2TokenizerFast,
    DataCollatorWithPadding,
    get_scheduler
)

  from .autonotebook import tqdm as notebook_tqdm


In [38]:
# Create some relevant folders for data persistence
os.makedirs("./data/normalised", exist_ok=True)
os.makedirs("./data/predictions/unnormalised", exist_ok=True)
os.makedirs("./data/predictions", exist_ok=True)

In [39]:
# Define some paths (e.g. to load, save data)
NODESET_PATH = "./data/normalised"
MODEL_NAME = "microsoft/deberta-v3-small"
# MODEL_NAME = "distilbert/distilbert-base-uncased"
DEBUG = 20

import torch
def detect_platform(cuda_num):
    if torch.cuda.is_available():
        print("cuda is available")
        return f'cuda:{cuda_num}'
    elif torch.backends.mps.is_available():
        print("mps is available")
        return 'mps'
    else:
        print("cpu is available")
        return 'cpu'

DEVICE = detect_platform(0)

mps is available


### Data overview & Preprocessing

#### Data description & First look

Given:
+ `L`-nodes, and `I`-nodes are in the **nodes** dict.
+ `TA`-nodes are also in the **nodes** dict.

Tb. predicted: 
+ `YA`-nodes are in the **nodes** dict.
+ `S-nodes` are not directly given, but through:
  + `CA`: Conflicting argument
  + `RA`: Default Inference
  + `MA`: Default Rephras`

Out goal is to assign one of 15-classes to each of these nodes. 4 classes (including NONE) are possible for each S-relation and 12 classes (including NONE) are possible for each YA-node. Leaving a total of **15 classes** to be predicted.

The associated timestamp information is unreliable.

In [40]:
import pprint
import json

# Load json file and pprint
def load_json_as_dict(filename):
    with open(filename, 'r') as f:
        return json.load(f)
    
def pprint_dict(dict_data):
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(dict_data)

load_json_as_dict("./data/normalised/nodeset17945.json")
# load_json_as_dict("./data/normalised/nodeset17945.json").keys()
# load_json_as_dict("./data/normalised/nodeset17945.json")["edges"]

{'nodes': [{'nodeID': '513315',
   'text': "Helle Thorning-Schmidt : That's an important point to think about.",
   'type': 'L',
   'timestamp': '2020-05-28 20:37:25'},
  {'nodeID': '513317',
   'text': 'it is an important point to think about the risk for vulnerable children for remaining in homes where they may be facing abuse or other problems which greatly affect their wellbeing',
   'type': 'I',
   'timestamp': '2020-05-28 20:37:25'},
  {'nodeID': '513320',
   'text': "Helle Thorning-Schmidt : What is the consequence for children who can't go to school.",
   'type': 'L',
   'timestamp': '2020-05-28 20:37:26'},
  {'nodeID': '513322',
   'text': "xxx is the consequence for children who can't go to school",
   'type': 'I',
   'timestamp': '2020-05-28 20:37:26'},
  {'nodeID': '513324',
   'text': 'Default Transition',
   'type': 'TA',
   'timestamp': '2020-05-28 20:37:26',
   'scheme': 'Default Transition',
   'schemeID': '82'},
  {'nodeID': '513325',
   'text': "Helle Thorning-Schmid

#### Normalise the Nodesets (external, Binder et al. (2024))


Cd into: <br>

`cd ./baseline/dialam-2024-shared-task-dfki`

Activate conda env: <br>
`conda activate dialam-2024-shared-task`

Ensure you put the relevant nodesets into `./data/noddies`


Normalise the nodesets and find them in `./data/noddies_processed`:

```

python src/utils/prepare_data.py --input_dir="./data/noddies" --output_dir="./data/noddies_processed #, for the evaluation or

python src/utils/prepare_data.py --input_dir="./data/noddies" --output_dir="./data/noddies_processed" --integrate_gold_data # for the training data

```

`cd -`

Move the processed nodesets into the designated `./data/normalised` folder (we only work the the `train` data supplied by the qt-30 corpus), for processing by this notebook.




In [41]:
# Load all the nodesets

def load_all_nodesets():
    nodesets = []
    for filename in os.listdir(NODESET_PATH):
        if filename.endswith(".json"):
            nodeset = load_json_as_dict(f"{NODESET_PATH}/{filename}")
            nodeset["filename"] = filename
            nodesets.append(nodeset)
    return nodesets

nodesets = load_all_nodesets()

In [42]:
nodesets[0].keys()
nodesets[0]["nodes"]
# nodesets[0]["edges"]
# nodesets[0]["nodes"]
# nodesets[0]["locutions"]
# nodesets[0]["filename"]

[{'nodeID': '712202',
  'text': 'Nick Thomas-Symonds : Yes, the government made a profound error by not adding India to the red list on 9 April with Pakistan and Bangladesh',
  'type': 'L',
  'timestamp': '2021-05-27 19:27:11'},
 {'nodeID': '712204',
  'text': 'the government made a profound error by not adding India to the red list on 9 April with Pakistan and Bangladesh',
  'type': 'I',
  'timestamp': '2021-05-27 19:27:11'},
 {'nodeID': '712206',
  'text': 'Nick Thomas-Symonds : Nadhim gave an explanation of a variant',
  'type': 'L',
  'timestamp': '2021-05-27 19:27:12'},
 {'nodeID': '712208',
  'text': 'Nadhim gave an explanation of a variant',
  'type': 'I',
  'timestamp': '2021-05-27 19:27:13'},
 {'nodeID': '712210',
  'text': "Nick Thomas-Symonds : That's the second different explanation I've heard from government ministers this week",
  'type': 'L',
  'timestamp': '2021-05-27 19:27:14'},
 {'nodeID': '712212',
  'text': "Nadhim's explanation of a variant is the second different 

#### Tokenise and label nodes

In [44]:
from transformers import AutoModel

# Run the tokeniser
tokenizer = DebertaV2TokenizerFast.from_pretrained(MODEL_NAME)
vocab_size_old = len(tokenizer.vocab)

deberta_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)

MODEL_MAX_LENGTH = 384 # Set this manually, since it appears to be "infinite" in theory for the DeBERTa model

In [45]:
# add custom non-text tokens
NT_TOKENS = ["[S]", "[YA]", "[TA]"]# S:= arginf., YA := illocutionary, TA := transitional between locutions

# Use more precide token framing
# NT_TOKENS = [""]

tokenizer.add_tokens(NT_TOKENS)
vocab_size_new = len(tokenizer.vocab)
print("original vocab size:", vocab_size_old)
print("new vocab size:", vocab_size_new)
print("Special marker tokens", NT_TOKENS)

original vocab size: 128001
new vocab size: 128004
Special marker tokens ['[S]', '[YA]', '[TA]']


In [46]:
# tokenise input and label output
from functools import partial

partial_tok = partial(tokenizer, 
            is_split_into_words=False, 
            truncation=True, 
            padding="max_length", 
            max_length=MODEL_MAX_LENGTH, 
            return_tensors="pt"
        )

txt = partial_tok(text="[S], This is fun [YA]") # we are only interested in the input_ids and possibly the attention mask?
tokenizer.decode(txt["input_ids"][0], skip_special_tokens=True)

'[S] , This is fun [YA]'

In [47]:
import copy

def encode_node_text(node):
    """Encodes a string of text into a tokenised form

    Args:
        node (Dict): A dict of type node in AIF (Argument Interchange Format)

    Raises:
        ValueError: Raised the node type attribute is not recognised

    Returns:
        node (Dict): A node enriched with tokeniser-encoded "tokens" field (PyTorch tensor)
    """
    n = copy.deepcopy(node)
    if n["type"] in ["I", "L"]:
        n["tokens"] = partial_tok(text=n["text"])
    else:
        if n["type"] in ["CA", "RA", "MA"]: # An S-node
            n["tokens"] = partial_tok(text="[S]")
            n["label"] = n["text"]
        elif n["type"] in ["YA"]: # An Illocutionary-node
            n["tokens"] = partial_tok(text="[YA]") 
            n["label"] = n["text"]
        elif n["type"] in ["TA"]: # A transitional node
            n["tokens"] = partial_tok(text="[TA]")
            # n["label"] = n["text"] # No label needed for transitional nodes
        else:
            raise ValueError("Unknown node type")
    return n

In [48]:
# sanity check
assert tokenizer.decode(encode_node_text(nodesets[0]["nodes"][30])["tokens"]["input_ids"][0], skip_special_tokens=True) == "[TA]", "Problem encoding node text to tokens"
# tokenizer.decode(encode_node_text(nodesets[0]["nodes"][30])["tokens"]["input_ids"][0], skip_special_tokens=True)

In [49]:
encoded_nodesets = list(map(lambda nodeset: {
    "nodes": list(map(encode_node_text, nodeset["nodes"])),
    "edges": nodeset["edges"],
    "locutions": nodeset["locutions"]
}, nodesets))

In [50]:
encoded_nodesets[10]["nodes"][40]

{'nodeID': '854539',
 'text': 'Default Transition',
 'type': 'TA',
 'timestamp': '2022-01-13 18:23:34',
 'scheme': 'Default Transition',
 'schemeID': '82',
 'tokens': {'input_ids': tensor([[     1, 128003,      2,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,
               0,      0,      0,      0,      0,      0,      0,      0,      0,


##### Convert to a PyTorch Geometric Dataset using nodes on string-format, no-directionality, only locutions (not propositions)
This is a approximation of Binder (2024)'s attempt.

In [51]:
import torch
from torch_geometric.data import Data

# Build a feature matrix in the form [num_nodes, num_node_features (tokeniser encoded strings?)]
def nodeset_to_pyg_data(nodeset):
    nodes = nodeset["nodes"]
    # edges = nodeset["edges"]
    # locutions = nodeset["locutions"]
    nodeid_to_index = {n["nodeID"]: i for i, n in enumerate(nodes)}
    nodeid_to_labels = {n["nodeID"]: n["label"] for n in nodes if "label" in n}
    node_features = torch.concat([torch.vstack(
        [n["tokens"][k] for k in ["input_ids", "attention_mask"]]
        ).unsqueeze(dim=0) for n in nodes])
    return node_features, nodeid_to_index, nodeid_to_labels

print("Numer of Nodes", len(encoded_nodesets[0]["nodes"]))
x, nodeid_to_index, _ = nodeset_to_pyg_data(encoded_nodesets[0])
print("Tensor Graph Shape", x.shape)


Numer of Nodes 78
Tensor Graph Shape torch.Size([78, 2, 384])


In [52]:
print(x[0][1])

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [53]:
# Encode the edges (directed) in the COO format (I guess the adjacency matrix)
def construct_edge_index(nodeset, nodeid_to_index, symmetrical=False):
    """Construct an edge index in COO format from a nodeset
        Note: We use tuple representation and convert it to COO format on return, as detailed here:
        https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html#data-handling-of-graphs

    Args:
        nodeset (Dict): A nodeset
    
    Returns:
        edge_index (Tensor): A 2xN tensor of edge indices in COO format
    """
    edges = nodeset["edges"]
    def generate_edges_in_tuple_format(edges, symmetrical=False):
        for e in edges:
            yield ([nodeid_to_index[e["fromID"]], nodeid_to_index[e["toID"]]], [nodeid_to_index[e["toID"]], nodeid_to_index[e["fromID"]]]) \
                if symmetrical else ([nodeid_to_index[e["fromID"]], nodeid_to_index[e["toID"]]], )
    return torch.tensor([y for x in generate_edges_in_tuple_format(edges, symmetrical) for y in x], dtype=torch.long)
edge_index = construct_edge_index(encoded_nodesets[0], nodeid_to_index, symmetrical=False)
print("Constructed edges", edge_index)
print("Number of edges", edge_index.shape)

Constructed edges tensor([[ 2, 26],
        [26,  4],
        [ 0, 27],
        [27,  2],
        [ 4, 28],
        [28,  6],
        [ 6, 29],
        [29,  8],
        [ 6, 30],
        [30, 10],
        [10, 31],
        [31, 12],
        [12, 32],
        [32, 14],
        [14, 33],
        [33, 16],
        [16, 34],
        [34, 18],
        [18, 35],
        [35, 20],
        [20, 36],
        [36, 22],
        [ 0, 37],
        [37, 22],
        [22, 38],
        [38, 24],
        [ 5, 39],
        [39,  3],
        [ 3, 40],
        [40,  1],
        [ 7, 41],
        [41,  5],
        [ 9, 42],
        [42,  7],
        [11, 43],
        [43,  7],
        [13, 44],
        [44, 11],
        [15, 45],
        [45, 13],
        [17, 46],
        [46, 15],
        [19, 47],
        [47, 17],
        [21, 48],
        [48, 19],
        [23, 49],
        [49, 21],
        [23, 50],
        [50,  1],
        [25, 51],
        [51, 23],
        [ 0, 52],
        [52,  1],
        [ 

In [54]:
# Build the PyG Data object
data = Data(x=x, edge_index=edge_index.t().contiguous())
# Implement train-level targets?


In [55]:
# Just to check if we use symmtetrical, asymmetrical constraints
data.is_directed()

True

##### Create a Dataset from the individual graphs (PyG data-objects)

In [56]:
from torch_geometric.loader import DataLoader
def generate_pyg_dataset(encoded_nodesets):
    """Generates a PyG dataset from a list of encoded nodesets

    Args:
        encoded_nodesets (List[Dict]): A list of encoded nodesets

    Returns:
        dataset (List[Data]): A list of PyG Data objects
    """
    dataset = []
    for nodeset in encoded_nodesets[:DEBUG]:
        x, nodeid_to_index, nodeid_to_labels = nodeset_to_pyg_data(nodeset)
        edge_index = construct_edge_index(nodeset, nodeid_to_index, symmetrical=False)
        data = Data(x=x, edge_index=edge_index.t().contiguous())
        data.nodeid_to_index = nodeid_to_index
        data.index_to_nodeid = {v: k for k, v in nodeid_to_index.items()}
        data.nodeid_to_labels = nodeid_to_labels
        dataset.append(data)
    return dataset

nodeset_graphs = generate_pyg_dataset(encoded_nodesets)

In [57]:
g = nodeset_graphs[0]
g.x[0][0] # gives us the tokens
# g.x[0][1] # gives us the attention mask

g2 = nodeset_graphs[2]
# g2.x[0][0] # gives us the tokens
g2.x[0][1] # gives us the attention mask

tensor([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [58]:
# inline context mgr
deberta_model.eval()

def get_node_embedding(node):
    with torch.no_grad():
        # import pdb; pdb.set_trace()
        return deberta_model(node[0].unsqueeze(dim=0).to(DEVICE), 
                             attention_mask=node[1].unsqueeze(dim=0).to(DEVICE), 
                             output_hidden_states=True
                        ).last_hidden_state[:,0][0].clone().detach()
        # .hidden_states[-1][0,0,:] # get the CLS token
# print(g.x[0][0] == g2.x[0][0])
assert torch.any((get_node_embedding(g.x[0]) == get_node_embedding(g2.x[0]))) == False, "There's a dimension issue here" # test get node embedding
# print("CLS embedding of first node of G1:", get_node_embedding(g.x[0]))
# print("CLS embedding of first node of G2:", get_node_embedding(g2.x[10]))

In [59]:
get_node_embedding(g2.x[0])

tensor([-2.1272e-02, -3.0149e-02,  1.6295e-03, -6.3993e-03,  8.0136e-03,
        -8.7204e-02, -2.5756e-02, -5.6568e-02, -1.3848e-02, -2.0704e-02,
        -5.1214e-02, -4.2049e-02,  4.4788e-02,  3.7260e-02, -3.2503e-02,
        -6.1188e-03, -4.7663e-02, -7.7772e-03, -7.1066e-03, -4.1993e-02,
         1.2513e-02,  3.3260e-02, -3.4277e-02, -1.3659e+00,  8.7290e-03,
        -5.2791e-02, -5.4420e-02, -1.1150e-02,  3.4624e-02, -6.1675e-02,
         1.5734e-02, -6.3536e-02, -6.7162e-02,  1.7815e-02,  2.9668e-02,
         3.0734e-02, -1.6835e-02, -6.3208e+00, -1.4178e-01, -5.8546e+00,
        -1.4041e-02, -1.8020e-02,  6.0095e-03,  6.8651e+00, -4.7702e-02,
        -7.2039e-02,  3.0207e-02, -1.4077e-01, -4.7215e-02, -5.6883e-03,
        -7.0918e-02, -5.2122e-02, -7.3409e-02, -1.7403e-02, -4.3405e-02,
        -9.4914e-02,  5.1728e-02, -8.9908e-02, -1.9623e-02, -5.0707e-02,
        -1.2162e-02,  6.5398e-03,  5.5986e-02,  5.5687e-03,  3.7580e-02,
        -3.1253e-02,  3.7941e-02, -2.4739e-02, -9.5

In [60]:
# embs = torch.vstack([get_node_embedding(node) for node in g.x])
# embs.shape

In [61]:
nodeset_graphs[3].x[0][0].get_device()

-1

In [62]:
# prefill with BERT default embedding
def prefill_node_embeddings(nodeset_graphs):
    for data in nodeset_graphs[:DEBUG]:
        data.node_embeddings = torch.vstack([get_node_embedding(node) for node in data.x])
    return nodeset_graphs

In [63]:
prefill_node_embeddings(nodeset_graphs)

[Data(
   x=[78, 2, 384],
   edge_index=[2, 104],
   nodeid_to_index={
     712202=0,
     712204=1,
     712206=2,
     712208=3,
     712210=4,
     712212=5,
     712214=6,
     712216=7,
     712218=8,
     712220=9,
     712222=10,
     712224=11,
     712226=12,
     712228=13,
     712230=14,
     712232=15,
     712234=16,
     712236=17,
     712238=18,
     712240=19,
     712245=20,
     712248=21,
     712253=22,
     712256=23,
     712261=24,
     712264=25,
     712268=26,
     712269=27,
     712271=28,
     712272=29,
     712275=30,
     712278=31,
     712279=32,
     712282=33,
     712283=34,
     712284=35,
     712289=36,
     712293=37,
     712295=38,
     712296=39,
     712297=40,
     712298=41,
     712299=42,
     712300=43,
     712301=44,
     712302=45,
     712303=46,
     712304=47,
     712305=48,
     712306=49,
     712307=50,
     712308=51,
     712309=52,
     712310=53,
     712311=54,
     712312=55,
     712313=56,
     712314=57,
     712315

In [64]:
nodeset_graphs[1].node_embeddings[10]

tensor([-1.0401e-02, -1.1985e-02,  7.8933e-03, -1.1570e-03, -5.9743e-03,
        -5.6065e-02, -3.4481e-02, -5.1867e-02, -2.2736e-02, -2.9083e-02,
        -4.8538e-02, -7.9283e-02,  8.4873e-03,  2.0089e-02, -5.6976e-02,
        -9.0659e-03, -4.1334e-02,  3.2241e-03, -2.9463e-03, -5.5289e-02,
         2.0411e-02,  3.4901e-02, -1.8888e-02, -1.1499e+00, -1.2357e-02,
        -5.5806e-02, -5.8549e-02, -2.0221e-02, -9.6383e-05, -6.5141e-02,
         3.6296e-03, -5.9688e-02, -6.3217e-02, -2.7842e-02,  9.1217e-03,
         2.7643e-02, -1.7854e-02, -6.3306e+00, -1.1796e-01, -5.8481e+00,
         1.1686e-02, -8.6214e-03,  1.6796e-02,  6.8782e+00, -6.0086e-02,
        -5.2035e-02,  5.0651e-02, -1.2388e-01, -6.8328e-02, -2.8418e-02,
        -5.1385e-02, -2.6733e-02, -5.7782e-02, -2.7517e-02, -6.3961e-02,
        -7.5921e-02,  1.0782e-02, -6.6497e-02, -3.2713e-02, -4.1577e-02,
        -8.8939e-03, -1.1360e-02,  9.2364e-02,  4.9610e-02,  2.3555e-02,
        -4.0781e-02,  3.0912e-02, -1.6614e-02, -7.6

##### Create dataset for node prediction

In [73]:
# Implement test-train-split
    # We work on the node-level here, we'll also develop a dataloader for the nodes.
    # The graphs are secondary.
    # We must property encode the labels, i.e. one-hot.

def generate_individual_nodes_dataset(nodeset_graphs):
    """
        Generates a dataset of individual nodes, i denotes the graph number
        and j denotes the node number in the graph.
        The last element in the tuple is the label of the node.
    """
    ds = []
    for i, graph in enumerate(nodeset_graphs):
        for j, unlabelled_node in enumerate(graph.nodeid_to_labels.items()):
            # Perform classification only on three nodes
            # import pdb; pdb.set_trace()
            # if unlabelled_node[1] in ["Default Inference", "NONE"]: # reduce to two labels
            dp = (i, graph.nodeid_to_index[unlabelled_node[0]],unlabelled_node[1])
            ds.append(dp)
    return ds

X = generate_individual_nodes_dataset(nodeset_graphs)
# dataset_loader = DataLoader(nodeset_graphs, batch_size=8, shuffle=True)



In [74]:
X # dataset of individual nodes

[(0, 39, 'Default Rephrase'),
 (0, 40, 'NONE'),
 (0, 41, 'NONE'),
 (0, 42, 'Default Conflict'),
 (0, 43, 'Default Rephrase'),
 (0, 44, 'NONE'),
 (0, 45, 'Default Rephrase'),
 (0, 46, 'NONE'),
 (0, 47, 'NONE'),
 (0, 48, 'Default Rephrase'),
 (0, 49, 'Default Inference-rev'),
 (0, 50, 'Default Rephrase'),
 (0, 51, 'Default Rephrase'),
 (0, 52, 'Asserting'),
 (0, 53, 'Asserting'),
 (0, 54, 'Asserting'),
 (0, 55, 'Asserting'),
 (0, 56, 'Asserting'),
 (0, 57, 'Asserting'),
 (0, 58, 'Asserting'),
 (0, 59, 'Asserting'),
 (0, 60, 'Asserting'),
 (0, 61, 'Asserting'),
 (0, 62, 'Asserting'),
 (0, 63, 'Asserting'),
 (0, 64, 'Asserting'),
 (0, 65, 'Restating'),
 (0, 66, 'NONE'),
 (0, 67, 'NONE'),
 (0, 68, 'Disagreeing'),
 (0, 69, 'Restating'),
 (0, 70, 'NONE'),
 (0, 71, 'Restating'),
 (0, 72, 'NONE'),
 (0, 73, 'NONE'),
 (0, 74, 'Restating'),
 (0, 75, 'Arguing'),
 (0, 76, 'Restating'),
 (0, 77, 'Restating'),
 (1, 48, 'Default Rephrase'),
 (1, 49, 'NONE'),
 (1, 50, 'NONE'),
 (1, 51, 'Default Rephrase

In [75]:
# Generate histogram to seek two well-discriminating classes
import numpy as np

def generate_histogram(X):
    labels = [x[2] for x in X]
    unique, counts = np.unique(labels, return_counts=True)
    return dict(zip(unique, counts))

generate_histogram(X)

{'Agreeing': 1,
 'Arguing': 62,
 'Asserting': 304,
 'Assertive Questioning': 1,
 'Default Conflict': 26,
 'Default Illocuting': 6,
 'Default Inference': 35,
 'Default Inference-rev': 29,
 'Default Rephrase': 71,
 'Disagreeing': 26,
 'NONE': 321,
 'Pure Questioning': 10,
 'Restating': 64,
 'Rhetorical Questioning': 2}

In [76]:
# Filter dataset by restating and default inference
X_filtered = list(filter(lambda x: x[2] in ["Default Rephrase", "Restating"], X))
X = X_filtered

##### Encoding individual node labels

In [77]:
from sklearn.preprocessing import OneHotEncoder
possible_labels = [[l] for l in sorted(list(set([node[2] for node in X])))]
ohe = OneHotEncoder().fit(possible_labels)
LABEL_NUMBER = ohe.categories_[0].shape[0]
LABEL_NUMBER

2

In [78]:
ohe.categories_[0]

array(['Default Rephrase', 'Restating'], dtype=object)

In [79]:
def one_hot_encode_labels(x: list):
    return torch.tensor(ohe.transform([[node[2]] for node in x]).toarray(), dtype=torch.float32)
y = one_hot_encode_labels(X)
y

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1

In [83]:
%pip install scikit-learn 

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting scikit-learn
  Downloading scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl.metadata (11 kB)
Downloading scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl (9.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.4/9.4 MB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
[?25hInstalling collected packages: scikit-learn
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
shap 0.44.1 requires pandas, which is not installed.[0m[31m
[0mSuccessfully installed scikit-learn-1.3.2
Note: you may need to restart the kernel to use updated packages.


In [85]:
from sklearn.model_selection import train_test_split

# Basic idea: Create the train-test-split by the indices of the nodesets
ns_idx = list(range(len(nodeset_graphs)))
X_train_idx, X_test_idx = train_test_split(ns_idx, test_size=0.2, random_state=1)
X_train_idx, X_val_idx = train_test_split(X_train_idx, test_size=0.25, random_state=1) # 0.25 x 0.8 = 0.2

ImportError: cannot import name '_check_response_method' from 'sklearn.utils.validation' (/opt/homebrew/Caskroom/miniconda/base/envs/dev/lib/python3.8/site-packages/sklearn/utils/validation.py)

In [86]:
# Allocate nodes given the graph indices:
# train:
X_train = [node for node in X if node[0] in X_train_idx]
y_train = torch.index_select(y, 0, torch.LongTensor([i for i, node in enumerate(X) if node[0] in X_train_idx]))
# test dataset
X_test = [node for node in X if node[0] in X_test_idx]
y_test = torch.index_select(y, 0, torch.LongTensor([i for i, node in enumerate(X) if node[0] in X_test_idx]))
# validation dataset
X_val = [node for node in X if node[0] in X_val_idx]
y_val = torch.index_select(y, 0, torch.LongTensor([i for i, node in enumerate(X) if node[0] in X_val_idx]))


NameError: name 'X_train_idx' is not defined

In [40]:
y_train[0:10]

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])

In [41]:
X_train

[(1, 48, 'Default Rephrase'),
 (1, 49, 'NONE'),
 (1, 50, 'NONE'),
 (1, 51, 'Default Rephrase'),
 (1, 52, 'NONE'),
 (1, 53, 'NONE'),
 (1, 54, 'Default Rephrase'),
 (1, 55, 'NONE'),
 (1, 56, 'Default Rephrase'),
 (1, 57, 'Default Inference-rev'),
 (1, 58, 'Default Rephrase'),
 (1, 59, 'Default Inference-rev'),
 (1, 60, 'NONE'),
 (1, 61, 'Default Conflict'),
 (1, 62, 'NONE'),
 (1, 63, 'NONE'),
 (1, 64, 'Asserting'),
 (1, 65, 'Asserting'),
 (1, 66, 'Asserting'),
 (1, 67, 'Asserting'),
 (1, 68, 'Asserting'),
 (1, 69, 'Asserting'),
 (1, 70, 'Asserting'),
 (1, 71, 'Asserting'),
 (1, 72, 'Asserting'),
 (1, 73, 'NONE'),
 (1, 74, 'Asserting'),
 (1, 75, 'Asserting'),
 (1, 76, 'Asserting'),
 (1, 77, 'Asserting'),
 (1, 78, 'Asserting'),
 (1, 79, 'Asserting'),
 (1, 80, 'Restating'),
 (1, 81, 'NONE'),
 (1, 82, 'NONE'),
 (1, 83, 'Restating'),
 (1, 84, 'NONE'),
 (1, 85, 'NONE'),
 (1, 86, 'Restating'),
 (1, 87, 'NONE'),
 (1, 88, 'Restating'),
 (1, 89, 'Arguing'),
 (1, 90, 'Restating'),
 (1, 91, 'Arguing

In [42]:
# Single batch case
train_loader = list(zip(X_train, y_train))
val_loader = list(zip(X_val, y_val))
test_loader = list(zip(X_test, y_test))

### Model Declaration

In [43]:
# from transformers import AutoModel
# deberta_model = AutoModel.from_pretrained(MODEL_NAME)

In [44]:
tokenizer.decode(g.x[0][0].type(torch.int64)) # we want the CLS token.

'[CLS] Nick Thomas-Symonds : Yes, the government made a profound error by not adding India to the red list on 9 April with Pakistan and Bangladesh[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD

In [45]:
list(deberta_model.modules())[-2].embedding_dim

768

In [46]:
g.edge_index.shape
len(nodesets[170]["edges"])
g.edge_index

tensor([[ 2, 26,  0, 27,  4, 28,  6, 29,  6, 30, 10, 31, 12, 32, 14, 33, 16, 34,
         18, 35, 20, 36,  0, 37, 22, 38,  5, 39,  3, 40,  7, 41,  9, 42, 11, 43,
         13, 44, 15, 45, 17, 46, 19, 47, 21, 48, 23, 49, 23, 50, 25, 51,  0, 52,
          2, 53,  4, 54,  6, 55,  8, 56, 10, 57, 12, 58, 14, 59, 16, 60, 18, 61,
         20, 62, 22, 63, 24, 64, 26, 65, 27, 66, 28, 67, 29, 68, 30, 69, 31, 70,
         32, 71, 33, 72, 34, 73, 35, 74, 36, 75, 37, 76, 38, 77],
        [26,  4, 27,  2, 28,  6, 29,  8, 30, 10, 31, 12, 32, 14, 33, 16, 34, 18,
         35, 20, 36, 22, 37, 22, 38, 24, 39,  3, 40,  1, 41,  5, 42,  7, 43,  7,
         44, 11, 45, 13, 46, 15, 47, 17, 48, 19, 49, 21, 50,  1, 51, 23, 52,  1,
         53,  3, 54,  5, 55,  7, 56,  9, 57, 11, 58, 13, 59, 15, 60, 17, 61, 19,
         62, 21, 63, 23, 64, 25, 65, 39, 66, 40, 67, 41, 68, 42, 69, 43, 70, 44,
         71, 45, 72, 46, 73, 47, 74, 48, 75, 49, 76, 50, 77, 51]])

In [47]:
torch.Tensor([[[1,2,3],[4,5,6]], [[7,8,9], [10,11,12]]])[:, 0, :]

tensor([[1., 2., 3.],
        [7., 8., 9.]])

In [48]:
from torch_geometric.nn import GCNConv, SimpleConv
from torch_geometric.utils import to_scipy_sparse_matrix 
from torch_sparse import SparseTensor 
sc = SimpleConv(aggr='sum').to(DEVICE)
# sc = GCNConv(512, 168)
# adj = SparseTensor(row=g.edge_index[0], col=g.edge_index[1])

# Create adjacency matrix in LIL-representation
# def COO_to_LIL(edge_index):
#     n_nodes = edge_index.max().item() + 1
#     lil_matrix = torch.zeros(n_nodes, n_nodes).type(torch.int64)
#     for i in range(edge_index.shape[1]):
#         lil_matrix[edge_index[0][i], edge_index[1][i]] = 1
#     return lil_matrix
# g.edge_index
# n_nodes = g.x.shape[0]
# lil_matrix = torch.zeros(n_nodes, n_nodes).type(torch.int64)
# for i in range(g.edge_index.shape[1]):
#     lil_matrix[g.edge_index[0][i], g.edge_index[1][i]] = 1

sc(g.node_embeddings, g.edge_index.to(DEVICE))
# type(g.x[0][0])
# g.x[0][1].dtype
# sc(g.x, adj)
# dense_to_sparse(g.edge_index)
# print(torch.LongTensor(to_scipy_sparse_matrix(g.edge_index)))
# type(g.edge_index)
g.node_embeddings


tensor([[ 0.0018, -0.0248,  0.0303,  ..., -0.0249, -0.0306, -0.0638],
        [ 0.0036, -0.0163,  0.0336,  ..., -0.0375, -0.0101, -0.0463],
        [-0.0120, -0.0297,  0.0119,  ..., -0.0446, -0.0502, -0.0985],
        ...,
        [ 0.0212, -0.0591,  0.0218,  ..., -0.0884, -0.0547, -0.0143],
        [ 0.0212, -0.0591,  0.0218,  ..., -0.0884, -0.0547, -0.0143],
        [ 0.0212, -0.0591,  0.0218,  ..., -0.0884, -0.0547, -0.0143]],
       device='mps:0')

In [49]:
del deberta_model

In [50]:
# Declare the model, inspired by: https://github.com/ZeroRin/BertGCN/blob/main/model/models.py
from torch_geometric.nn import GCNConv, SimpleConv, GATConv
from torch.functional import F
from transformers import AutoModel
from torch import nn

class GCN(torch.nn.Module):
    def __init__(self, in_feats=None, n_classes=15, n_hidden=200):
        super().__init__()
        torch.manual_seed(1234567)
        self.in_feats = in_feats
        self.conv1 = GCNConv(in_feats, 60)
        self.conv2 = GCNConv(60, n_classes)

        # Note that SimpleConvs are non-trainable!
        # They are simple, but efficient, checkout here for literature / more reference implementations:
        # https://github.com/Tiiiger/SGC?tab=readme-ov-file
        # self.conv1 = SimpleConv(aggr='sum', 
        #                         # combine_root="self_loop" # we don't need self loop, since we include bert encoded embedding.
                            # ) # See: https://pytorch-geometric.readthedocs.io/en/2.5.3/generated/torch_geometric.nn.conv.SimpleConv.html#torch_geometric.nn.conv.SimpleConv
        # self.conv2 = SimpleConv(aggr='sum', 
        #                         # combine_root="self_loop"
                                # )
        self.linear = torch.nn.Linear(self.in_feats, n_classes)
    def forward(self, embeds, edge_index):
        # edge-index is LIL-encoded adjacency matrix
        # import pdb; pdb.set_trace()
        x = self.conv1(embeds, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        # x = self.conv2(x, edge_index)
        x = self.conv2(x, edge_index)
        # import pdb; pdb.set_trace()
        # x = self.linear(x)
        return x

class GAT(nn.Module):
    def __init__(self,
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 heads,
                 activation,
                 feat_drop=0,
                 attn_drop=0,
                 negative_slope=0.2,
                 residual=False
    ):
        super(GAT, self).__init__()
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        self.concat = True 
        # input projection (no residual)
        self.gat_layers.append(GATConv(
            in_dim, num_hidden, heads[0], self.concat,
            feat_drop, negative_slope))
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
            self.gat_layers.append(GATConv(
                num_hidden * heads[l-1], num_hidden, heads[l], self.concat,
                feat_drop, negative_slope))
        # output projection
        self.gat_layers.append(GATConv(
            num_hidden * heads[-2], num_classes, heads[-1], self.concat,
            feat_drop, negative_slope))

    def forward(self, inputs, g):
        h = inputs
        for l in range(self.num_layers):
            h = self.gat_layers[l](h, g).flatten(1)
        # output projection
        logits = self.gat_layers[-1](h, g).mean(1)
        return logits

class BertDialGCN(torch.nn.Module):
    def __init__(self, pretrained_model, no_classes=15, m=0.7, n_hidden=200, dropout=0.5):
        super(BertDialGCN, self).__init__()
        self.m = m
        self.no_classes = no_classes
        # self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
        # self.deberta_model = pretrained_model
        self.deberta_model = AutoModel.from_pretrained(pretrained_model)
        self.feat_dim = list(self.deberta_model.modules())[-2].embedding_dim
        self.classifier = torch.nn.Linear(self.feat_dim, no_classes)
        self.gcn = GCN(
            in_feats=self.feat_dim,
            n_classes=no_classes,
            n_hidden=n_hidden,
            # n_classes=no_classes,
            # n_layers=gcn_layers-1,
            # activation=F.elu,
            # dropout=dropout
        )
        # self.gcn = GAT(
        #     num_layers=1,
        #     in_dim=self.feat_dim,
        #     num_hidden=n_hidden,
        #     num_classes=no_classes,
        #     # heads=[8, 1],
        #     heads=[8] * (2-1) + [1],
        #     activation=F.elu
        #     # feat_drop=dropout,
        #     # attn_drop=dropout,
        #     # negative_slope=0.2,
        #     # residual=False
        # )

    def forward(self, g, idx):
        input_ids, attention_mask = g.x[idx][0].unsqueeze(dim=0), g.x[idx][1].unsqueeze(dim=0)
        # import pdb; pdb.set_trace()
        # feats = self.deberta_model(input_ids, attention_mask)[0][:, 0]
        if self.training:
            # import pdb; pdb.set_trace()
            # print("Input ID device", input_ids.get_device())
            # print("Attention Mask device", attention_mask.get_device())
            # print("Graph", )
            feats = self.deberta_model(input_ids.to(DEVICE), 
                                       attention_mask.to(DEVICE)
                                    ).last_hidden_state[:,0][0]
            # import pdb; pdb.set_trace()
            # g.node_embeddings.detach_()
            # g.node_embeddings[idx] = feats # store the updated embeddings for the GraphNN
        else:
            feats = self.deberta_model(input_ids.to(DEVICE), 
                                       attention_mask.to(DEVICE)
                                    ).last_hidden_state[:,0][0]
            # feats = g.node_embeddings[idx] # fetch if in eval mode
        cls_logit = self.classifier(feats)
        # import pdb; pdb.set_trace()
        cls_pred = torch.nn.Softmax(dim=0)(cls_logit) # TODO: Confirm that this is the right dimension
        # gcn_logit = self.gcn(g.ndata['cls_feats'], g, g.edata['edge_weight'])[idx]
        # gcn_logit = self.gcn(g.node_embeddings, g.edge_index, idx) # Obtain logis from Siple Graph Conv
        gcn_logit = self.gcn(g.node_embeddings, g.edge_index.to(DEVICE))[idx] # Obtain logis from GAT
        gcn_pred = torch.nn.Softmax(dim=0)(gcn_logit) # TODO: Confirm that this is the right dimension
        pred = (gcn_pred+1e-10) * self.m + cls_pred * (1 - self.m)
        pred = torch.log(pred)
        return pred

### Training

In [51]:
# model = BertDialGCN(no_classes=LABEL_NUMBER, pretrained_model=deberta_model)
model = BertDialGCN(no_classes=LABEL_NUMBER, m=0.7, pretrained_model=MODEL_NAME).to(DEVICE)

# initial test pass through the model
# model.eval()
# out = model(g, 10)
# out

In [52]:
# reset node embeddings to all zeroes
# for g in nodeset_graphs:
#     g.node_embeddings = torch.zeros(g.x.shape[0], list(deberta_model.modules())[-2].embedding_dim)

In [53]:
print(model)

BertDialGCN(
  (deberta_model): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 768, padding_idx=0)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-5): 6 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): StableDropout()
              (dropout): StableDropout()
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
            

In [54]:
nodeset_graphs[0].node_embeddings.get_device() # ensure is stored 0, if cuda or mps is available
nodeset_graphs[0].edge_index.get_device()

-1

In [55]:
import torch.optim as optim
import torch.nn as nn

LEARNING_RATE = 0.0001
MOMENTUM = 0.9

# training loop
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
criterion = nn.CrossEntropyLoss()

def train(model, train_loader, optimizer, criterion, n_epoch=10, iter_print=100):
    # prev_loss = np.inf
    for epoch in range(n_epoch):
        model.train()
        loss_r = 0.0
        count_r = 0
        validation_loss = 0.0
        for i, batch in enumerate(train_loader):
            # print(images.shape)
            # import pdb; pdb.set_trace()
            nodes, labels = tuple(t for t in batch)
            # import pdb; pdb.set_trace()
            optimizer.zero_grad()
            print("Training", nodes, labels)
            # nodeset_graphs[nodes[0]].x[0][0].to(DEVICE)
            # nodeset_graphs[nodes[0]].x[0][1].to(DEVICE)
            outputs = model(nodeset_graphs[nodes[0]], nodes[1])
            # print(outputs)
            loss = criterion(outputs, labels.to(DEVICE))
            # import pdb; pdb.set_trace()
            loss.backward()
            # print("works once!")
            print(loss)
            optimizer.step()
            nodeset_graphs[nodes[0]].node_embeddings.detach_() # TODO: Is this purposeful?
            loss_r += loss.item()
            count_r += 1
            if (i+1) % iter_print == 0:
                print(f"Epoch [{epoch+1}/{n_epoch}], Step [{i+1}/{len(train_loader)}], Average Loss: {loss_r/count_r:.4f}")
                loss_r = 0.0
                count_r = 0
        model.eval()
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                nodes, labels = tuple(t for t in batch)
                # print("Validation", nodes, labels)
                outputs = model(nodeset_graphs[nodes[0]], nodes[1])
                loss = criterion(outputs, labels.to(DEVICE))
                validation_loss += loss.item()
        # Print loss after each epoch
        epoch_loss = loss_r / len(train_loader)
        print(f"\nEnd of Epoch {epoch+1}/{n_epoch}, Average Epoch Train Loss: {epoch_loss:.4f}")
        print(f"Average Validation Loss: {validation_loss / len(X_val):.4f}")

In [2]:
X_test

NameError: name 'X_test' is not defined

In [57]:
# nodeset_graphs[3].detach_()
nodeset_graphs[3].node_embeddings[65]

tensor([ 2.3517e-02, -6.0290e-02,  1.9953e-02,  4.5002e-03,  2.5003e-02,
        -9.2331e-02,  5.8856e-03, -4.4842e-02,  1.1930e-02,  1.6216e-02,
         3.5501e-03, -1.2816e-04,  2.0527e-02,  1.0193e-01, -2.3250e-02,
         3.7577e-02, -7.0609e-02, -8.1153e-03,  2.3115e-02, -3.2160e-02,
        -2.8714e-02, -6.1022e-03, -1.2259e-02, -1.9077e+00,  8.0444e-02,
        -3.6280e-02, -3.5325e-02, -3.4149e-02,  1.2857e-01, -1.0700e-02,
         3.6063e-02, -9.9011e-02, -1.0003e-01,  4.3351e-02,  4.7411e-02,
         2.1735e-02, -4.6116e-02, -6.3033e+00, -1.3471e-01, -5.9250e+00,
        -2.6059e-02,  7.0537e-03,  7.1264e-02,  6.8249e+00, -6.6979e-02,
        -8.4105e-02,  4.5860e-02, -1.6653e-01, -8.2658e-02, -1.2472e-02,
        -7.0927e-02, -6.6312e-02, -4.2821e-02, -1.3778e-02, -1.7776e-02,
        -1.8262e-01,  1.0060e-01, -1.1793e-01, -7.7510e-02, -6.4461e-02,
        -3.4846e-02,  5.4432e-03,  4.5273e-02,  2.0520e-02,  8.3010e-02,
         2.6622e-03,  5.2072e-02,  1.0953e-02, -9.4

In [58]:
# nodeset_graphs[2].x[29].get_device()
nodeset_graphs[2].x[29]

tensor([[     1, 128001,      2,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,    

In [59]:
train(model, train_loader, optimizer, criterion, n_epoch=5, iter_print=100)

Training (1, 48, 'Default Rephrase') tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.])
tensor(2.9278, device='mps:0', grad_fn=<DivBackward1>)
Training (1, 49, 'NONE') tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])
tensor(2.7337, device='mps:0', grad_fn=<DivBackward1>)
Training (1, 50, 'NONE') tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])
tensor(2.5444, device='mps:0', grad_fn=<DivBackward1>)
Training (1, 51, 'Default Rephrase') tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.])
tensor(2.7463, device='mps:0', grad_fn=<DivBackward1>)
Training (1, 52, 'NONE') tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])
tensor(2.1304, device='mps:0', grad_fn=<DivBackward1>)
Training (1, 53, 'NONE') tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])
tensor(2.6944, device='mps:0', grad_fn=<DivBackward1>)
Training (1, 54, 'Default Rephrase') tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.])
tens

In [60]:
# Test Loop
unencoded_predictions = []
def test(model, test_loader, criterion):
    model.eval()
    loss_r = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in test_loader:
            # print("iteration")
            nodes, labels = tuple(t for t in batch)
            print("nodes", nodes)
            print("labels", labels)
            outputs = model(nodeset_graphs[nodes[0]], nodes[1])
            # loss = criterion(outputs, labels)
            # loss_r += loss.item()
            # import pdb; pdb.set_trace()
            print("Outputs", outputs)
            _, predicted = torch.max(outputs.data.unsqueeze(dim=0), 1)
            # total += outputs.size(0)
            predicted = predicted.cpu()
            total += 1
            print("Size", outputs.size(0))
            _, labels_maxed = torch.max(labels.data.unsqueeze(dim=0), 1)
            print("Predicted: ", predicted, "Labels: ", labels_maxed)
            correct += (predicted == labels_maxed).sum().item()
            print(predicted)
            unencoded_predictions.append(ohe.inverse_transform(F.one_hot(torch.LongTensor([predicted]), LABEL_NUMBER))[0][0])

    avg_loss = loss_r / len(test_loader)
    accuracy = 100 * correct / total
    print(f'Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

In [61]:
X_test, y_test

([(3, 59, 'NONE'),
  (3, 60, 'NONE'),
  (3, 61, 'Default Inference'),
  (3, 62, 'NONE'),
  (3, 63, 'NONE'),
  (3, 64, 'NONE'),
  (3, 65, 'NONE'),
  (3, 66, 'Default Rephrase'),
  (3, 67, 'NONE'),
  (3, 68, 'NONE'),
  (3, 69, 'NONE'),
  (3, 70, 'NONE'),
  (3, 71, 'Default Rephrase'),
  (3, 72, 'NONE'),
  (3, 73, 'NONE'),
  (3, 74, 'Default Rephrase'),
  (3, 75, 'Default Rephrase'),
  (3, 76, 'NONE'),
  (3, 77, 'Default Inference-rev'),
  (3, 78, 'Asserting'),
  (3, 79, 'Asserting'),
  (3, 80, 'Asserting'),
  (3, 81, 'Asserting'),
  (3, 82, 'Asserting'),
  (3, 83, 'Asserting'),
  (3, 84, 'NONE'),
  (3, 85, 'Asserting'),
  (3, 86, 'Asserting'),
  (3, 87, 'Asserting'),
  (3, 88, 'Asserting'),
  (3, 89, 'Asserting'),
  (3, 90, 'Asserting'),
  (3, 91, 'Asserting'),
  (3, 92, 'Asserting'),
  (3, 93, 'Asserting'),
  (3, 94, 'Asserting'),
  (3, 95, 'Asserting'),
  (3, 96, 'Asserting'),
  (3, 97, 'Asserting'),
  (3, 98, 'NONE'),
  (3, 99, 'NONE'),
  (3, 100, 'Arguing'),
  (3, 101, 'NONE'),
  (3,

In [62]:
test(model, test_loader, criterion)

nodes (3, 59, 'NONE')
labels tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])
Outputs tensor([-4.2240, -2.8252, -1.1074, -4.7973, -3.3126, -3.7875, -3.1221, -3.0614,
        -2.4242, -3.2018, -1.6405, -3.3894, -2.6041, -5.0188], device='mps:0')
Size 14
Predicted:  tensor([2]) Labels:  tensor([10])
tensor([2])
nodes (3, 60, 'NONE')
labels tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.])
Outputs tensor([-4.2253, -2.8245, -1.1073, -4.7992, -3.3124, -3.7880, -3.1225, -3.0616,
        -2.4240, -3.2034, -1.6398, -3.3916, -2.6042, -5.0213], device='mps:0')
Size 14
Predicted:  tensor([2]) Labels:  tensor([10])
tensor([2])
nodes (3, 61, 'Default Inference')
labels tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.])
Outputs tensor([-4.2241, -2.8250, -1.1073, -4.7973, -3.3125, -3.7871, -3.1217, -3.0616,
        -2.4238, -3.2016, -1.6410, -3.3898, -2.6041, -5.0186], device='mps:0')
Size 14
Predicted:  tensor([2]) Labels:  tensor([6])
tensor([2])
nodes (

In [63]:
nodesets[5]["nodes"][70]

{'nodeID': '554544', 'type': 'YA', 'text': 'Asserting'}

In [401]:
X_test

[(3, 59, 'NONE'),
 (3, 60, 'NONE'),
 (3, 61, 'Default Inference'),
 (3, 62, 'NONE'),
 (3, 63, 'NONE'),
 (3, 64, 'NONE'),
 (3, 65, 'NONE'),
 (3, 66, 'Default Rephrase'),
 (3, 67, 'NONE'),
 (3, 68, 'NONE'),
 (3, 69, 'NONE'),
 (3, 70, 'NONE'),
 (3, 71, 'Default Rephrase'),
 (3, 72, 'NONE'),
 (3, 73, 'NONE'),
 (3, 74, 'Default Rephrase'),
 (3, 75, 'Default Rephrase'),
 (3, 76, 'NONE'),
 (3, 77, 'Default Inference-rev'),
 (3, 78, 'Asserting'),
 (3, 79, 'Asserting'),
 (3, 80, 'Asserting'),
 (3, 81, 'Asserting'),
 (3, 82, 'Asserting'),
 (3, 83, 'Asserting'),
 (3, 84, 'NONE'),
 (3, 85, 'Asserting'),
 (3, 86, 'Asserting'),
 (3, 87, 'Asserting'),
 (3, 88, 'Asserting'),
 (3, 89, 'Asserting'),
 (3, 90, 'Asserting'),
 (3, 91, 'Asserting'),
 (3, 92, 'Asserting'),
 (3, 93, 'Asserting'),
 (3, 94, 'Asserting'),
 (3, 95, 'Asserting'),
 (3, 96, 'Asserting'),
 (3, 97, 'Asserting'),
 (3, 98, 'NONE'),
 (3, 99, 'NONE'),
 (3, 100, 'Arguing'),
 (3, 101, 'NONE'),
 (3, 102, 'NONE'),
 (3, 103, 'NONE'),
 (3, 104, 

In [64]:
nodesets[3]["nodes"][105]

{'nodeID': '861549', 'type': 'YA', 'text': 'Restating'}

In [403]:
# Map predicted labels back to graph
unencoded_predictions
X_test_pred = list(zip([(x[0], x[1]) for x in X_test], unencoded_predictions))
# only include relevant nodesets, i.e. those with predicted labels
relevant_nodeset_ids = sorted(list(set([x[0][0] for x in X_test_pred])))
nodesets_pred = copy.deepcopy(nodesets)
# import pdb; pdb.set_trace()
# replace the nodes with the predicted labels
def node_swap(ns_pred):
    """
    Replace the gold labels with the predicted labels in the nodesets
    """
    for node in ns_pred:
        print("Replacing text of node", node[0][1], "in nodeset", node[0][0])
        nodesets_pred[node[0][0]]["nodes"][node[0][1]]["text"] = node[1]

node_swap(X_test_pred)

# subset only relevant nodesets
nodesets_pred_rel = [nodesets_pred[id] for id in relevant_nodeset_ids]

Replacing text of node 59 in nodeset 3
Replacing text of node 60 in nodeset 3
Replacing text of node 61 in nodeset 3
Replacing text of node 62 in nodeset 3
Replacing text of node 63 in nodeset 3
Replacing text of node 64 in nodeset 3
Replacing text of node 65 in nodeset 3
Replacing text of node 66 in nodeset 3
Replacing text of node 67 in nodeset 3
Replacing text of node 68 in nodeset 3
Replacing text of node 69 in nodeset 3
Replacing text of node 70 in nodeset 3
Replacing text of node 71 in nodeset 3
Replacing text of node 72 in nodeset 3
Replacing text of node 73 in nodeset 3
Replacing text of node 74 in nodeset 3
Replacing text of node 75 in nodeset 3
Replacing text of node 76 in nodeset 3
Replacing text of node 77 in nodeset 3
Replacing text of node 78 in nodeset 3
Replacing text of node 79 in nodeset 3
Replacing text of node 80 in nodeset 3
Replacing text of node 81 in nodeset 3
Replacing text of node 82 in nodeset 3
Replacing text of node 83 in nodeset 3
Replacing text of node 84

In [404]:
nodesets_pred[3]["nodes"][105]

{'nodeID': '861549', 'type': 'YA', 'text': 'NONE'}

In [405]:
nodesets_pred_rel[0]["nodes"][105]

{'nodeID': '861549', 'type': 'YA', 'text': 'NONE'}

In [406]:
nodesets_pred[5]

{'nodes': [{'nodeID': '554359',
   'text': 'Robert Buckland : Clarity is something that people demand from government',
   'type': 'L',
   'timestamp': '2020-06-18 20:39:53'},
  {'nodeID': '554360',
   'text': 'clarity is something that people demand from government',
   'type': 'I',
   'timestamp': '2020-06-18 20:39:54'},
  {'nodeID': '554363',
   'text': 'Robert Buckland : whilst it’s right of us to consider the evidence and to constantly evaluate and question and make adjustments, the two metre rule certainly stays for the foreseeable',
   'type': 'L',
   'timestamp': '2020-06-18 20:39:54'},
  {'nodeID': '554364',
   'text': 'whilst it’s right of the UK government to consider the evidence and to constantly evaluate and question and make adjustments, the two metre rule certainly stays for the foreseeable',
   'type': 'I',
   'timestamp': '2020-06-18 20:39:54'},
  {'nodeID': '554367',
   'text': 'Default Transition',
   'type': 'TA',
   'timestamp': '2020-06-18 20:39:55',
   'scheme':

In [407]:
# Save the predicted nodesets to file:
for ns in nodesets_pred_rel:
    with open(f"./data/predictions/{ns['filename']}", 'w') as f:
        json.dump(ns, f)

In [65]:
# Next steps, 

# Cleanup Phase

# Implement training loop (is it fast enough?) ✅
    # (If no) Find an efficient way to train the model for node prediction
    # (If yes) we are done with this

# Implement evaluation ✅

# Finish pipeline
    # Make splitting tree-specific! (for the evaluation script to work) ✅

# Implement GPU acceleration ✅

# Perform data augmentation, i.e. balance the label numbers
    # Does this solve the problem with the all "NONE" predictions?

# Implement Batched training?
# We must likely will need a NodeLoader: 
#   https://pytorch-geometric.readthedocs.io/en/2.5.2/modules/loader.html#torch_geometric.loader.NodeLoader
# Accuracy problem, because we assume batches.

# Replicate the experiment with a common base text.

# Figure how to export the model for evaluation adhering to the evaluation script


### Evaluation

#### De-normalise the nodeset

Move the predicted nodesets in `./data/predictions` into the designated `./baseline/dialam-2024-shared-task-dfki` folder.

Move data into:

`mv ./data/predictions/* ./baseline/dialam-2024-shared-task-dfki/data/noddies_predicted/`

Cd into:

`cd ./baseline/dialam-2024-shared-task-dfki`

We want to remove the "-rev" labelled classes, as well as invert the arrows again, so they work well with the evaluation script.

Use the option: `--re_revert_ra_relations`

We might also want to remove the **NONE** relations.

Use the option `--re_remove_none_relations`

Both functions only work along with the option `--integrate_gold_data`

```
python src/utils/prepare_data.py --input_dir="./data/noddies_predicted" --output_dir="./data/noddies_predicted_unnormalised" --integrate_gold_data --re_revert_ra_relations --re_remove_none_relations # for the training data
```

**Then run the evaluation script on:** <br> <br>
Arguments:
```
python src/evaluation/eval_official.py --predictions_dir="./data/noddies_predicted_unnormalised" --gold_dir="./data/noddies" --mode="arguments"
```

Illocutions:
```
python src/evaluation/eval_official.py --predictions_dir="./data/noddies_predicted_unnormalised" --gold_dir="./data/noddies" --mode="illocutions"
```

## Any final remarks
- Either in MD or in code (like below)