In [1]:
import torch, sys
sys.path.insert(0, '../')
from my_utils import gpu_utils
import importlib, gc
from my_utils.alignment_features import *
import my_utils.alignment_features as afeatures
importlib.reload(afeatures)
import gnn_utils.graph_utils as gutils



In [2]:
# !pip install torch-geometric
# !pip install tensorboardX

# !wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
# !unzip ngrok-stable-linux-amd64.zip

#  print(torch.version.cuda)
#  print(torch.__version__)    

dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [3]:

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

import time
from datetime import datetime

import networkx as nx
import numpy as np
import torch
import torch.optim as optim

from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader

import torch_geometric.transforms as T

from tensorboardX import SummaryWriter
from sklearn.manifold import TSNE
# import matplotlib.pyplot as plt




In [4]:
from my_utils import align_utils as autils, utils
import argparse
from multiprocessing import Pool
import random

# set random seed
config_file = "/mounts/Users/student/ayyoob/Dokumente/code/pbc-ui-demo/config_pbc.ini"
utils.setup(config_file)

params = argparse.Namespace()


params.gold_file = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-grc-gold-alignments_train.txt"
pros, surs = autils.load_gold(params.gold_file)
all_verses = list(pros.keys())
params.gold_file = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-heb-gold-alignments_train.txt"
pros, surs = autils.load_gold(params.gold_file)
all_verses.extend(list(pros.keys()))
all_verses = list(set(all_verses))
print(len(all_verses))

params.editions_file =  "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi_lang_list.txt"
editions, langs = autils.load_simalign_editions(params.editions_file)
current_editions = [editions[lang] for lang in langs]

def get_pruned_verse_alignments(args):
    verse, current_editions = args
    
    verse_aligns_inter = autils.get_verse_alignments(verse)
    verse_aligns_gdfa = autils.get_verse_alignments(verse, gdfa=True)

    autils.prune_non_necessary_alignments(verse_aligns_inter, current_editions)
    autils.prune_non_necessary_alignments(verse_aligns_gdfa, current_editions)

    gc.collect()
    return verse_aligns_inter, verse_aligns_gdfa
    

args = []
for i,verse in enumerate(all_verses):
    args.append((verse, current_editions[:]))


24159


In [5]:
#importlib.reload(afeatures)

class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, features, n_head = 2, has_tagfreq_feature=False,):
        super(Encoder, self).__init__()
        self.conv1 = pyg_nn.GATConv(in_channels, 2*out_channels, heads= n_head)
        self.conv2 = pyg_nn.GATConv(2 * n_head *  out_channels , out_channels, heads= 1)
        self.fin_lin = nn.Linear(out_channels, out_channels)
        
        if has_tagfreq_feature:
            self.feature_encoder = afeatures.FeatureEncoding(features, [normalized_tag_frequencies, word_vectors])
        else:
            self.feature_encoder = afeatures.FeatureEncoding(features, [word_vectors])

    def forward(self, x, edge_index):
        x = self.feature_encoder(x, dev)
        x = F.elu(self.conv1(x, edge_index, ))
        x = F.elu(self.conv2(x, edge_index))
        return F.relu(self.fin_lin(x))

In [6]:
def freeze_encoders_embedding(encoder):
    for i,ft in enumerate(encoder.feature_types):
        if ft.type == MAPPING:
            print('doing it')
            encoder.layers[i] = afeatures.MappingEncoding(encoder.layers[i].emb.weight, freeze=True)

In [7]:
def clean_memory():
    gc.collect()
    with torch.no_grad():
        torch.cuda.empty_cache()

class DataEncoder():

    def __init__(self, data_loader, model, mask_language):
        self.data_loader = data_loader
        self.model = model
        self.mask_language = mask_language
    
    def __iter__(self):
        for i,batch in enumerate(tqdm(self.data_loader)):
            
            x = batch['x'][0].to(dev)
            edge_index = batch['edge_index'][0].to(dev)
            verse = batch['verse'][0]

            if verse in masked_verses:
                continue

            try:
                if self.mask_language:
                    x[:, 0] = 0
                z = self.model.encode(x, edge_index)
                
            except Exception as e:
                global sag, khar, gav
                sag, khar, gav =  (i, batch, verse)
                print(e)
                1/0
            
            yield z, verse, i, batch

def train(epoch, data_loader, mask_language, max_batches=999999999):
    global optimizer
    total_loss = 0
    model.train()
    loss_multi_round = 0

    data_encoder = DataEncoder(data_loader, model, mask_language)

    for z, verse, i, batch in data_encoder:
        optimizer.zero_grad()
        
        target = batch['pos_classes'][0].to(dev)
        _, labels = torch.max(target, 1)
        
        index = batch['pos_index'][0].to(dev)
        preds = model.decoder(z, index, batch)

        # print(preds.shape, labels.shape)
        loss = criterion(preds, labels)
        loss = loss * target.shape[0]
        loss.backward()
        total_loss += loss.item()

        if (i+1) % 5 == 0:
            optimizer.step()
            

        if i % 500 == 499:
            print(f"loss: {total_loss}")
            total_loss = 0
            test(epoch, test_data_loader, mask_language)
            model.train()
            clean_memory()

        if i == max_batches:
            break
    
    print(f"total train loss: {total_loss}")


In [8]:
class POSDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, n_class, drop_out=0):
        super(POSDecoder, self).__init__()

        self.transfer = nn.Sequential(nn.Linear(input_size, hidden_size), nn.ReLU(), nn.Dropout(drop_out),
                        nn.Linear(hidden_size, n_class))

    def forward(self, z, index, batch=None):
        h = z[index, :]

        res = self.transfer(h)

        return res

class POSDecoderTransformer(nn.Module):
    def __init__(self, input_size, hidden_size, n_class, drop_out=0):
        super(POSDecoderTransformer, self).__init__()

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=8, dim_feedforward=hidden_size)
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=2)

        self.transfer = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU(), nn.Dropout(drop_out),
                        nn.Linear(hidden_size, n_class))

    def forward(self, z, index, batch):
        language_based_nodes = batch['lang_based_nodes']
        transformer_indices = batch['transformer_indices']

        batch = []
        for lang_nodes in language_based_nodes:
            tensor = z[lang_nodes, :]
            tensor = F.pad(tensor, (0, 0, 0, 150 - tensor.size(0)))
            batch.append(tensor)
        
        batch = torch.stack(batch)
        batch = torch.transpose(batch, 0, 1)

        h = self.transformer(batch)
        h = torch.transpose(h, 0, 1)
        h = h[transformer_indices[0], transformer_indices[1], :]

        res = self.transfer(h)

        return res

