In [1]:
import json
from pathlib import Path
import numpy as np

from sentence_transformers import SentenceTransformer

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv, GATv2Conv
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import from_networkx
import torch.nn.functional as F
from torch.nn.functional import softmax

import os
import networkx as nx

from transformers import BertTokenizer, BertForSequenceClassification
from tqdm import tqdm


In [2]:
def flatten(list_of_list):
    return [item for sublist in list_of_list for item in sublist]

path_to_training = Path("training")
path_to_test = Path("test")

#####
# training and test sets of transcription ids
#####
training_set = ['ES2002', 'ES2005', 'ES2006', 'ES2007', 'ES2008', 'ES2009', 'ES2010', 'ES2012', 'ES2013', 'ES2015', 'ES2016', 'IS1000', 'IS1001', 'IS1002', 'IS1003', 'IS1004', 'IS1005', 'IS1006', 'IS1007', 'TS3005', 'TS3008', 'TS3009', 'TS3010', 'TS3011', 'TS3012']
#training_set = ['ES2002', 'ES2005', 'ES2006', 'ES2007', 'ES2008', 'ES2009', 'ES2010']
training_set = flatten([[m_id+s_id for s_id in 'abcd'] for m_id in training_set])
training_set.remove('IS1002a')
training_set.remove('IS1005d')
training_set.remove('TS3012c')

test_set = ['ES2003', 'ES2004', 'ES2011', 'ES2014', 'IS1008', 'IS1009', 'TS3003', 'TS3004', 'TS3006', 'TS3007']
test_set = flatten([[m_id+s_id for s_id in 'abcd'] for m_id in test_set])


In [3]:
y_training = []
with open("training_labels.json", "r") as file:
    training_labels = json.load(file)
X_training = []
for transcription_id in training_set:
    with open(path_to_training / f"{transcription_id}.json", "r") as file:
        transcription = json.load(file)
    
    for utterance in transcription:
        X_training.append(utterance["speaker"] + ": " + utterance["text"])
    
    y_training += training_labels[transcription_id]

In [4]:
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(12345)

        self.linear1 = nn.Linear(388,128)

        self.linear2 = nn.Linear(128,128)

        self.conv1 = GATv2Conv(in_channels=128, out_channels=32, dropout=0.2,heads=4,edge_dim=16, aggr="mean")
        self.conv2 = GATv2Conv(in_channels=32*4, out_channels=128, dropout=0.2,heads=4, edge_dim=16, concat=False,aggr="max")

        self.fc = nn.Linear(128, 16)
        self.classifier = nn.Linear(16, 2)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, edge_index, edge_attr):
        x = x.to(torch.float)

        x = self.linear1(x)

        x = self.relu(x)

        x = self.dropout(x)
        print('ok')
        h = self.conv1(x, edge_index=edge_index, edge_attr=edge_attr)
       
        h = self.tanh(h)

        h = self.conv2(h, edge_index=edge_index, edge_attr=edge_attr)
        print('ok2')
        h = self.tanh(h)

        out = self.fc(h)
        
        out = self.dropout(out)
        out = self.relu(out)
        out = self.classifier(out)


        return out, h


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert = SentenceTransformer('all-MiniLM-L6-v2').to(device)
bert.train()
def embed_text(text):
# Embeds text into 384-dimensional Space
    return bert.encode(text, show_progress_bar=True)
    

        

In [6]:
print(device)
print(torch.cuda.is_available())

cpu
False


