<a href="https://colab.research.google.com/github/jasivan/GAT-UCCA/blob/main/GAT_UCCA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Parameters to define

In [None]:
pos_dim = 20 #position concatenation dimension, set to 0 to disable the feature, must be even
edge_dim = 12 #edge label concatenation dimension, set to 0 to disable the feature
Train_trial = 20 # Number of set of 5 the model is trained for

##Python modules

In [None]:
'''Download functions'''
!pip install transformers
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-geometric

'''IMPORT'''
import xml.etree.ElementTree as ET # extract tree labels

import torch # for neural network
import torch.nn as nn # for neural network functions like dropout
from transformers import BertTokenizer, BertModel # for BERT embeddings

from torch_geometric.data import Data # for GAT
from torch_geometric.nn import GATv2Conv # for GAT
from math import sin, cos, ceil # for position embedding and batching

from sklearn.model_selection import train_test_split # for splitting data
import random # for shuffling
import time # for time stamp

'''Run functions once'''
# Obtain pre-trained BERT word embeddings
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Load pre-trained model (weights)
Bnet = BertModel.from_pretrained('textattack/bert-base-uncased-snli',
                                  output_hidden_states = True) # Whether the model returns all hidden-states.


##Pre-processing functions

In [None]:
# Processing XML file
def XML_processing(file):
  node2tag = {} # dict[node] = 'word/tag'
  parent2children = {} # dict[parent] = [children]
  sent = '' # sentence in a string
  tree = ET.parse(file)
  root = tree.getroot()
  for layer in root.iter('layer'): #Iterated through layers of trees, layer 0 => words layer 1 => nodes
    if layer.attrib['layerID'] == '0':
      for node in layer.iter('node'):
        for attribute in node.iter('attributes'):
          node2tag[node.attrib['ID']] = attribute.attrib['text'] # leaf nodes are matched to words
          sent += attribute.attrib['text'] + ' '
    else:
      node2tag['1.1'] = 'X' # Add root Node labelled X
      for node in layer.iter('node'):
        e = []
        for edge in node.iter('edge'):
          e.append(edge.attrib['toID'])
          parent2children[node.attrib['ID']] = e
          if edge.attrib['type'] != 'Terminal':
            node2tag[edge.attrib['toID']] = edge.attrib['type'] # non-leaf nodes are matched to edge label
  return node2tag, parent2children, sent[0:-1]

# Enumerating all nodes in topological order, leaves to root
def enum_node(parent2children, node2tag):
    L0 = {}
    L1 = {}
    for i, key in enumerate(node2tag.keys()):
      if '0.' in key:
        L0[key]=i
    for i, key in enumerate(parent2children.keys()):
      L1[key] = i+len(L0)
    return L0, L1, {**L0, **L1}

# Edge_index: list of source nodes and target nodes as two lists in a tensor
def edge_index(parent2children, L):
    s = []
    t = []
    for key in parent2children.keys():
      for val in parent2children[key]:
        s.append(key) # edge source
        t.append(val) # edge target
    edge_index = [[L[s[i]] for i in reversed(range(len(s)))], [L[t[i]] for i in reversed(range(len(t)))]]

    return torch.tensor(edge_index, dtype=torch.long)