In [9]:
def test(epoch, testloader, mask_language, filter_wordtypes=None):
    print('testing',  epoch)
    model.eval()
    total = 0
    correct = 0

    data_encoder = DataEncoder(testloader, model, mask_language)
    
    with torch.no_grad():
        for z, verse, i, batch in data_encoder:
            
            target = batch['pos_classes'][0].to(dev)
            index = batch['pos_index'][0].to(dev)
            
            if filter_wordtypes != None:
                non_filtered_words = filter_wordtypes[batch['x'][0][:, 9].long()] == 1
                non_filtered_words = non_filtered_words[index]
                index = index[non_filtered_words]

                target = target[non_filtered_words, :]

            preds = model.decoder(z, index, batch)
            
            if preds.size(0) > 0:
                _, predicted = torch.max(preds, 1)
                _, labels = torch.max(target, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'test, epoch: {epoch}, total:{total} ACC: {correct/total}')
    clean_memory()


In [10]:
def majority_voting_test(data_loader1, data_loader2):
    total = 0
    correct = 0
    
    for i,(batch, batch2) in enumerate(tqdm(zip(data_loader1, data_loader2))) :
            
        x = batch['x'][0]
        edge_index = batch['edge_index'][0]
        verse = batch['verse'][0]

        if verse in masked_verses:
            continue

        target = batch['pos_classes'][0]
        index = batch['pos_index'][0]

        index2 = batch2['pos_index'][0]
        

        for node, label in zip(index,target):
            other_side = edge_index[1, edge_index[0, :] == node]
            other_side_withpos = other_side[[True if i in index2 else False for i in other_side]]
            other_side_target_indices = [(i == index2).nonzero(as_tuple=True)[0].item() for i in other_side_withpos]
            #print(other_side_target_indices)
            proj_tags = batch2['pos_classes'][0][other_side_target_indices]

            if proj_tags.size(0) > 0:
                _, proj_tags = torch.max(proj_tags, 1)
                #print(target.shape, node, index.shape, proj_tags, other_side)
                
                if torch.argmax(label) == torch.mode(proj_tags)[0]:
                    correct += 1
                
                total += 1

    print(f'test, , total:{total} ACC: {correct/total}')


In [11]:
#print(wordtype_frequencies[736])
#print(wordtype_frequencies[1473])
#print(wordtype_frequencies[3683])
#print(wordtype_frequencies[7367])
#print(wordtype_frequencies[14733])

In [12]:
#frequent_words = torch.zeros(word_frequencies.size(0))
#frequent_words[word_frequencies > -1] = 1

#test(1, test_data_loader, filter_wordtypes=frequent_words)

In [13]:
from gensim.models import Word2Vec
w2v_model = Word2Vec.load("/mounts/work/ayyoob/models/w2v/word2vec_helfi_langs_15e.model")

print(w2v_model.wv.vectors.shape)
word_vectors = torch.from_numpy(w2v_model.wv.vectors).float()

(2354770, 100)


In [14]:
import pickle

train_verses = all_verses[:]
test_verses = all_verses[:] 
editf1 = 'fin-x-bible-helfi'
editf2 = "heb-x-bible-helfi"


if 'jpn-x-bible-newworld' in  current_editions[:]:
     current_editions.remove('jpn-x-bible-newworld')
if 'grc-x-bible-unaccented' in  current_editions[:]:
     current_editions.remove('grc-x-bible-unaccented')



data_dir_train = "/mounts/data/proj/ayyoob/align_induction/dataset/dataset_helfi_train_community_word"
data_dir_blinker = "/mounts/data/proj/ayyoob/align_induction/dataset/pruned_alignments_blinker_inter/"
data_dir_grc = "/mounts/data/proj/ayyoob/align_induction/dataset/dataset_helfi_grc_community_word/"
data_dir_heb = "/mounts/data/proj/ayyoob/align_induction/dataset/dataset_helfi_heb_community_word/"

train_dataset = torch.load(f"{data_dir_train}/train_dataset_nox_noedge.torch.bin")
blinker_test_dataset = torch.load(f"{data_dir_blinker}/train_dataset_nox_noedge.torch.bin")
grc_test_dataset = torch.load(f"{data_dir_grc}/train_dataset_nox_noedge.torch.bin")
heb_test_dataset = torch.load(f"{data_dir_heb}/train_dataset_nox_noedge.torch.bin")


In [15]:
import codecs
import collections

postag_map = {"ADJ": 0, "ADP": 1, "ADV": 2, "AUX": 3, "CCONJ": 4, "DET": 5, "INTJ": 6, "NOUN": 7, "NUM": 8, "PART": 9, "PRON": 10, "PROPN": 11, "PUNCT": 12, "SCONJ": 13, "SYM": 14, "VERB": 15, "X": 16}

pos_lang_list = ["eng-x-bible-mixed", "deu-x-bible-newworld", "ces-x-bible-newworld", 
		"fra-x-bible-louissegond","hin-x-bible-newworld", "ita-x-bible-2009", 
		"prs-x-bible-goodnews", "ron-x-bible-2006", "spa-x-bible-newworld"]

def get_db_nodecount(dataset):
	res = 0
	for lang in dataset.nodes_map.values():
		for verse in lang.values():
			res += len(verse)
	
	return res

def get_language_nodes(dataset, lang_list, sentences):
	node_count = get_db_nodecount(dataset)
	pos_labels = torch.zeros(node_count, len(postag_map))

	pos_node_cover = collections.defaultdict(list)
	for lang in lang_list:
		for sentence in sentences:
			if sentence in dataset.nodes_map[lang]:
				for tok in dataset.nodes_map[lang][sentence]:
					pos_node_cover[sentence].append(dataset.nodes_map[lang][sentence][tok])
	
	return pos_labels, pos_node_cover