In [18]:
def extract_data(path,label_file=None):
    set_of_edge_attr = set()
    assert path != None
    # One-hot encode
    dict_speakers =  {'UI': np.array([1,0,0,0]), 'PM': np.array([0,1,0,0]), 'ME': np.array([0,0,1,0]), 'ID': np.array([0,0,0,1])}

    files = [file.split('.')[0] for file in os.listdir(path) if file.endswith('.json')]
    data = []
    vocab_of_texts_edges = set()
    if label_file!=None:
        with open(f"{label_file}.json", "r") as file:
            labels = json.load(file)
    

    for file_name in tqdm(files):
        with open(path +"/"+ f"{file_name}.txt", "r") as file:
            c = file.readline()
            while len(c) != 0:
                index1, text, index2 = c.split(' ')
                set_of_edge_attr.add(text)
                c = file.readline()

    list_of_edge_attr = list(set_of_edge_attr)
    dict_edge_attr = dict()
    for i in range(len(list_of_edge_attr)):
        vec = np.zeros(len(list_of_edge_attr))
        vec[i] = 1
        dict_edge_attr[list_of_edge_attr[i]] = vec


    for file_name in tqdm(files):
        print(file_name)
        # Loading Nodes
        with open(path +'/'+ f"{file_name}.json", "r") as file:
            file_data = json.load(file)
        graph = nx.Graph()

        N_nodes = len(file_data)
        graph.add_nodes_from(list(range(N_nodes)))

        # Loading Nodes attributes
        nodes_attr = [file_data[i]['text'] for i in range(N_nodes)]
        
        #Loading Nodes labels and one hot encoding them
        nodes_labels = np.zeros((N_nodes, 4))
        for i in range(N_nodes):
            nodes_labels[i] = dict_speakers[file_data[i]['speaker']]
        # Embedding Nodes attributes with bert

        # Extracting Edges
        edges_indices = []
        edges_attr = []
        with open(path +"/"+ f"{file_name}.txt", "r") as file:
            c = file.readline()
            while len(c) != 0:
                index1, text, index2 = c.split(' ')
                edges_indices.append((int(index1),int(index2)))

                #nodes_attr[int(index2)] = nodes_attr[int(index1)] +" . "+ nodes_attr[int(index2)]

                edges_attr.append(dict_edge_attr[text])
                edges_attr.append(dict_edge_attr[text])

                # edges_indices.append((int(index2)-1,int(index2)))
                # edges_attr.append(np.zeros((len(dict_edge_attr[text]))))

                c = file.readline()
        #print('finish')

        nodes_attr = embed_text(nodes_attr)
        #print('embedded')
        
        # Concatenating the attributes
        nodes_attr = np.hstack([nodes_labels,nodes_attr])

        # Embedding edges features

        # Add edges list to graph
        graph.add_edges_from(edges_indices)


        data_loader = from_networkx(graph)
        if label_file != None:
            data_loader.y = torch.tensor(labels[file_name],dtype=torch.long)
        data_loader.x = torch.tensor(nodes_attr,dtype=torch.float)
        data_loader.edge_attr = torch.tensor(edges_attr, dtype=torch.float)
        data.append(data_loader)

    
    print(set_of_edge_attr)
    return data, files

In [19]:
import time
epochs= 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
#model = GCN(len(X_training[0]),16,32,2,0.5).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [None]:
data,files=extract_data("training",'training_labels')

In [23]:

    
for dataloader in data:
    for epoch in range(5):
        model.train()
        optimizer.zero_grad()
        output,h = model(dataloader.x, dataloader.edge_index, dataloader.edge_attr)
        #loss_train = F.nll_loss(output[idx_train], class_labels_train)
        loss_train = criterion(output, dataloader.y)
        loss_train.backward()
        optimizer.step()

        


ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2
ok
ok2

In [27]:
data_test[0].edge_attr.shape

torch.Size([588, 16])

In [24]:
data_test,files=extract_data("test")