In [None]:
def get_BERT_emb(sent, node2tag):
    marked_text = "[CLS] " + sent + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text) # Tokenize our sentence with the BERT tokenizer.
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) # convert token to index

    # Mark each of the tokens from first sentence as belonging to sentence "1".
    segments_ids = [1] * len(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    # Put the model in "evaluation" mode, meaning feed-forward operation.
    Bnet.eval() # Call pre-trained BERT model
    with torch.no_grad(): # Ensure weights don't update
        outputs = Bnet(tokens_tensor, segments_tensors)
        hidden_states = outputs[2]
    
    # Sum last 4 layers to get token embeddings
    token_embeddings = torch.squeeze(hidden_states[-1].add(hidden_states[-2].add(hidden_states[-3].add(hidden_states[-4]))))
    
    # Combines tokens embeddings to get word embeddings e.g. vin ##ken --> vinken
    emb, j = [], 0
    W = list(node2tag.values())
    for i in range(len(tokenized_text)-2): # iterate through BERT token embeddings
      if W[j].lower().startswith(tokenized_text[i+1]) == True:
        emb.append(token_embeddings[i+1])
        j += 1
      else:
        emb[-1] = emb[-1].add(token_embeddings[i+1])
    
    return emb # returns list of word embeddings


In [None]:
def preprocessing(file):
  node2tag, parent2children, sent = XML_processing(file)
  L0, L1, L = enum_node(parent2children, node2tag) #L0 => layer0; L1 => layer1 ; L => all layers

  # BERT Embeddings
  emb = get_BERT_emb(sent, node2tag) # list of BERT embedding for leaf nodes
  BERT_emb_size = emb[0].shape[0] # dimension of BERT embedding (768)

  # Append small magnitude random embedding for non-leaf nodes
  for i in range(len(parent2children)):
    emb.append(torch.mul(torch.rand(BERT_emb_size, dtype=torch.float, requires_grad=True), 0.001))

  # Concatenating Position Embeddings (0 vector for non-leaf nodes)
  for t in range(len(L)):
      if t<len(L0):
        pos = []
        for k in range(int(pos_dim/2)):
            pos.append(sin(t/(10000**(2*k/pos_dim))))
            pos.append(cos(t/(10000**(2*k/pos_dim))))
        emb[t] = torch.cat((emb[t], torch.Tensor(pos)))
      else:
        emb[t] = torch.cat((emb[t], torch.zeros(pos_dim)))

  # Concatenating Label Embeddings (0 vector for leaf nodes)
  for t in range(len(node2tag)):
    if t<len(L0):
      emb[t] = torch.cat((emb[t], torch.zeros(edge_dim)))
    else:
      #[0] to get first letter so that C-remote => C
      emb[t] = torch.cat((emb[t], edge_emb[list(node2tag.values())[t][0]]))

  # Combine embeddings into a tensor
  node_emb =  torch.stack(emb, dim=0).to(device)
  edge_idx = edge_index(parent2children, L).to(device)

  return node_emb, edge_idx

In [None]:
''' Unzip XML files '''
!mkdir sentence-B-full
!mkdir sentence-A-full
%cd /content
!unzip /content/sentence-B-full.zip -d /content/sentence-B-full
!unzip /content/sentence-A-full.zip -d /content/sentence-A-full

mkdir: cannot create directory ‘sentence-B-full’: File exists
mkdir: cannot create directory ‘sentence-A-full’: File exists
/content
Archive:  /content/drive/MyDrive/sentence-B-full.zip
replace /content/sentence-B-full/190740.xml? [y]es, [n]o, [A]ll, [N]one, [r]ename: Archive:  /content/drive/MyDrive/sentence-A-full.zip
replace /content/sentence-A-full/190740.xml? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
# Preprocessing parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialising random embedding for the edge labels
edge_label = ['P', 'S', 'A', 'D', 'C', 'E', 'N', 'R', 'H', 'L', 'G', 'F', 'X', 'Q', 'U', 'T']
edge_emb = {}
for label in edge_label:
    edge_emb[label] = torch.rand(edge_dim, dtype = torch.float, requires_grad=True)

# Preprocessing sentences
%cd /content/sentence-B-full
preprocessB = []
for i in range(4906,9840): # total = 9840
  file1 = '1' + "%04d" % (i+1) + '0.xml'
  node_emb1, edge_idx1 = preprocessing(file1)
  preprocessB.append((node_emb1, edge_idx1))

%cd /content/sentence-A-full
preprocessA = []
for i in range(4906,9840):
  file1 = '1' + "%04d" % (i+1) + '0.xml'
  node_emb1, edge_idx1 = preprocessing(file1)
  preprocessA.append((node_emb1, edge_idx1))


In [None]:
# Store relatedness score to sentence pairs
%cd /content
with open('/content/relatedness-score-full.txt', 'r') as f:
    full_target = [torch.FloatTensor([float(line)]).to(device) for line in f.readlines()]

# Group sentence A, sentence B and relatedness score
preprocess = list(zip(preprocessA, preprocessB, full_target[4906:9840]))

# Split data into train and test set
preprocess_train = list(zip(preprocessA[:4439], preprocessB[:4439], full_target[4906:9345]))
preprocess_test = list(zip(preprocessA[4439:], preprocessB[4439:], full_target[9345:]))
print(len(preprocess_train), len(preprocess_test))
del preprocessA
del preprocessB

##Graph Attention Network model

In [None]:
# Graph Attention Network
class GAT(torch.nn.Module):
  def __init__(self, in_channels, out_channels):
    super(GAT, self).__init__()
    self.conv1 = GATv2Conv(in_channels, hidden, heads= heads, concat=False, add_self_loops= True)
    self.conv2 = GATv2Conv(hidden, out_channels, heads= heads, concat=False, add_self_loops= True)
    self.dropout1 = nn.Dropout2d(0.1)
    self.linear = nn.Linear(out_channels, linear_hidden)

  def forward(self, preprocessing):
    output = self.dropout1(self.conv1(preprocessing[0], preprocessing[1]))
    output = self.dropout1(self.conv2(output, preprocessing[1]))
    output = self.linear(output)
    output = torch.mean(output, dim=0)
    output = torch.unsqueeze(output, 0)
    return output

## Test/Val

In [None]:
def train(model, data, data_size):

    print('----------TRAINING----------')
    optimiser = torch.optim.RMSprop(model.parameters(), lr=lr)
    for epoch in range(epochs):
        model.train().to(device)
        optimiser.zero_grad()

        preprocess_rand = list(random.sample(data, data_size))
        batch_loss = 0

        for i in range(ceil(data_size/batch_size)):
            sim, target = [], []
            if i+1 < ceil(data_size/batch_size) or data_size % batch_size == 0:
                batch = batch_size
            else:
                batch = data_size % batch_size
            for j in range(batch):
                out1r = model(preprocess_rand[j + i*batch_size][0]).to(device)
                out2r = model(preprocess_rand[j + i*batch_size][1]).to(device)
                target.append(preprocess_rand[j + i*batch_size][2])
                sim.append((abs(cosine(out1r, out2r))*4)+1) # rescale to range 1 to 5

            Csim = torch.cat(tuple(sim), 0).to(device)
            targets = torch.cat(tuple(target), 0).to(device)

            loss = criterion(Csim, targets)
            loss.backward(retain_graph=True)
            optimiser.step()
            nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
            
            batch_loss += loss*batch

    print('Epoch:', epoch, 'loss:', '%.3f' % (batch_loss.item()/data_size))
    return batch_loss.item()/data_size

def eval(model, data, data_size):
    start = time.process_time()
    print('----------VALIDATION----------')
    model.eval().to(device)
    sim, target = [], []
    for j in range(data_size):
        out1r = model(data[j][0]).to(device)
        out2r = model(data[j][1]).to(device)
        target.append(data[j][2])
        sim.append((abs(cosine(out1r, out2r))*4)+1)

    Csim = torch.cat(tuple(sim), 0).to(device)
    Csim = (Csim * 10).round() / 10
    targets = torch.cat(tuple(target), 0).to(device)

    loss_MSE = criterion(Csim, targets)
    vx = Csim - torch.mean(Csim)
    vy = targets - torch.mean(targets)
    loss_pearson = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))
    
    from scipy import stats
    loss_spearman = stats.spearmanr(Csim.tolist(), targets.tolist())

    print('loss_MSE:', '%.3f' % loss_MSE.item())
    print('loss_pearson:', '%.3f' % loss_pearson.item())
    print('loss_spearman:', '%.3f' % loss_spearman[0])

    return loss_MSE.item(), loss_pearson.item()