def get_pos_tags(dataset, pos_lang_list):
	all_tags = {}
	for lang in pos_lang_list:
		if lang not in dataset.nodes_map:
			continue
		all_tags[lang] = {}
		with codecs.open(F"/mounts/work/mjalili/projects/gnn-align/data/pbc_pos_tags/{lang}.conllu", "r", "utf-8") as lang_pos:
			tag_sent = []
			sent_id = ""
			for sline in lang_pos:
				sline = sline.strip()
				if sline == "":
					if sent_id not in dataset.nodes_map[lang]:
						tag_sent = []
						sent_id = ""
						continue

					all_tags[lang][sent_id] = [p[3] for p in tag_sent]
					tag_sent = []
					sent_id = ""
				elif "# verse_id" in sline:
					sent_id = sline.split()[-1]
				elif sline[0] == "#":
					continue
				else:
					tag_sent.append(sline.split("\t"))

	node_count = get_db_nodecount(dataset)
	pos_labels = torch.zeros(node_count, len(postag_map))
	pos_node_cover = collections.defaultdict(list)

	for lang in all_tags:
		for sent_id in all_tags[lang]:
			sent_tags = all_tags[lang][sent_id]
			for w_i in range(len(sent_tags)):
				if w_i not in dataset.nodes_map[lang][sent_id]:
					continue
				pos_labels[dataset.nodes_map[lang][sent_id][w_i], postag_map[sent_tags[w_i]]] = 1
				pos_node_cover[sent_id].append(dataset.nodes_map[lang][sent_id][w_i])

	return pos_labels, pos_node_cover
	#pos_pickle = {"pos_labels": pos_labels, "node_ids_train": pos_ids_train, "node_ids_dev": pos_ids_dev}
	#torch.save(pos_pickle, '/mounts/work/ayyoob/models/gnn/postag')


In [16]:
# blinker_test_dataset = torch.load("/mounts/work/ayyoob/models/gnn/dataset_blinker_full_community_word.pickle", map_location=torch.device('cpu'))
editf12 = "eng-x-bible-mixed"
editf22 = 'fra-x-bible-louissegond'

test_gold_eng_fra = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/eng_fra_pbc/eng-fra.gold"

pros_blinker, surs_blinker = autils.load_gold(test_gold_eng_fra)

blinker_verse_alignments_inter = {}

verses_map = {}

for edit in blinker_test_dataset.nodes_map:
    for verse in blinker_test_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in blinker_test_dataset.nodes_map[edit][verse]:
                verses_map[verse] = blinker_test_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
blinker_verses = [item[0] for item in sorted_verses]


In [17]:
#importlib.reload(afeatures)
#grc_test_dataset = torch.load("/mounts/work/ayyoob/models/gnn/dataset_helfi_grc_test_community_word.pickle", map_location=torch.device('cpu'))
editf_fin = "fin-x-bible-helfi"
editf_grc = 'grc-x-bible-helfi'

test_gold_grc = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-grc-gold-alignments_test.txt"

pros_grc, surs_grc = autils.load_gold(test_gold_grc)

grc_test_verse_alignments_inter = {}
grc_test_verse_alignments_gdfa = {}
gc.collect()

verses_map = {}

for edit in grc_test_dataset.nodes_map:
    for verse in grc_test_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in grc_test_dataset.nodes_map[edit][verse]:
                verses_map[verse] = grc_test_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
grc_test_verses = [item[0] for item in sorted_verses]

gc.collect()

0

In [18]:
#heb_test_dataset = torch.load("/mounts/work/ayyoob/models/gnn/dataset_helfi_heb_test_community_word.pickle", map_location=torch.device('cpu'))

test_gold_heb = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-heb-gold-alignments_test.txt"

pros_heb, surs_heb = autils.load_gold(test_gold_heb)

verses_map = {}

for edit in heb_test_dataset.nodes_map:
    for verse in heb_test_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in heb_test_dataset.nodes_map[edit][verse]:
                verses_map[verse] = heb_test_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
heb_test_verses = [item[0] for item in sorted_verses]
gc.collect()

0

In [19]:
verses_map = {}

for edit in train_dataset.nodes_map:
    for verse in train_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in train_dataset.nodes_map[edit][verse]:
                verses_map[verse] = train_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
all_verses = [item[0] for item in sorted_verses]

long_verses = set()

for edit in train_dataset.nodes_map.keys():
    for verse in train_dataset.nodes_map[edit]:
        to_print = False
        for tok in train_dataset.nodes_map[edit][verse]:
            if tok > 150:
                to_print = True
        if to_print == True:
            long_verses.add(verse)


train_verses = all_verses[:]

masked_verses = list(long_verses)
#masked_verses.extend(blinker_verses)

In [20]:
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import random

def get_language_based_nodes(nodes_map, verse, train_nodes, padding):
    res = []
    transformer_indices = [[-1 for i in range(len(train_nodes))],[-1 for i in range(len(train_nodes))]]
    
    lang_ind = 0
    for lang in nodes_map:
        if verse in nodes_map[lang]:
            items = nodes_map[lang][verse].items()
            items = sorted(items, key=lambda i: i[0])

            to_add = []
            for i, it in enumerate(items):
                to_add.append(it[1] - padding)
                if it[1] in train_nodes:
                    index = train_nodes.index(it[1])
                    transformer_indices[0][index] = lang_ind
                    transformer_indices[1][index] = i

            res.append(to_add)
            lang_ind += 1
            
    return res, transformer_indices