100%|██████████| 40/40 [00:00<00:00, 1146.22it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

ES2003a


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

  2%|▎         | 1/40 [00:01<01:06,  1.72s/it]

ES2003b


Batches:   0%|          | 0/24 [00:00<?, ?it/s]

  5%|▌         | 2/40 [00:05<01:51,  2.93s/it]

ES2003c


Batches:   0%|          | 0/24 [00:00<?, ?it/s]

  8%|▊         | 3/40 [00:09<02:07,  3.46s/it]

ES2003d


Batches:   0%|          | 0/29 [00:00<?, ?it/s]

 10%|█         | 4/40 [00:13<02:11,  3.64s/it]

ES2004a


Batches:   0%|          | 0/14 [00:00<?, ?it/s]

 12%|█▎        | 5/40 [00:15<01:44,  2.98s/it]

ES2004b


Batches:   0%|          | 0/25 [00:00<?, ?it/s]

 15%|█▌        | 6/40 [00:19<01:56,  3.42s/it]

ES2004c


Batches:   0%|          | 0/27 [00:00<?, ?it/s]

 18%|█▊        | 7/40 [00:23<02:00,  3.65s/it]

ES2004d


Batches:   0%|          | 0/34 [00:00<?, ?it/s]

 20%|██        | 8/40 [00:28<02:03,  3.87s/it]

ES2011a


Batches:   0%|          | 0/12 [00:00<?, ?it/s]

 22%|██▎       | 9/40 [00:30<01:42,  3.30s/it]

ES2011b


Batches:   0%|          | 0/20 [00:00<?, ?it/s]

 25%|██▌       | 10/40 [00:33<01:37,  3.23s/it]

ES2011c


Batches:   0%|          | 0/22 [00:00<?, ?it/s]

 28%|██▊       | 11/40 [00:36<01:33,  3.23s/it]

ES2011d


Batches:   0%|          | 0/24 [00:00<?, ?it/s]

 30%|███       | 12/40 [00:39<01:29,  3.21s/it]

ES2014a


Batches:   0%|          | 0/9 [00:00<?, ?it/s]

 32%|███▎      | 13/40 [00:41<01:15,  2.80s/it]

ES2014b


Batches:   0%|          | 0/24 [00:00<?, ?it/s]

 35%|███▌      | 14/40 [00:45<01:19,  3.07s/it]

ES2014c


Batches:   0%|          | 0/25 [00:00<?, ?it/s]

 38%|███▊      | 15/40 [00:48<01:20,  3.21s/it]

ES2014d


Batches:   0%|          | 0/35 [00:00<?, ?it/s]

 40%|████      | 16/40 [00:53<01:30,  3.77s/it]

IS1008a


Batches:   0%|          | 0/8 [00:00<?, ?it/s]

 42%|████▎     | 17/40 [00:55<01:12,  3.14s/it]

IS1008b


Batches:   0%|          | 0/19 [00:00<?, ?it/s]

 45%|████▌     | 18/40 [00:58<01:06,  3.01s/it]

IS1008c


Batches:   0%|          | 0/14 [00:00<?, ?it/s]

 48%|████▊     | 19/40 [01:00<00:59,  2.86s/it]

IS1008d


Batches:   0%|          | 0/17 [00:00<?, ?it/s]

 50%|█████     | 20/40 [01:03<00:55,  2.79s/it]

IS1009a


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

 52%|█████▎    | 21/40 [01:05<00:49,  2.61s/it]

IS1009b


Batches:   0%|          | 0/20 [00:00<?, ?it/s]

 55%|█████▌    | 22/40 [01:08<00:50,  2.83s/it]

IS1009c


Batches:   0%|          | 0/15 [00:00<?, ?it/s]

 57%|█████▊    | 23/40 [01:11<00:47,  2.78s/it]

IS1009d


Batches:   0%|          | 0/22 [00:00<?, ?it/s]

 60%|██████    | 24/40 [01:14<00:46,  2.89s/it]

TS3003a


Batches:   0%|          | 0/15 [00:00<?, ?it/s]

 62%|██████▎   | 25/40 [01:16<00:38,  2.57s/it]

TS3003b


Batches:   0%|          | 0/22 [00:00<?, ?it/s]

 65%|██████▌   | 26/40 [01:19<00:38,  2.75s/it]

TS3003c


Batches:   0%|          | 0/22 [00:00<?, ?it/s]

 68%|██████▊   | 27/40 [01:22<00:37,  2.91s/it]

TS3003d


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 70%|███████   | 28/40 [01:26<00:38,  3.20s/it]

TS3004a


Batches:   0%|          | 0/20 [00:00<?, ?it/s]

 72%|███████▎  | 29/40 [01:29<00:32,  2.96s/it]

TS3004b


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

 75%|███████▌  | 30/40 [01:33<00:34,  3.49s/it]

TS3004c


Batches:   0%|          | 0/43 [00:00<?, ?it/s]

 78%|███████▊  | 31/40 [01:38<00:34,  3.86s/it]

TS3004d


Batches:   0%|          | 0/37 [00:00<?, ?it/s]

 80%|████████  | 32/40 [01:43<00:32,  4.07s/it]

TS3006a


Batches:   0%|          | 0/20 [00:00<?, ?it/s]

 82%|████████▎ | 33/40 [01:45<00:25,  3.65s/it]

TS3006b


Batches:   0%|          | 0/37 [00:00<?, ?it/s]

 85%|████████▌ | 34/40 [01:50<00:23,  3.91s/it]

TS3006c


Batches:   0%|          | 0/40 [00:00<?, ?it/s]

 88%|████████▊ | 35/40 [01:55<00:20,  4.15s/it]

TS3006d


Batches:   0%|          | 0/52 [00:00<?, ?it/s]

 90%|█████████ | 36/40 [02:01<00:19,  4.77s/it]

TS3007a


Batches:   0%|          | 0/21 [00:00<?, ?it/s]

 92%|█████████▎| 37/40 [02:03<00:12,  4.10s/it]

TS3007b


Batches:   0%|          | 0/30 [00:00<?, ?it/s]

 95%|█████████▌| 38/40 [02:07<00:08,  4.11s/it]

TS3007c


Batches:   0%|          | 0/30 [00:00<?, ?it/s]

 98%|█████████▊| 39/40 [02:12<00:04,  4.25s/it]

TS3007d


Batches:   0%|          | 0/46 [00:00<?, ?it/s]

100%|██████████| 40/40 [02:17<00:00,  3.45s/it]

{'Explanation', 'Background', 'Elaboration', 'Clarification_question', 'Alternation', 'Conditional', 'Parallel', 'Correction', 'Q-Elab', 'Contrast', 'Question-answer_pair', 'Acknowledgement', 'Result', 'Narration', 'Comment', 'Continuation'}





In [26]:
test_labels = {}
for i in range(len(data_test)):
    dataloader=data_test[i]
    model.eval()
    y_test, embedding = model(dataloader.x, dataloader.edge_index,dataloader.edge_attr)
    #y_test = (y_test > 0.5).astype(int)
    y_test = torch.argmax(y_test, dim=1)
    print(y_test)
    test_labels[files[i]] = y_test.tolist()

with open("test2.json", "w") as file:
    json.dump(test_labels, file, indent=4)

ok
ok2
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0])
ok