## Run code

In [None]:
# Parameters 'Siamese'
heads = 1
hidden = 500
out_channels = 400
linear_hidden = 200
epochs = 5
batch_size = 15
lr = 0.000009

# Siamese network (mean of all nodes)
model = GAT(preprocess_train[0][0][0].shape[1], out_channels).to(device)
cosine = nn.CosineSimilarity(dim=1) # similarity function
criterion = torch.nn.MSELoss() # Loss function

f = open('epochs_loss.txt', 'w')
f.writelines('EPOCHS , LOSS_TRAIN , LOSS_DEV , Pearson')
f.close()
loss_pearson = 0
trial=0
while trial < Train_trial :
  print('\nTRIAL:', trial)
  loss_train = train(model, preprocess_train, len(preprocess_train))
  loss_eval, loss_pearson1 = eval(model, preprocess_test, len(preprocess_test))
  trial +=1
  f = open('epochs_loss.txt', 'a')
  f.writelines(f'\n {5*trial} , {loss_train:.3f} , {loss_eval:.3f} , {loss_pearson.item():.3f}')
  f.close()
  if loss_pearson1 > loss_pearson:
      loss_pearson = loss_pearson1
      torch.save(model.state_dict(), 'model-similarity.pt')

In [None]:
# Preprocessing sentences
%cd /content/sentence-B-full
preprocessB = []
for i in range(4906): # total = 9840
  file1 = '1' + "%04d" % (i+1) + '0.xml'
  node_emb1, edge_idx1 = preprocessing(file1)
  preprocessB.append((node_emb1, edge_idx1))

%cd /content/sentence-A-full
preprocessA = []
for i in range(4906):
  file1 = '1' + "%04d" % (i+1) + '0.xml'
  node_emb1, edge_idx1 = preprocessing(file1)
  preprocessA.append((node_emb1, edge_idx1))

%cd /content


#Group sentence A, sentence B and relatedness score
preprocess_eval = list(zip(preprocessA, preprocessB, full_target[:4906]))
del preprocessA
del preprocessB
_, _ = eval(model, preprocess_eval, len(preprocess_eval))

In [None]:
model1 = GAT(preprocess_eval[0][0][0].shape[1], out_channels).to(device)
model1.load_state_dict(torch.load('/content/model-similarity.pt', map_location=device))
model1.eval()