class POSTAGGNNDataset(Dataset):

    def __init__(self, dataset, verses, edit_files, alignments, node_cover, pos_labels, data_dir, create_data=False, group_size = 20):
        self.node_cover = node_cover
        self.pos_labels = pos_labels
        self.data_dir = data_dir
        self.items = self.calculate_size(verses, group_size, node_cover)
        self.dataset = dataset

        if create_data:
            self.calculate_verse_stats(verses, edit_files, alignments, dataset, data_dir)            
        
    def calculate_size(self, verses, group_size, node_cover):
        res = []
        for verse in verses:
            covered_nodes = node_cover[verse]
            random.shuffle(covered_nodes)
            items = [covered_nodes[i:i + group_size] for i in range(0, len(covered_nodes), group_size)]
            res.extend([(verse, i) for i in items])

        return res

    def calculate_verse_stats(self,verses, edition_files, alignments, dataset, data_dir):

        min_edge = 0
        for verse in tqdm(verses):
            min_nodes = 99999999999999
            max_nodes = 0
            edges_tmp = [[],[]]
            x_tmp = []
            features = []
            for i,editf1 in enumerate(edition_files):
                for j,editf2 in enumerate(edition_files[i+1:]):
                    aligns = autils.get_aligns(editf1, editf2, alignments[verse])
                    if aligns != None:
                        for align in aligns:
                            try:
                                n1,_ = gutils.node_nom(verse, editf1, align[0], None, dataset.nodes_map, x_tmp, edition_files, features)
                                n2,_ = gutils.node_nom(verse, editf2, align[1], None, dataset.nodes_map, x_tmp, edition_files, features)
                                edges_tmp[0].extend([n1, n2])

                                max_nodes = max(n1, n2, max_nodes)
                                min_nodes = min(n1, n2, min_nodes)
                            except Exception as e:
                                print(editf1, editf2, verse)
                                raise(e)

            self.verse_info = {}

            self.verse_info['padding'] = min_nodes
            
            self.verse_info['x'] = torch.clone(dataset.x[min_nodes:max_nodes+1,:])
            
            self.verse_info['edge_index'] = torch.clone(dataset.edge_index[:, min_edge : min_edge + len(edges_tmp[0])] - min_nodes)

            if torch.min(self.verse_info['edge_index']) != 0:
                print(verse, min_nodes, max_nodes, min_edge, len(edges_tmp[0]))
                print(torch.min(self.verse_info['edge_index']))
            
            if self.verse_info['x'].shape[0] != torch.max(self.verse_info['edge_index']) + 1 :
                print(verse, min_nodes, max_nodes, min_edge, len(edges_tmp[0]))
                print(torch.min(self.verse_info['edge_index']))
            
            min_edge = min_edge + len(edges_tmp[0])

            torch.save(self.verse_info, f"{data_dir}/verses/{verse}_info.torch.bin")
        
        dataset.x = None
        dataset.edge_index = None
        torch.save(dataset, f"{data_dir}/train_dataset_nox_noedge.torch.bin")
    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        verse, nodes = self.items[idx]
        
        self.verse_info = {verse: torch.load(f'{self.data_dir}/verses/{verse}_info.torch.bin')}


        word_number = self.verse_info[verse]['x'][:, 9]
        padding = self.verse_info[verse]['padding']
        
        language_based_nodes, transformer_indices = get_language_based_nodes(self.dataset.nodes_map, verse, nodes, padding)

        # # Add POSTAG to set of features
        # postags = self.pos_labels[padding: self.verse_info[verse]['x'].size(0) + padding, : ]
        # postags = postags.detach().clone()
        # postags[torch.LongTensor(nodes) - padding, :] = 0
        # self.verse_info[verse]['x'] = torch.cat((self.verse_info[verse]['x'], postags), dim=1)

        # Add token id as a feature, used to extract token information (like token's tag distribution)
        word_number = torch.unsqueeze(word_number, 1)
        self.verse_info[verse]['x'] = torch.cat((self.verse_info[verse]['x'], word_number), dim=1)

        return {'verse':verse, 'x':self.verse_info[verse]['x'], 'edge_index':self.verse_info[verse]['edge_index'], 
                'pos_classes': self.pos_labels[nodes, :], 'pos_index': torch.LongTensor(nodes) - padding, 
                'padding': padding, 'lang_based_nodes': language_based_nodes, 'transformer_indices': transformer_indices}


# # train_pos_labels, train_pos_node_cover = get_pos_tags(train_dataset, pos_lang_list)
# # torch.save({'pos_labels':train_pos_labels, 'pos_node_cover':train_pos_node_cover}, f'{data_dir_train}/pos_data.torch.bin')
pos_data = torch.load(f'{data_dir_train}/pos_data.torch.bin')
train_pos_labels, train_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']
gnn_dataset_train_pos = POSTAGGNNDataset(train_dataset, train_verses, current_editions, {},
                    train_pos_node_cover, train_pos_labels, data_dir_train, group_size = 100)

## blinker_pos_labels, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, pos_lang_list)
## torch.save({'pos_labels':blinker_pos_labels, 'pos_node_cover': blinker_pos_node_cover}, f'{data_dir_blinker}/pos_data.torch.bin')
pos_data = torch.load(f'{data_dir_blinker}/pos_data.torch.bin')
blinker_pos_labels, blinker_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']
gnn_dataset_blinker_pos = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
                            blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 10000)

##grc_pos_labels, grc_pos_node_cover = get_pos_tags(grc_test_dataset)
##torch.save({'pos_labels':grc_pos_labels, 'pos_node_cover': grc_pos_node_cover}, f'{data_dir_grc}/pos_data.torch.bin')
pos_data = torch.load(f'{data_dir_grc}/pos_data.torch.bin')
grc_pos_labels, grc_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']
gnn_dataset_grc_pos = POSTAGGNNDataset(grc_test_dataset, grc_test_verses, current_editions, grc_test_verse_alignments_inter, grc_pos_node_cover, grc_pos_labels, data_dir_grc, group_size = 100)

##heb_pos_labels, heb_pos_node_cover = get_pos_tags(heb_test_dataset)
##torch.save({'pos_labels':heb_pos_labels, 'pos_node_cover': heb_pos_node_cover}, f'{data_dir_heb}/pos_data.torch.bin')
pos_data = torch.load(f'{data_dir_heb}/pos_data.torch.bin')
heb_pos_labels, heb_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']
gnn_dataset_heb_pos = POSTAGGNNDataset(heb_test_dataset, heb_test_verses, current_editions, {}, heb_pos_node_cover, heb_pos_labels, data_dir_heb, group_size = 100)


gnn_dataset_train_pos_bigbatch = POSTAGGNNDataset(train_dataset, train_verses, current_editions, {},
                   train_pos_node_cover, train_pos_labels, data_dir_train, group_size = 10000)
train_data_loader_bigbatch = DataLoader(gnn_dataset_train_pos_bigbatch, batch_size=1, shuffle=False)

gnn_dataset_grc_pos_bigbatch = POSTAGGNNDataset(grc_test_dataset, grc_test_verses, current_editions, {}, grc_pos_node_cover, grc_pos_labels, data_dir_grc, group_size = 10000)
grc_data_loader_bigbatch = DataLoader(gnn_dataset_grc_pos_bigbatch, batch_size=1, shuffle=False)

gnn_dataset_heb_pos_bigbatch = POSTAGGNNDataset(heb_test_dataset, heb_test_verses, current_editions, {}, heb_pos_node_cover, heb_pos_labels, data_dir_heb, group_size = 10000)
heb_data_loader_bigbatch = DataLoader(gnn_dataset_heb_pos_bigbatch, batch_size=1, shuffle=False)
blinker_data_loader_bigbatch = DataLoader(gnn_dataset_blinker_pos, batch_size=1, shuffle=False)

In [21]:
no_eng_langs = pos_lang_list[:]
no_eng_langs.remove('eng-x-bible-mixed')

# # train_pos_labels, train_pos_node_cover = get_pos_tags(train_dataset, no_eng_langs)
# # gnn_dataset_train_pos = POSTAGGNNDataset(train_dataset, train_verses, current_editions, verse_alignments_inter,
# #                        train_pos_node_cover, train_pos_labels, data_dir_train, group_size = 10)


# _, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, ['eng-x-bible-mixed'])
# blinker_pos_labels, _ = get_pos_tags(blinker_test_dataset, pos_lang_list)
# gnn_dataset_blinker_pos = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
#                             blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 100)



In [41]:
def save_model(model, name):
    #model.encoder.feature_encoder.feature_types[0] = afeatures.OneHotFeature(20, 83, 'editf')
    #model.encoder.feature_encoder.feature_types[1] = afeatures.OneHotFeature(32, 150, 'position')
    #model.encoder.feature_encoder.feature_types[2] = afeatures.FloatFeature(4, 'degree_centrality')
    #model.encoder.feature_encoder.feature_types[3] = afeatures.FloatFeature(4, 'closeness_centrality')
    #model.encoder.feature_encoder.feature_types[4] = afeatures.FloatFeature(4, 'betweenness_centrality')
    #model.encoder.feature_encoder.feature_types[5] = afeatures.FloatFeature(4, 'load_centrality')
    #model.encoder.feature_encoder.feature_types[6] = afeatures.FloatFeature(4, 'harmonic_centrality')
    #model.encoder.feature_encoder.feature_types[7] = afeatures.OneHotFeature(32, 250, 'greedy_modularity_community')
    #model.encoder.feature_encoder.feature_types[8] = afeatures.OneHotFeature(32, 250, 'community_2')
    #model.encoder.feature_encoder.feature_types[9] = afeatures.MappingFeature(100, 'word')
    #model.encoder.feature_encoder.feature_types[10] = afeatures.MappingFeature(len(postag_map), 'tag_priors', freeze=True)
    torch.save(model, f'/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_{name}_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.pickle')

In [23]:
# #test(1, test_data_loader) 

# #finetune_pos_labels, finetune_pos_node_cover = get_pos_tags(train_dataset, ['eng-x-bible-mixed'])
# #gnn_dataset_finetune_pos = POSTAGGNNDataset(train_dataset, train_verses, current_editions, verse_alignments_inter,
# #                       finetune_pos_node_cover, finetune_pos_labels, data_dir_train, group_size = 100)
# #finetune_data_loader = DataLoader(gnn_dataset_finetune_pos, batch_size=1, shuffle=False)

# # train(1, finetune_data_loader, max_batches=1000)
# # test(1, test_data_loader) 

# blinker_pos_labels, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, ['eng-x-bible-mixed'])
# gnn_dataset_blinker_pos = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
#                              blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 10000)

# blinker_pos_labels, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, no_eng_langs)
# gnn_dataset_blinker_pos_majvoting_test = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
#                              blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 10000)

# test_data_loader = DataLoader(gnn_dataset_blinker_pos, batch_size=1, shuffle=False)
# test_data_loader_majvoting = DataLoader(gnn_dataset_blinker_pos_majvoting_test, batch_size=1, shuffle=False)
# majority_voting_test(test_data_loader, test_data_loader_majvoting)

In [24]:
torch.cuda.set_device(7)
features = train_dataset.features

# features.append(afeatures.PassFeature(name='posTAG', dim=len(postag_map)))
# features.pop()

In [25]:
from tqdm import tqdm

def create_model(tag_frequencies=False, use_transformers=False, train_word_embedding=False, mask_language=True):
    global model, criterion, optimizer, test_data_loader

    features = train_dataset.features
    features[9].freeze = not train_word_embedding
    
    train_data_loader = DataLoader(gnn_dataset_train_pos, batch_size=1, shuffle=True)
    grc_data_loader = DataLoader(gnn_dataset_grc_pos, batch_size=1, shuffle=True)
    heb_data_loader = DataLoader(gnn_dataset_heb_pos, batch_size=1, shuffle=True)
    test_data_loader = DataLoader(gnn_dataset_blinker_pos, batch_size=1, shuffle=True)

    clean_memory()
    drop_out = 0
    n_head = 1
    in_dim = sum(t.out_dim for t in features)


    channels = 512

    decoder_in_dim = n_head * channels 

    # model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/gnn_512_flggll_word_halfTrain_nofeatlinear_encoderlineear_decoderonelayer20210910-235352-.pickle')
    if use_transformers:
        decoder = POSDecoderTransformer(decoder_in_dim, decoder_in_dim*2, len(postag_map))
    else:
        decoder = POSDecoder(decoder_in_dim, decoder_in_dim*2, len(postag_map))
        
    model = pyg_nn.GAE(Encoder(in_dim, channels, features, n_head, has_tagfreq_feature=tag_frequencies), decoder).to(dev)
    freeze_encoders_embedding(model.encoder.feature_encoder)

    model.to(dev)
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

    torch.set_printoptions(edgeitems=5)
    print("model params - decoder params - conv1", sum(p.numel() for p in model.parameters()), sum(p.numel() for p in decoder.parameters()))

    for epoch in range(1, 3):
        print(f"\n----------------epoch {epoch} ---------------")
        
        train(epoch, train_data_loader, mask_language)
        train(epoch, grc_data_loader, mask_language)
        train(epoch, heb_data_loader, mask_language)
        save_model(model, f'posfeat{tag_frequencies}_transformer{use_transformers}_trainWE{train_word_embedding}_maskLang{mask_language}')
        test(epoch, test_data_loader, mask_language) 
        clean_memory()


In [42]:
#create_model()

In [27]:
#initial_model = model.to('cpu')

In [28]:
#i = sag
#batch = khar
#verse = gav
#print(i, verse)

#keys = list(gnn_dataset.verse_info.keys())

#gnn_dataset.verse_info[verse]
#save_model(model, 'freeze-embedding_noLang')

In [29]:
model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_freeze-embedding_noLang_20211022-125404.pickle')
torch.cuda.set_device(0)
model.to(dev)
#epoch = 0
#test_data_loader = DataLoader(gnn_dataset_blinker_pos, batch_size=1, shuffle=False)
#test(epoch, test_data_loader, mask_language=True)
#yoruba_postags = generate_target_lang_tags('yor-x-bible-2010', f"posfeatFalse_transformerFalse_trainWEFalse", True)


GAE(
  (encoder): Encoder(
    (conv1): GATConv(236, 1024, heads=1)
    (conv2): GATConv(1024, 512, heads=1)
    (fin_lin): Linear(in_features=512, out_features=512, bias=True)
    (feature_encoder): FeatureEncoding(
      (layers): ModuleList(
        (0): OneHotEncoding(
          (linear): Linear(in_features=83, out_features=20, bias=True)
        )
        (1): OneHotEncoding(
          (linear): Linear(in_features=150, out_features=32, bias=True)
        )
        (2): FloatEncoding(
          (linear): Linear(in_features=1, out_features=4, bias=True)
        )
        (3): FloatEncoding(
          (linear): Linear(in_features=1, out_features=4, bias=True)
        )
        (4): FloatEncoding(
          (linear): Linear(in_features=1, out_features=4, bias=True)
        )
        (5): FloatEncoding(
          (linear): Linear(in_features=1, out_features=4, bias=True)
        )
        (6): FloatEncoding(
          (linear): Linear(in_features=1, out_features=4, bias=True)
        )

In [33]:
from copy import deepcopy
# These are functions to calculate the frequency vectors and also update training data with predictions for target language tags

# This funcion update existing structures that are used for training with new predictions
def update_trainset_with_predictions(node_cover, pos_labels, padding, index, max_values, pos_tags, word_added_to_train_freq, word_pos):
    accepted_values = max_values > 0.9
    index = index[accepted_values]
    pos_tags = pos_tags[accepted_values]
    word_pos = word_pos[accepted_values]

    # I should filter high repetition words here!
    word_added_to_train_freq[word_pos.long()] += 1
    accepted_by_frequency = word_added_to_train_freq[word_pos.long()] < 10
    index = index[accepted_by_frequency]
    pos_tags = pos_tags[accepted_by_frequency]

    index_global = index + padding
    node_cover.extend(index_global.tolist())
    pos_labels[index_global, pos_tags] = 1

# This function iterates a dataset to creates the tag vectors and update training structures
def get_words_tag_frequence(model, word_count, class_count, data_loader, tag_frequencies=None, from_gold_data=False, node_cover=None, pos_labels=None, mask_language=True):
    
    res = tag_frequencies
    if res == None:
        res = torch.ones(word_count, class_count)
        res[:, :] = 0.0000001
    
    data_encoder = DataEncoder(data_loader, model, mask_language)
    word_added_to_train_freq = torch.zeros(word_count)

    with torch.no_grad():

        for z, verse, i, batch in data_encoder:
            index = batch['pos_index'][0].to(dev)
            
            if from_gold_data:
                tags_onehot = batch['pos_classes'][0]
            else:
                tags_onehot = model.decoder(z, index, batch)

            max_values, pos_tags = torch.max(torch.softmax(tags_onehot, dim=1), 1)
            word_pos = batch['x'][0][index, 9]

            if not from_gold_data:
                if node_cover != None:
                    update_trainset_with_predictions(node_cover[verse], pos_labels, batch['padding'][0], index, max_values, pos_tags, word_added_to_train_freq, word_pos)
                accepted_values = max_values > 0.5
                word_pos = word_pos[accepted_values]
                pos_tags = pos_tags[accepted_values]

            res[word_pos.long(), pos_tags.long()] += 1



    sm = torch.sum(res, dim=1)
    res_normalized = (res.transpose(1,0) / sm).transpose(1,0)
    
    return res_normalized, res

def get_data_loadrs_for_target_editions(target_editions, dataset, pos_node_cover, verses, data_dir):
    target_pos_labels, target_pos_node_cover = get_language_nodes(dataset, target_editions, pos_node_cover.keys())
    gnn_dataset_target_pos = POSTAGGNNDataset(dataset, verses, None, {}, target_pos_node_cover, target_pos_labels, data_dir, group_size = 500)
    target_data_loader = DataLoader(gnn_dataset_target_pos, batch_size=1, shuffle=False)
    
    return target_data_loader

# Calls the above functions over any of the datasets (train, grc, heb, blinker) to get pos_tag vectors for each word. (once for source languages and once for target languages)
def get_tag_frequencies_node_tags(editions):
    train_pos_node_cover_copy, train_pos_labels_copy = deepcopy(train_pos_node_cover), deepcopy(train_pos_labels)
    grc_pos_node_cover_copy, grc_pos_labels_copy = deepcopy(grc_pos_node_cover), deepcopy(grc_pos_labels)
    heb_pos_node_cover_copy, heb_pos_labels_copy = deepcopy(heb_pos_node_cover), deepcopy(heb_pos_labels)
    blinker_pos_node_cover_copy, blinker_pos_labels_copy = deepcopy(blinker_pos_node_cover), deepcopy(blinker_pos_labels)

    _, tag_frequencies = get_words_tag_frequence(model, 2354770, len(postag_map), train_data_loader_bigbatch, from_gold_data=True)
    _, tag_frequencies = get_words_tag_frequence(model, 2354770, len(postag_map), grc_data_loader_bigbatch, from_gold_data=True, tag_frequencies=tag_frequencies)
    _, tag_frequencies = get_words_tag_frequence(model, 2354770, len(postag_map), heb_data_loader_bigbatch, from_gold_data=True, tag_frequencies=tag_frequencies)
    _, tag_frequencies = get_words_tag_frequence(model, 2354770, len(postag_map), blinker_data_loader_bigbatch, from_gold_data=True, tag_frequencies=tag_frequencies)

    
    target_data_loader = get_data_loadrs_for_target_editions(editions, train_dataset, train_pos_node_cover, train_verses, data_dir_train)
    _, tag_frequencies_target = get_words_tag_frequence(model, 2354770, len(postag_map), target_data_loader, node_cover=train_pos_node_cover_copy, pos_labels=train_pos_labels_copy)
    target_data_loader = get_data_loadrs_for_target_editions(editions, grc_test_dataset, grc_pos_node_cover, grc_test_verses, data_dir_grc)
    _, tag_frequencies_target = get_words_tag_frequence(model, 2354770, len(postag_map), target_data_loader, node_cover=grc_pos_node_cover_copy, pos_labels=grc_pos_labels_copy, tag_frequencies=tag_frequencies_target)
    target_data_loader = get_data_loadrs_for_target_editions(editions, heb_test_dataset, heb_pos_node_cover, heb_test_verses, data_dir_heb)
    _, tag_frequencies_target = get_words_tag_frequence(model, 2354770, len(postag_map), target_data_loader, node_cover=heb_pos_node_cover_copy, pos_labels=heb_pos_labels_copy, tag_frequencies=tag_frequencies_target)
    target_data_loader = get_data_loadrs_for_target_editions(editions, blinker_test_dataset, blinker_pos_node_cover, blinker_verses, data_dir_blinker)
    _, tag_frequencies_target = get_words_tag_frequence(model, 2354770, len(postag_map), target_data_loader, node_cover=blinker_pos_node_cover_copy, pos_labels=blinker_pos_labels_copy, tag_frequencies=tag_frequencies_target)

    tag_frequencies += tag_frequencies_target
    return tag_frequencies, tag_frequencies_target, train_pos_node_cover_copy, train_pos_labels_copy, grc_pos_node_cover_copy, grc_pos_labels_copy, heb_pos_node_cover_copy, heb_pos_labels_copy, blinker_pos_node_cover_copy, blinker_pos_labels_copy

res_ = get_tag_frequencies_node_tags(['yor-x-bible-2010'])
tag_frequencies, tag_frequencies_target, train_pos_node_cover_copy, train_pos_labels_copy, grc_pos_node_cover_copy, grc_pos_labels_copy, heb_pos_node_cover_copy, heb_pos_labels_copy, blinker_pos_node_cover_copy, blinker_pos_labels_copy = res_

In [38]:
word_frequencies_target = torch.sum(tag_frequencies_target, dim=1)
tag_frequencies_copy = tag_frequencies.detach().clone()

tag_frequencies_copy[torch.logical_and(word_frequencies_target>0.1, word_frequencies_target<3), :] = 0.0000001

# We have to give uniform noise to some training examples to prevent the model from returning one of the most frequent tags always!!
uniform_noise = torch.BoolTensor(tag_frequencies.size(0))
uniform_noise[:] = True
shuffle_tensor = torch.randperm(tag_frequencies.size(0))[:int(tag_frequencies.size(0)*0.7)]
uniform_noise[shuffle_tensor] = False
tag_frequencies_copy[torch.logical_and(uniform_noise, word_frequencies_target < 0.1), :] = 0.0000001

sm = torch.sum(tag_frequencies_copy, dim=1)
normalized_tag_frequencies = (tag_frequencies_copy.transpose(1,0) / sm).transpose(1,0)

In [45]:
#importlib.reload(afeatures)
#initial_model = model
#features.append(afeatures.MappingFeature(len(postag_map), 'tag_priors', freeze=True))
torch.cuda.set_device(1)
#create_model(tag_frequencies=True)
yoruba_postags = generate_target_lang_tags('yor-x-bible-2010', f"posfeatTrue_transformerFalse_trainWEFalse", True)

100%|██████████| 24078/24078 [09:09<00:00, 43.85it/s]
100%|██████████| 2225/2225 [00:43<00:00, 51.45it/s]
100%|██████████| 783/783 [00:27<00:00, 28.74it/s]
100%|██████████| 250/250 [00:06<00:00, 39.13it/s]


In [None]:

#normalized_gold_frequencies, gold_frequencies_all = get_words_tag_frequence(model, 2354770, len(postag_map), english_data_loader, from_gold_data=True)

#gold_frequencies_all = gold_frequencies
word_frequencies = torch.sum(gold_frequencies_all, dim=1)

subjectword_indices =  word_frequencies > 0.1
print(word_frequencies.shape)
gold_frequencies = gold_frequencies_all[subjectword_indices, :]
predicted_frequencies = tag_frequencies_english[subjectword_indices, :]
wordtype_frequencies = word_frequencies[subjectword_indices]
print(gold_frequencies.shape)

_, gold_tags = torch.max(gold_frequencies, dim=1)
_, predicted_tags = torch.max(predicted_frequencies, dim=1)

sorted_wordtype_frequencies, sort_pattern = torch.sort(wordtype_frequencies, descending=True)

sorted_gold_tags = gold_tags[sort_pattern]
sorted_predicted_tags = predicted_tags[sort_pattern]
quarter_size = int(sorted_gold_tags.size(0)/2.0)

print('quarter size', quarter_size)
print("general accuracy", torch.sum(gold_tags == predicted_tags)/predicted_tags.size(0))
print('first quarter accuracy', torch.sum(sorted_gold_tags[:quarter_size] == sorted_predicted_tags[:quarter_size])/quarter_size)
print('last part accuracy', torch.sum(sorted_gold_tags[1*quarter_size:] == sorted_predicted_tags[1*quarter_size:])/sorted_predicted_tags[1*quarter_size:].size(0))

print('total token count', torch.sum(wordtype_frequencies))
print('first quarter words token count', torch.sum(word_frequencies[:quarter_size]))

print('1st frequency', sorted_wordtype_frequencies[0])
print('10st frequency', sorted_wordtype_frequencies[10])
print('100st frequency', sorted_wordtype_frequencies[100])
print('100st frequency', sorted_wordtype_frequencies[736])
print('1000st frequency', sorted_wordtype_frequencies[1000])
print('10000st frequency', sorted_wordtype_frequencies[10000])


In [None]:
features.pop()
#features.append(afeatures.MappingFeature(len(postag_map), 'tag_priors'))


In [None]:
#normalized_tag_frequencies = torch.softmax(tag_frequencies_copy, dim=1)

sm = torch.sum(tag_frequencies_copy, dim=1)
normalized_tag_frequencies = (tag_frequencies_copy.transpose(1,0) / sm).transpose(1,0)

In [44]:
def get_target_lang_postags(dataset, data_loader, edit, mask_language):
    model.eval()
    res = {}
    data_endoer = DataEncoder(data_loader, model, mask_language)
    
    with torch.no_grad():

        for z, verse, _, batch in data_endoer:
            if verse in dataset.nodes_map[edit]:

                index = []
                toks = list(dataset.nodes_map[edit][verse].keys())
                for i in toks:
                    index.append(dataset.nodes_map[edit][verse][i])
                index = torch.LongTensor(index).to(dev) - batch['padding'][0]

                preds = model.decoder(z, index, batch)

                _, predicted = torch.max(preds, 1)

                res[verse] = {toks[i]:predicted[i].item() for i in range(len(toks))}

    return res

def generate_target_lang_tags(target_lang, params, mask_language):
    target_pos_tags = {}
     
    res_ = get_target_lang_postags(train_dataset, train_data_loader_bigbatch, target_lang, mask_language)
    target_pos_tags.update(res_)

    res_ = get_target_lang_postags(heb_test_dataset, heb_data_loader_bigbatch, target_lang, mask_language)
    target_pos_tags.update(res_)

    res_ = get_target_lang_postags(grc_test_dataset, grc_data_loader_bigbatch, target_lang, mask_language)
    target_pos_tags.update(res_)

    res_ = get_target_lang_postags(blinker_test_dataset, blinker_data_loader_bigbatch, target_lang, mask_language)
    target_pos_tags.update(res_)

    torch.save(target_pos_tags, f'/mounts/work/ayyoob/results/gnn_align/yoruba/pos_tags_{params}_maskLang{mask_language}.pickle')
    return target_pos_tags


In [None]:
global model, decoder
#1/0

decoder = None
model = None

gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:

#features = blinker_test_dataset.features[:]
#features_edge = train_dataset.features_edge[:]
from pprint import pprint
#print('indim',in_dim)
#features[-1].out_dim = 50
for i in features:
    #if i.type==3:
    #    i.out_dim=4
    print(vars(i))

#sum(p.out_dim for p in features)
#train_dataset.features.pop()
#train_dataset.features[0] = afeatures.OneHotFeature(20, 83, 'editf')
#train_dataset.features[1] = afeatures.OneHotFeature(32, 150, 'position')
#train_dataset.features[2] = afeatures.FloatFeature(4, 'degree_centrality')
#train_dataset.features[3] = afeatures.FloatFeature(4, 'closeness_centrality')
#train_dataset.features[4] = afeatures.FloatFeature(4, 'betweenness_centrality')
#train_dataset.features[5] = afeatures.FloatFeature(4, 'load_centrality')
#train_dataset.features[6] = afeatures.FloatFeature(4, 'harmonic_centrality')
#train_dataset.features[7] = afeatures.OneHotFeature(32, 250, 'greedy_modularity_community')
##train_dataset.features.append(afeatures.MappingFeature(100, 'word'))
#torch.save(train_dataset, "/mounts/work/ayyoob/models/gnn/dataset_helfi_train_community_word.pickle")
#torch.save(train_dataset.features[-3], "./features.tmp")

{'type': 1, 'out_dim': 20, 'global_normalize': False, 'name': 'editf', 'Active': True, 'n_classes': 83}
{'type': 1, 'out_dim': 32, 'global_normalize': False, 'name': 'position', 'Active': True, 'n_classes': 150}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'degree_centrality', 'Active': True}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'closeness_centrality', 'Active': True}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'betweenness_centrality', 'Active': True}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'load_centrality', 'Active': True}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'harmonic_centrality', 'Active': True}
{'type': 1, 'out_dim': 32, 'global_normalize': False, 'name': 'greedy_modularity_community', 'Active': True, 'n_classes': 250}
{'type': 1, 'out_dim': 32, 'global_normalize': False, 'name': 'label_propagation_community', 'Active': True, 'n_classes': 250}
{'type': 6, 'out_dim': 100, 'global_

In [None]:
nodes_map = train_dataset.nodes_map
bad_edition_files = []
for edit in nodes_map:
    bad_count = 0
    for verse in nodes_map[edit]:
        if len(nodes_map[edit][verse].keys()) < 2:
            bad_count += 1
        if bad_count > 1:
            bad_edition_files.append(edit)
            break
print(bad_edition_files)

In [None]:
all_japanese_nodes = set()
nodes_map = train_dataset.nodes_map

for bad_editionf in bad_edition_files:
    for verse in nodes_map[bad_editionf]:
        for item in nodes_map[bad_editionf][verse].items():
            all_japanese_nodes.add(item[1])

print(" all japansese nodes: ", len(all_japanese_nodes))
edge_index = train_dataset.edge_index.to('cpu')
remaining_edges_index = []
for i in tqdm(range(0, edge_index.shape[1], 2)):
    if edge_index[0, i].item() not in all_japanese_nodes and edge_index[0, i+1].item() not in all_japanese_nodes:
        remaining_edges_index.extend([i, i+1])

print('original total edges count', edge_index.shape)
print('remaining edge count', len(remaining_edges_index))
train_dataset.edge_index = edge_index[:, remaining_edges_index]
train_dataset.edge_index.shape
