In [1]:
import torch, sys
sys.path.insert(0, '../')
from my_utils import gpu_utils
import importlib, gc
from my_utils.alignment_features import *
import my_utils.alignment_features as afeatures
importlib.reload(afeatures)
import gnn_utils.graph_utils as gutils
import postag_utils as posutil
from my_utils.pytorch_utils import EarlyStopping



In [3]:
dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dev2 = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


In [4]:

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

import time
from datetime import datetime

import networkx as nx
import numpy as np
import torch
import torch.optim as optim

from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader

import torch_geometric.transforms as T

from tensorboardX import SummaryWriter
from sklearn.manifold import TSNE
# import matplotlib.pyplot as plt




In [5]:
from my_utils import align_utils as autils, utils
import argparse
from multiprocessing import Pool
import random

# set random seed
config_file = "/mounts/Users/student/ayyoob/Dokumente/code/pbc-ui-demo/config_pbc.ini"
utils.setup(config_file)

params = argparse.Namespace()


params.gold_file = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-grc-gold-alignments_train.txt"
pros, surs = autils.load_gold(params.gold_file)
all_verses = list(pros.keys())
params.gold_file = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-heb-gold-alignments_train.txt"
pros, surs = autils.load_gold(params.gold_file)
all_verses.extend(list(pros.keys()))
all_verses = list(set(all_verses))
print(len(all_verses))

params.editions_file =  "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi_lang_list.txt"
editions, langs = autils.load_simalign_editions(params.editions_file)
current_editions = [editions[lang] for lang in langs]

def get_pruned_verse_alignments(args):
    verse, current_editions = args
    
    verse_aligns_inter = autils.get_verse_alignments(verse)
    verse_aligns_gdfa = autils.get_verse_alignments(verse, gdfa=True)

    autils.prune_non_necessary_alignments(verse_aligns_inter, current_editions)
    autils.prune_non_necessary_alignments(verse_aligns_gdfa, current_editions)

    gc.collect()
    return verse_aligns_inter, verse_aligns_gdfa
    

args = []
for i,verse in enumerate(all_verses):
    args.append((verse, current_editions[:]))


24159


In [6]:
#importlib.reload(afeatures)

class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, features, n_head = 2, has_tagfreq_feature=False,):
        super(Encoder, self).__init__()
        self.conv1 = pyg_nn.GATConv(in_channels, 2*out_channels, heads= n_head)
        self.conv2 = pyg_nn.GATConv(2 * n_head *  out_channels , out_channels, heads= 1)
        #self.fin_lin = nn.Linear(out_channels, out_channels)
        
        if has_tagfreq_feature:
            self.feature_encoder = afeatures.FeatureEncoding(features, [normalized_tag_frequencies, word_vectors])
            #self.feature_encoder = afeatures.FeatureEncoding(features, [normalized_tag_frequencies,train_pos_labels, word_vectors])
        else:
            self.feature_encoder = afeatures.FeatureEncoding(features, [word_vectors])
            #self.feature_encoder = afeatures.FeatureEncoding(features, [train_pos_labels, word_vectors])

    def forward(self, x, edge_index):
        encoded = self.feature_encoder(x, dev)
        x = F.elu(self.conv1(encoded, edge_index, ))
        x = F.elu(self.conv2(x, edge_index))
        #return F.relu(self.fin_lin(x)), encoded
        return x, encoded

In [7]:
def freeze_encoders_embedding(encoder):
    for i,ft in enumerate(encoder.feature_types):
        if ft.type == MAPPING:
            print('doing it')
            encoder.layers[i] = afeatures.MappingEncoding(encoder.layers[i].emb.weight, freeze=True)

In [8]:
def clean_memory():
    gc.collect()
    with torch.no_grad():
        torch.cuda.empty_cache()

class DataEncoder():

    def __init__(self, data_loader, model, mask_language):
        self.data_loader = data_loader
        self.model = model
        self.mask_language = mask_language
    
    def __iter__(self):
        for i,batch in enumerate(tqdm(self.data_loader)):

            x = batch['x'][0].to(dev)  # initial features (not encoded)
            edge_index = batch['edge_index'][0].to(dev) 
            verse = batch['verse'][0]

            if verse in masked_verses:
                continue

            try:
                if self.mask_language:
                    x[:, 0] = 0
                z, encoded = self.model.encode(x, edge_index) # Z will be the output of the GNN
                batch['encoded'] = encoded
            except Exception as e:
                global sag, khar, gav
                sag, khar, gav =  (i, batch, verse)
                print(e)
                1/0
            
            yield z, verse, i, batch

def train(epoch, data_loader, mask_language, test_data_loader, max_batches=999999999):
    global optimizer
    total_loss = 0
    model.train()
    loss_multi_round = 0

    data_encoder = DataEncoder(data_loader, model, mask_language)
    optimizer.zero_grad()


    for z, verse, i, batch in data_encoder:
        
        target = batch['pos_classes'][0].to(dev)
        _, labels = torch.max(target, 1)
        
        index = batch['pos_index'][0].to(dev)

        preds = model.decoder(z, index, batch)

        # print(preds.shape, labels.shape)
        loss = criterion(preds, labels)
        loss = loss * target.shape[0] # TODO check if this is necessary
        loss.backward()
        total_loss += loss.item()

        if (i+1) % 5 == 0: # Gradient accumulation
            optimizer.step()
            optimizer.zero_grad()

            

        if i % 1000 == 999:
            # print(f"loss: {total_loss}")
            total_loss = 0
            val_loss = test(epoch, test_data_loader, mask_language)
            test_mostfreq(yor_data_loader_heb, True, yor_gold_mostfreq_tag, yor_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
            test_mostfreq(tam_data_loader_grc, True, tam_gold_mostfreq_tag, tam_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
            test_mostfreq(arb_data_loader_grc, True, arb_gold_mostfreq_tag, arb_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
            test_mostfreq(por_data_loader_grc, True, por_gold_mostfreq_tag, por_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
            print('----------------------------------------------------------------------------------------------------------')
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            model.train()
            clean_memory()
            

        if i == max_batches:
            break
    
    print(f"total train loss: {total_loss}")


In [9]:
class POSDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, n_class, drop_out=0):
        super(POSDecoder, self).__init__()

        self.transfer = nn.Sequential(nn.Linear(input_size, hidden_size), nn.ReLU(), nn.Dropout(drop_out),
                        nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(drop_out),
                        nn.Linear(hidden_size, n_class))

    def forward(self, z, index, batch=None):
        h = z[index, :]

        res = self.transfer(h)

        return res

class POSDecoderTransformer(nn.Module):
    def __init__(self, input_size, hidden_size, n_class, residual_connection, drop_out=0):
        super(POSDecoderTransformer, self).__init__()

        self.residual_connection = residual_connection
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=8, dim_feedforward=hidden_size)
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=6)

        self.transfer = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU(), nn.Dropout(drop_out), # TODO check what happens if I remove this.
                        nn.Linear(hidden_size, n_class))

    def forward(self, z_, index, batch_):
        z = z_.to(dev2)

        x = F.pad(batch_['encoded'], (0, z.size(1) - batch_['encoded'].size(1))).to(dev2)


        language_based_nodes = batch_['lang_based_nodes'] # determines which node belongs to which language
        transformer_indices = batch_['transformer_indices'] # the reverse of the prev structure

        sentences = []
        for lang_nodes in language_based_nodes: # we rearrange the nodes into sentences of each language
            if self.residual_connection:
                tensor = z[lang_nodes, :] + x[lang_nodes, :]
            else:
                tensor = z[lang_nodes, :]
            tensor = F.pad(tensor, (0, 0, 0, 150 - tensor.size(0)))
            sentences.append(tensor)
        
        batch = torch.stack(sentences) # A batch contains all translations of one sentence in all training languages.
        batch = torch.transpose(batch, 0, 1)

        h = self.transformer(batch)
        h = torch.transpose(h, 0, 1)
        h = h[transformer_indices[0], transformer_indices[1], :] # rearrange the nodes back to the order in which we recieved (the order that represents the graph)

        res = self.transfer(h)

        return res.to(dev)

In [10]:
def test(epoch, testloader, mask_language, filter_wordtypes=None):
    print('testing',  epoch)
    model.eval()
    total = 0
    correct = 0
    total_loss = 0
    probability_sum = 0 
    probability_count = 0

    data_encoder = DataEncoder(testloader, model, mask_language)
    
    with torch.no_grad():
        for z, verse, i, batch in data_encoder:
            
            target = batch['pos_classes'][0].to(dev)
            index = batch['pos_index'][0].to(dev)
            
            if filter_wordtypes != None:
                non_filtered_words = filter_wordtypes[batch['x'][0][:, 9].long()] == 1
                non_filtered_words = non_filtered_words[index]
                index = index[non_filtered_words]

                target = target[non_filtered_words, :]

            preds = model.decoder(z, index, batch)
            

            if preds.size(0) > 0:
                max_probs, predicted = torch.max(torch.softmax(preds, dim=1), 1)
                _, labels = torch.max(target, 1)

                loss = criterion(preds, labels)
                probability_sum += torch.sum(max_probs)
                probability_count += max_probs.size(0)
                total_loss += loss

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'test, epoch: {epoch}, total:{total} ACC: {correct/total}, loss: {total_loss}, confidence: {probability_sum/probability_count}')
    clean_memory()
    return total_loss

def test_mostfreq(testloader, mask_language, target_mostfreq_tags, target_mostfreq_index, word_types_shape, from_target=False):
    
    res = torch.zeros(word_types_shape[0], word_types_shape[1])
    model.eval()

    data_encoder = DataEncoder(testloader, model, mask_language)
    
    with torch.no_grad():
        for z, verse, i, batch in data_encoder:
            
            index = batch['pos_index'][0].to(dev)
            x = batch['x'][0][index, :]
            target = batch['pos_classes'][0]

            preds = model.decoder(z, index, batch)

            _, pred_max = torch.max(preds, dim=1)
            _, targe_max = torch.max(target, dim=1)

            if from_target:
                res[x[:, 9].long(), targe_max.long()] += 1
            else:
                res[x[:, 9].long(), pred_max.long()] += 1
    
    max_vals, res_tags = torch.max(res, dim=1)
    res_tags = res_tags[target_mostfreq_index]
    max_vals = max_vals[target_mostfreq_index]
    

    
    #print(f'correct = {torch.sum(target_mostfreq_tags == res_tags)}')
    #print(f'most frequency test, total:{target_mostfreq_tags.shape[0]}, accuracy:{torch.sum(target_mostfreq_tags == res_tags)/target_mostfreq_tags.shape[0]}, ')
    #print('target mostfreq tags', target_mostfreq_tags.shape)
    #print('res_tags', res_tags.shape)

    res_tags = res_tags[max_vals>0.1]
    target_mostfreq_tags_cp = target_mostfreq_tags[max_vals>0.1]
    #print(f'correct = {torch.sum(target_mostfreq_tags_cp == res_tags)}')
    print(f'most frequency test, total:{target_mostfreq_tags_cp.shape[0]}, accuracy:{torch.sum(target_mostfreq_tags_cp == res_tags)/target_mostfreq_tags_cp.shape[0]}, ')
    #print('target mostfreq tags', target_mostfreq_tags_cp.shape)
    #print('res_tags', res_tags.shape)


In [11]:
def majority_voting_test(data_loader1, data_loader2):
    total = 0
    correct = 0
    
    for i,(batch, batch2) in enumerate(tqdm(zip(data_loader1, data_loader2))) :
            
        x = batch['x'][0]
        edge_index = batch['edge_index'][0]
        verse = batch['verse'][0]

        if verse in masked_verses:
            continue

        target = batch['pos_classes'][0]
        index = batch['pos_index'][0]

        index2 = batch2['pos_index'][0]
        

        for node, label in zip(index,target):
            other_side = edge_index[1, edge_index[0, :] == node]
            other_side_withpos = other_side[[True if i in index2 else False for i in other_side]]
            other_side_target_indices = [(i == index2).nonzero(as_tuple=True)[0].item() for i in other_side_withpos]
            #print(other_side_target_indices)
            proj_tags = batch2['pos_classes'][0][other_side_target_indices]

            if proj_tags.size(0) > 0:
                _, proj_tags = torch.max(proj_tags, 1)
                #print(target.shape, node, index.shape, proj_tags, other_side)
                
                if torch.argmax(label) == torch.mode(proj_tags)[0]:
                    correct += 1
                
                total += 1

    print(f'test, , total:{total} ACC: {correct/total}')


In [12]:
#print(wordtype_frequencies[736])
#print(wordtype_frequencies[1473])
#print(wordtype_frequencies[3683])
#print(wordtype_frequencies[7367])
#print(wordtype_frequencies[14733])

In [13]:
#frequent_words = torch.zeros(word_frequencies.size(0))
#frequent_words[word_frequencies > -1] = 1

#test(1, test_data_loader, filter_wordtypes=frequent_words)

In [14]:
from gensim.models import Word2Vec
w2v_model = Word2Vec.load("/mounts/work/ayyoob/models/w2v/word2vec_helfi_langs_15e.model")

print(w2v_model.wv.vectors.shape)
word_vectors = torch.from_numpy(w2v_model.wv.vectors).float()

(2354770, 100)


In [15]:
import pickle

train_verses = all_verses[:]
test_verses = all_verses[:] 
editf1 = 'fin-x-bible-helfi'
editf2 = "heb-x-bible-helfi"


if 'jpn-x-bible-newworld' in  current_editions[:]:
     current_editions.remove('jpn-x-bible-newworld')
if 'grc-x-bible-unaccented' in  current_editions[:]:
     current_editions.remove('grc-x-bible-unaccented')



data_dir_train = "/mounts/data/proj/ayyoob/align_induction/dataset/dataset_helfi_train_community_word"
data_dir_blinker = "/mounts/data/proj/ayyoob/align_induction/dataset/pruned_alignments_blinker_inter/"
data_dir_grc = "/mounts/data/proj/ayyoob/align_induction/dataset/dataset_helfi_grc_community_word/"
data_dir_heb = "/mounts/data/proj/ayyoob/align_induction/dataset/dataset_helfi_heb_community_word/"

train_dataset = torch.load(f"{data_dir_train}/train_dataset_nox_noedge.torch.bin")
blinker_test_dataset = torch.load(f"{data_dir_blinker}/train_dataset_nox_noedge.torch.bin")
grc_test_dataset = torch.load(f"{data_dir_grc}/train_dataset_nox_noedge.torch.bin")
heb_test_dataset = torch.load(f"{data_dir_heb}/train_dataset_nox_noedge.torch.bin")


In [16]:
import codecs
import collections

postag_map = {"ADJ": 0, "ADP": 1, "ADV": 2, "AUX": 3, "CCONJ": 4, "DET": 5, "INTJ": 6, "NOUN": 7, "NUM": 8, "PART": 9, "PRON": 10, "PROPN": 11, "PUNCT": 12, "SCONJ": 13, "SYM": 14, "VERB": 15, "X": 16}
postag_reverse_map = {item[1]:item[0] for item in postag_map.items()}

pos_lang_list = ["eng-x-bible-mixed", "deu-x-bible-newworld", "ces-x-bible-newworld", 
		#"prs-x-bible-goodnews", "hin-x-bible-newworld", "ron-x-bible-2006",
		'dan-x-bible-newworld', 'fin-x-bible-helfi', 'nld-x-bible-newworld', 'pol-x-bible-newworld', 'swe-x-bible-newworld',
		"ita-x-bible-2009", "fra-x-bible-louissegond", "spa-x-bible-newworld"
		#,'prs-x-bible-goodnews'
		]

def get_db_nodecount(dataset):
	res = 0
	for lang in dataset.nodes_map.values():
		for verse in lang.values():
			res += len(verse)
	
	return res

def get_language_nodes(dataset, lang_list, sentences):
	node_count = get_db_nodecount(dataset)
	pos_labels = torch.zeros(node_count, len(postag_map))

	pos_node_cover = collections.defaultdict(list)
	for lang in lang_list:
		if lang in dataset.nodes_map:
			for sentence in sentences:
				if sentence in dataset.nodes_map[lang]:
					for tok in dataset.nodes_map[lang][sentence]:
						pos_node_cover[sentence].append(dataset.nodes_map[lang][sentence][tok])
	
	return pos_labels, pos_node_cover

def create_structures(dataset, all_tags):
	node_count = get_db_nodecount(dataset)
	pos_labels = torch.zeros(node_count, len(postag_map))
	pos_node_cover = collections.defaultdict(list)

	for lang in all_tags:
		for sent_id in all_tags[lang]:
			sent_tags = all_tags[lang][sent_id]
			for w_i in range(len(sent_tags)):
				if w_i not in dataset.nodes_map[lang][sent_id]:
					continue
				pos_labels[dataset.nodes_map[lang][sent_id][w_i], sent_tags[w_i]] = 1
				pos_node_cover[sent_id].append(dataset.nodes_map[lang][sent_id][w_i])
	return pos_labels, pos_node_cover

def get_pos_tags(dataset, pos_lang_list):
	all_tags = {}
	for lang in pos_lang_list:
		if lang not in dataset.nodes_map:
			continue
		all_tags[lang] = {}

		#if os.path.exists(f'/mounts/work/silvia/POS/TAGGED_LANGS/STANZA/{lang}.conllu'):
		#	base_path = '/mounts/work/silvia/POS/TAGGED_LANGS/STANZA/'
		if os.path.exists(F"/mounts/work/silvia/POS/TAGGED_LANGS/{lang}.conllu"):	
			base_path = F"/mounts/work/silvia/POS/TAGGED_LANGS/"
		else:
			base_path = F"/mounts/work/mjalili/projects/gnn-align/data/pbc_pos_tags/"
		
		with codecs.open(F"{base_path}{lang}.conllu", "r", "utf-8") as lang_pos:
			tag_sent = []
			sent_id = ""
			for sline in lang_pos:
				sline = sline.strip()
				if sline == "":
					if sent_id not in dataset.nodes_map[lang]:
						tag_sent = []
						sent_id = ""
						continue

					all_tags[lang][sent_id] = [postag_map[p[3]] for p in tag_sent]
					tag_sent = []
					sent_id = ""
				elif "# verse_id" in sline:
					sent_id = sline.split()[-1]
				elif sline[0] == "#":
					continue
				else:
					tag_sent.append(sline.split("\t"))

	pos_labels, pos_node_cover = create_structures(dataset, all_tags)
	return pos_labels, pos_node_cover
	

def get_pos_tags_from_bronze_data(dataset, file_path, language):
	file_content = torch.load(file_path)
	all_tags = {language:{}}

	node_count = get_db_nodecount(dataset)
	pos_labels = torch.zeros(node_count, len(postag_map))
	pos_node_cover = collections.defaultdict(list)

	for sent_id in file_content:
		if sent_id in dataset.nodes_map[language]:
			m_list = file_content[sent_id].items() if isinstance(file_content[sent_id], dict) else file_content[sent_id]
			for item in m_list:
				pos_labels[ dataset.nodes_map[language][sent_id][item[0]], item[1]] = 1
				pos_node_cover[sent_id].append(dataset.nodes_map[language][sent_id][item[0]])
	
	return pos_labels, pos_node_cover

def read_ud_gold_file(f_path, w2v_model, lang):
	pos_labels = torch.zeros(w2v_model.wv.vectors.shape[0], len(postag_map))
	with codecs.open(f_path, "r", "utf-8") as lang_pos:
			for sline in lang_pos:
				sline = sline.strip()
				if sline == "":
					pass
				elif "# verse_id" in sline:
					pass
				elif sline[0] == "#":
					continue
				else:
					line_items = list(sline.split("\t"))
					word = line_items[1]
					tag = line_items[3]
					try:
						# print(f'{lang}:{word.lower()}')
						idx = w2v_model.wv.key_to_index[f'{lang}:{word.lower()}']
						
					except Exception as e: # some words from the gold data may not exist in bible. we just skip them
						continue
					
					if tag == '_':
						#print('tag is', tag)
						continue

					pos_labels[idx, postag_map[tag]] += 1
	
	index = (torch.sum(pos_labels, dim=1) > 0.1).nonzero()
	
	maxes, tags = torch.max(pos_labels, dim=1)
	print(torch.sum(pos_labels))

	return tags[index], index


In [17]:
# blinker_test_dataset = torch.load("/mounts/work/ayyoob/models/gnn/dataset_blinker_full_community_word.pickle", map_location=torch.device('cpu'))
editf12 = "eng-x-bible-mixed"
editf22 = 'fra-x-bible-louissegond'

test_gold_eng_fra = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/eng_fra_pbc/eng-fra.gold"

pros_blinker, surs_blinker = autils.load_gold(test_gold_eng_fra)

blinker_verse_alignments_inter = {}

verses_map = {}

for edit in blinker_test_dataset.nodes_map:
    for verse in blinker_test_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in blinker_test_dataset.nodes_map[edit][verse]:
                verses_map[verse] = blinker_test_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
blinker_verses = [item[0] for item in sorted_verses]


In [18]:
#importlib.reload(afeatures)
#grc_test_dataset = torch.load("/mounts/work/ayyoob/models/gnn/dataset_helfi_grc_test_community_word.pickle", map_location=torch.device('cpu'))
editf_fin = "fin-x-bible-helfi"
editf_grc = 'grc-x-bible-helfi'

test_gold_grc = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-grc-gold-alignments_test.txt"

pros_grc, surs_grc = autils.load_gold(test_gold_grc)

grc_test_verse_alignments_inter = {}
grc_test_verse_alignments_gdfa = {}
gc.collect()

verses_map = {}

for edit in grc_test_dataset.nodes_map:
    for verse in grc_test_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in grc_test_dataset.nodes_map[edit][verse]:
                verses_map[verse] = grc_test_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
grc_test_verses = [item[0] for item in sorted_verses]

gc.collect()

0

In [19]:
#heb_test_dataset = torch.load("/mounts/work/ayyoob/models/gnn/dataset_helfi_heb_test_community_word.pickle", map_location=torch.device('cpu'))

test_gold_heb = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-heb-gold-alignments_test.txt"

pros_heb, surs_heb = autils.load_gold(test_gold_heb)

verses_map = {}

for edit in heb_test_dataset.nodes_map:
    for verse in heb_test_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in heb_test_dataset.nodes_map[edit][verse]:
                verses_map[verse] = heb_test_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
heb_test_verses = [item[0] for item in sorted_verses]
gc.collect()

0

In [20]:
verses_map = {}

for edit in train_dataset.nodes_map:
    for verse in train_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in train_dataset.nodes_map[edit][verse]:
                verses_map[verse] = train_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
all_verses = [item[0] for item in sorted_verses]

long_verses = set()

for edit in train_dataset.nodes_map.keys():
    for verse in train_dataset.nodes_map[edit]:
        to_print = False
        for tok in train_dataset.nodes_map[edit][verse]:
            if tok > 150:
                to_print = True
        if to_print == True:
            long_verses.add(verse)


train_verses = all_verses[:]

masked_verses = list(long_verses)
#masked_verses.extend(blinker_verses)

In [21]:
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import random


class POSTAGGNNDataset(Dataset):

    def __init__(self, dataset, verses, edit_files, alignments, node_cover, pos_labels, data_dir, create_data=False, group_size = 20):
        self.node_cover = node_cover
        self.pos_labels = pos_labels
        self.data_dir = data_dir
        self.items = self.calculate_size(verses, group_size, node_cover)
        self.dataset = dataset

        if create_data:
            self.calculate_verse_stats(verses, edit_files, alignments, dataset, data_dir)            
        
    def calculate_size(self, verses, group_size, node_cover):
        res = []
        for verse in verses:
            covered_nodes = node_cover[verse]
            random.shuffle(covered_nodes)
            items = [covered_nodes[i:i + group_size] for i in range(0, len(covered_nodes), group_size)]
            res.extend([(verse, i) for i in items])

        return res

    def calculate_verse_stats(self,verses, edition_files, alignments, dataset, data_dir):

        min_edge = 0
        for verse in tqdm(verses):
            min_nodes = 99999999999999
            max_nodes = 0
            edges_tmp = [[],[]]
            x_tmp = []
            features = []
            for i,editf1 in enumerate(edition_files):
                for j,editf2 in enumerate(edition_files[i+1:]):
                    aligns = autils.get_aligns(editf1, editf2, alignments[verse])
                    if aligns != None:
                        for align in aligns:
                            try:
                                n1,_ = gutils.node_nom(verse, editf1, align[0], None, dataset.nodes_map, x_tmp, edition_files, features)
                                n2,_ = gutils.node_nom(verse, editf2, align[1], None, dataset.nodes_map, x_tmp, edition_files, features)
                                edges_tmp[0].extend([n1, n2])

                                max_nodes = max(n1, n2, max_nodes)
                                min_nodes = min(n1, n2, min_nodes)
                            except Exception as e:
                                print(editf1, editf2, verse)
                                raise(e)

            self.verse_info = {}

            self.verse_info['padding'] = min_nodes
            
            self.verse_info['x'] = torch.clone(dataset.x[min_nodes:max_nodes+1,:])
            
            self.verse_info['edge_index'] = torch.clone(dataset.edge_index[:, min_edge : min_edge + len(edges_tmp[0])] - min_nodes)

            if torch.min(self.verse_info['edge_index']) != 0:
                print(verse, min_nodes, max_nodes, min_edge, len(edges_tmp[0]))
                print(torch.min(self.verse_info['edge_index']))
            
            if self.verse_info['x'].shape[0] != torch.max(self.verse_info['edge_index']) + 1 :
                print(verse, min_nodes, max_nodes, min_edge, len(edges_tmp[0]))
                print(torch.min(self.verse_info['edge_index']))
            
            min_edge = min_edge + len(edges_tmp[0])

            torch.save(self.verse_info, f"{data_dir}/verses/{verse}_info.torch.bin")
        
        dataset.x = None
        dataset.edge_index = None
        torch.save(dataset, f"{data_dir}/train_dataset_nox_noedge.torch.bin")
    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        verse, nodes = self.items[idx]
        
        self.verse_info = {verse: torch.load(f'{self.data_dir}/verses/{verse}_info.torch.bin')}


        word_number = self.verse_info[verse]['x'][:, 9]
        padding = self.verse_info[verse]['padding']
        
        language_based_nodes, transformer_indices = posutil.get_language_based_nodes(self.dataset.nodes_map, verse, nodes, padding)

        ## # Add POSTAG to set of features
        #postags = self.pos_labels[padding: self.verse_info[verse]['x'].size(0) + padding, : ]
        #postags = postags.detach().clone()
        #postags[torch.LongTensor(nodes) - padding, :] = 0
        #self.verse_info[verse]['x'] = torch.cat((self.verse_info[verse]['x'], postags), dim=1)

        # Add token id as a feature, used to extract token information (like token's tag distribution)
        word_number = torch.unsqueeze(word_number, 1)
        self.verse_info[verse]['x'] = torch.cat((self.verse_info[verse]['x'], word_number), dim=1)

        return {'verse':verse, 'x':self.verse_info[verse]['x'], 'edge_index':self.verse_info[verse]['edge_index'], 
                'pos_classes': self.pos_labels[nodes, :], 'pos_index': torch.LongTensor(nodes) - padding, 
                'padding': padding, 'lang_based_nodes': language_based_nodes, 'transformer_indices': transformer_indices}

def create_me_a_gnn_dataset_you_stupid(node_covers, labels, group_size=100, editions=current_editions):

    train_ds = POSTAGGNNDataset(train_dataset, train_verses, editions, {}, node_covers[0], labels[0], data_dir_train, group_size=group_size)
    grc_ds = POSTAGGNNDataset(grc_test_dataset, grc_test_verses, editions, {}, node_covers[1], labels[1], data_dir_grc, group_size=group_size)
    heb_ds = POSTAGGNNDataset(heb_test_dataset, heb_test_verses, editions, {}, node_covers[2], labels[2], data_dir_heb, group_size=group_size)
    blinker_ds = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, editions, {}, node_covers[3], labels[3], data_dir_blinker, group_size=group_size)

    return train_ds, grc_ds, heb_ds, blinker_ds

# if "eng-x-bible-mixed" in pos_lang_list:
#     pos_lang_list.remove("eng-x-bible-mixed")
train_pos_labels, train_pos_node_cover = get_pos_tags(train_dataset, pos_lang_list)
#torch.save({'pos_labels':train_pos_labels, 'pos_node_cover':train_pos_node_cover}, f'{data_dir_train}/pos_data.torch.bin')
#pos_data = torch.load(f'{data_dir_train}/pos_data.torch.bin')
#train_pos_labels, train_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']

blinker_pos_labels, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, pos_lang_list)
##torch.save({'pos_labels':blinker_pos_labels, 'pos_node_cover': blinker_pos_node_cover}, f'{data_dir_blinker}/pos_data.torch.bin')
#pos_data = torch.load(f'{data_dir_blinker}/pos_data.torch.bin')
#blinker_pos_labels, blinker_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']

grc_pos_labels, grc_pos_node_cover = get_pos_tags(grc_test_dataset, pos_lang_list)
#torch.save({'pos_labels':grc_pos_labels, 'pos_node_cover': grc_pos_node_cover}, f'{data_dir_grc}/pos_data.torch.bin')
#pos_data = torch.load(f'{data_dir_grc}/pos_data.torch.bin')
#grc_pos_labels, grc_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']

heb_pos_labels, heb_pos_node_cover = get_pos_tags(heb_test_dataset, pos_lang_list)
#torch.save({'pos_labels':heb_pos_labels, 'pos_node_cover': heb_pos_node_cover}, f'{data_dir_heb}/pos_data.torch.bin')
#pos_data = torch.load(f'{data_dir_heb}/pos_data.torch.bin')
#heb_pos_labels, heb_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']

gnn_dataset_train_pos, gnn_dataset_grc_pos, gnn_dataset_heb_pos, gnn_dataset_blinker_pos = create_me_a_gnn_dataset_you_stupid(
    [train_pos_node_cover, grc_pos_node_cover, heb_pos_node_cover, blinker_pos_node_cover], [train_pos_labels, grc_pos_labels, heb_pos_labels, blinker_pos_labels], group_size=256)

gnn_dataset_train_pos_bigbatch, gnn_dataset_grc_pos_bigbatch, gnn_dataset_heb_pos_bigbatch, gnn_dataset_blinker_pos_bigbatch = create_me_a_gnn_dataset_you_stupid(
    [train_pos_node_cover, grc_pos_node_cover, heb_pos_node_cover, blinker_pos_node_cover], [train_pos_labels, grc_pos_labels, heb_pos_labels, blinker_pos_labels], group_size=10000)

train_data_loader_bigbatch = DataLoader(gnn_dataset_train_pos_bigbatch, batch_size=1, shuffle=False)
grc_data_loader_bigbatch = DataLoader(gnn_dataset_grc_pos_bigbatch, batch_size=1, shuffle=False)
heb_data_loader_bigbatch = DataLoader(gnn_dataset_heb_pos_bigbatch, batch_size=1, shuffle=False)
blinker_data_loader_bigbatch = DataLoader(gnn_dataset_blinker_pos_bigbatch, batch_size=1, shuffle=False)


In [22]:
print(len(gnn_dataset_train_pos))

39817


In [23]:

def get_verse_size_distrib(nodes_map):
    verse_sizes = {}
    distrib = [0 for i in range(84)]
    for edition in nodes_map:
        for verse in nodes_map[edition]:
            if verse in verse_sizes:
                verse_sizes[verse] += 1
            else:
                verse_sizes[verse] = 1
    
    for verse in verse_sizes:
        distrib[verse_sizes[verse]] += 1
    
    for i in range(len(distrib)):
        if distrib[i] > 10:
            print(i, distrib[i])

def update_verse_sizes(nodes_map, verse_sizes):
    for edition in nodes_map:
        for verse in nodes_map[edition]:
            if verse in verse_sizes:
                verse_sizes[verse] += 1
            else:
                verse_sizes[verse] = 1

def update_distribution(nodes_map, verse_sizes, distrib, edition):
    if edition in nodes_map:
        for verse in nodes_map[edition]:
            if verse_sizes[verse] > 99:
                print(verse, verse_sizes[verse])
                continue
            distrib[verse_sizes[verse]] += 1
                
def get_language_verse_distrib(nodes_map1, nodes_map2, nodes_map3, nodes_map4, edition):
    global all_verse_sizes
    all_verse_sizes = {}

    update_verse_sizes(nodes_map1, all_verse_sizes)
    update_verse_sizes(nodes_map2, all_verse_sizes)
    update_verse_sizes(nodes_map3, all_verse_sizes)
    update_verse_sizes(nodes_map4, all_verse_sizes)

    
    distrib = [0 for i in range(100)]
    update_distribution(nodes_map1, all_verse_sizes, distrib, edition)
    update_distribution(nodes_map2, all_verse_sizes, distrib, edition)
    update_distribution(nodes_map3, all_verse_sizes, distrib, edition)
    update_distribution(nodes_map4, all_verse_sizes, distrib, edition)
    
    for i in range(len(distrib)):
        if distrib[i] > 10:
            print(i, distrib[i])

get_language_verse_distrib(train_dataset.nodes_map, blinker_test_dataset.nodes_map, grc_test_dataset.nodes_map, heb_test_dataset.nodes_map, 'tam-x-bible-newworld')
get_verse_size_distrib(train_dataset.nodes_map)
get_verse_size_distrib(blinker_test_dataset.nodes_map)
get_verse_size_distrib(grc_test_dataset.nodes_map)
get_verse_size_distrib(heb_test_dataset.nodes_map)
# big_verses = []
# for verse in train_verses:
#     if all_verse_sizes[verse] > 50:
#         big_verses.append(verse)
# gnn_dataset_train_pos = POSTAGGNNDataset(train_dataset, big_verses, editions, {}, train_pos_node_cover, train_pos_labels, data_dir_train, group_size=128)

44028011 156
41012034 162
45012008 158
44014012 160
58006008 156
58010022 156
66002018 156
43013022 162
44001020 160
46014019 158
50004005 158
50001025 158
42023030 162
44007042 160
40002018 160
44016001 160
48004008 158
52002017 158
59002010 158
44028018 160
44014020 160
66021020 154
66016009 152
49005007 158
66022013 156
46007036 158
44005036 160
56001010 156
46014009 158
44008027 158
58001003 156
41009026 162
40010024 162
43011010 162
44013001 160
41001026 162
51004002 156
48002017 158
49005032 158
46004017 158
40010025 162
54003013 158
55001011 152
41003034 162
44014006 158
40027024 162
44024003 150
42005007 162
44012019 160
44024014 160
51003015 156
47003013 158
42023014 160
44017011 160
58012028 156
44004026 158
40007016 162
44014013 160
54002007 154
44027006 160
44025024 158
54004015 158
44010004 160
40002018 160
40007016 162
40010024 162
40010025 162
40027024 162
41001026 162
41003034 162
41009026 162
41012034 162
42005007 162
42006040 162
42023014 160
42023030 162
43011010 162

In [24]:
def get_data_loadrs_for_target_editions(target_editions, dataset, pos_node_cover, verses, data_dir):
    target_pos_labels, target_pos_node_cover = get_language_nodes(dataset, target_editions, pos_node_cover.keys())
    gnn_dataset_target_pos = POSTAGGNNDataset(dataset, verses, None, {}, target_pos_node_cover, target_pos_labels, data_dir, group_size = 50000)
    target_data_loader = DataLoader(gnn_dataset_target_pos, batch_size=1, shuffle=False)
    
    return target_data_loader

In [25]:
#no_eng_langs = pos_lang_list[:]
#no_eng_langs.remove('eng-x-bible-mixed')

## # train_pos_labels, train_pos_node_cover = get_pos_tags(train_dataset, no_eng_langs)
## # gnn_dataset_train_pos = POSTAGGNNDataset(train_dataset, train_verses, current_editions, verse_alignments_inter,
## #                        train_pos_node_cover, train_pos_labels, data_dir_train, group_size = 10)


#_, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, ['eng-x-bible-mixed'])
#blinker_pos_labels, _ = get_pos_tags(blinker_test_dataset, pos_lang_list)
#gnn_dataset_blinker_pos_onlyeng = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
#                            blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 500)



In [26]:
def save_model(model, name):
    #model.encoder.feature_encoder.feature_types[0] = afeatures.OneHotFeature(20, 83, 'editf')
    #model.encoder.feature_encoder.feature_types[1] = afeatures.OneHotFeature(32, 150, 'position')
    #model.encoder.feature_encoder.feature_types[2] = afeatures.FloatFeature(4, 'degree_centrality')
    #model.encoder.feature_encoder.feature_types[3] = afeatures.FloatFeature(4, 'closeness_centrality')
    #model.encoder.feature_encoder.feature_types[4] = afeatures.FloatFeature(4, 'betweenness_centrality')
    #model.encoder.feature_encoder.feature_types[5] = afeatures.FloatFeature(4, 'load_centrality')
    #model.encoder.feature_encoder.feature_types[6] = afeatures.FloatFeature(4, 'harmonic_centrality')
    #model.encoder.feature_encoder.feature_types[7] = afeatures.OneHotFeature(32, 250, 'greedy_modularity_community')
    #model.encoder.feature_encoder.feature_types[8] = afeatures.OneHotFeature(32, 250, 'community_2')
    #model.encoder.feature_encoder.feature_types[9] = afeatures.MappingFeature(100, 'word')
    #model.encoder.feature_encoder.feature_types[10] = afeatures.MappingFeature(len(postag_map), 'tag_priors', freeze=True)
    torch.save(model, f'/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_{name}_' + '.pickle')

In [27]:
# eng_test_pos_labels, eng_test_pos_node_cover = get_pos_tags(blinker_test_dataset, ['eng-x-bible-mixed'])
# gnn_dataset_engtest_pos = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, {},
#                       eng_test_pos_node_cover, eng_test_pos_labels, data_dir_blinker, group_size = 500)
# engtest_data_loader = DataLoader(gnn_dataset_engtest_pos, batch_size=1, shuffle=False)

# #test(1, engtest_data_loader, mask_language=False) 

# finetune_pos_labels, finetune_pos_node_cover = get_pos_tags(train_dataset, ['eng-x-bible-mixed'])
# gnn_dataset_finetune_pos = POSTAGGNNDataset(train_dataset, train_verses, current_editions, {},
#                       finetune_pos_node_cover, finetune_pos_labels, data_dir_train, group_size = 100)
# finetune_data_loader = DataLoader(gnn_dataset_finetune_pos, batch_size=1, shuffle=False)


#train(1, finetune_data_loader, mask_language=False)
#test(1, engtest_data_loader, mask_language=False) 

# blinker_pos_labels, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, ['eng-x-bible-mixed'])
# gnn_dataset_blinker_pos = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
#                              blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 10000)

# blinker_pos_labels, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, no_eng_langs)
# gnn_dataset_blinker_pos_majvoting_test = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
#                              blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 10000)

# test_data_loader = DataLoader(gnn_dataset_blinker_pos, batch_size=1, shuffle=False)
# test_data_loader_majvoting = DataLoader(gnn_dataset_blinker_pos_majvoting_test, batch_size=1, shuffle=False)
# majority_voting_test(test_data_loader, test_data_loader_majvoting)

In [28]:
#yor_bronze_labels, yor_bronze_node_cover = get_pos_tags_from_bronze_data(heb_test_dataset, '/mounts/work/ayyoob/results/gnn_align/yoruba/pos_tags_yor-x-bible-2010_posfeatFalse_transformerFalse_trainWEFalse_maskLangTrue.pickle', 'yor-x-bible-2010')
yor_bronze_labels, yor_bronze_node_cover = get_pos_tags_from_bronze_data(heb_test_dataset, '/mounts/work/ayyoob/results/gnn_align/yoruba/pos_tags_yor-x-bible-2010_posfeatTrue_transformerFalse_trainWEFalse_maskLangTrue.pickle', 'yor-x-bible-2010')
# yor_bronze_labels, yor_bronze_node_cover = get_pos_tags_from_bronze_data(heb_test_dataset, '/mounts/work/ayyoob/results/gnn_align/yoruba/pos_tags_yor-x-bible-2010_9lang_pos_tagging_posfeatTrue_transformerTrue6layresresidual_trainWETrue_maskLangFalse_epoch3_trainYoruba_20220219-052901.pickle_maskLangFalse.pickle', 'yor-x-bible-2010')
#yor_bronze_labels, yor_bronze_node_cover = get_pos_tags_from_bronze_data(heb_test_dataset, '/mounts/work/ayyoob/results/gnn_align/yoruba/pos_tags_yor-x-bible-2010_11lang_posfeatTruealltargets_transformerTrue6layresresidual_trainWEFalse_epoch1_noEng_maskLangTrue.pickle', 'yor-x-bible-2010')

tam_bronze_labels, tam_bronze_node_cover = get_pos_tags_from_bronze_data(grc_test_dataset, '/mounts/work/ayyoob/results/gnn_align/yoruba/pos_tags_tam-x-bible-newworld_12lang_posfeatFalsealltargets_transformerFalse6layresresidual_trainWEFalse_epoch1__maskLangTrue.pickle', 'tam-x-bible-newworld')


gnn_dataset_yorbronz_pos = POSTAGGNNDataset(heb_test_dataset, heb_test_verses, current_editions, {},
                      yor_bronze_node_cover, yor_bronze_labels, data_dir_heb, group_size = 500)
yorbronz_data_loader = DataLoader(gnn_dataset_yorbronz_pos, batch_size=1, shuffle=False)

gnn_dataset_tambronz_pos = POSTAGGNNDataset(grc_test_dataset, grc_test_verses, current_editions, {},
                      tam_bronze_node_cover, tam_bronze_labels, data_dir_grc, group_size = 500)
tambronz_data_loader = DataLoader(gnn_dataset_tambronz_pos, batch_size=1, shuffle=False)


In [29]:
# test_mostfreq(tambronz_data_loader, True, tam_gold_mostfreq_tag, tam_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)), True)


In [30]:
def test_existing_model(model_path, use_transformer):
    global model
    model = torch.load(f'/mounts/work/ayyoob/models/gnn/checkpoint/postagging/{model_path}', map_location=torch.device('cpu'))
    model.to(dev)
    print(model_path)
    if use_transformer:
        model.decoder.to(dev2)

    test_mostfreq(yor_data_loader_heb, True, yor_gold_mostfreq_tag, yor_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
    #test(1, yorbronz_data_loader, True)
    #test(1, engtest_data_loader, True)

In [31]:
for edit in current_editions:
    if edit.startswith('por'):
        print(edit)

por-x-bible-versaointernacional


In [32]:
yor_gold_mostfreq_tag, yor_gold_mostfreq_index = read_ud_gold_file('/mounts/work/silvia/POS/yo_ytb-ud-test.conllu', w2v_model, 'yor')
yor_data_loader_grc = get_data_loadrs_for_target_editions(['yor-x-bible-2010'], grc_test_dataset, grc_pos_node_cover, grc_test_verses, data_dir_grc)
yor_data_loader_heb = get_data_loadrs_for_target_editions(['yor-x-bible-2010'], heb_test_dataset, heb_pos_node_cover, heb_test_verses, data_dir_heb)

tam_gold_mostfreq_tag, tam_gold_mostfreq_index = read_ud_gold_file('/mounts/work/silvia/POS/ta_mwtt-ud-test.conllu', w2v_model, 'tam')
tam_data_loader_grc = get_data_loadrs_for_target_editions(['tam-x-bible-newworld'], grc_test_dataset, grc_pos_node_cover, grc_test_verses, data_dir_grc)

arb_gold_mostfreq_tag, arb_gold_mostfreq_index = read_ud_gold_file('/nfs/datx/UD/ar_pud-ud-test.conllu', w2v_model, 'arb')
arb_data_loader_grc = get_data_loadrs_for_target_editions(['arb-x-bible'], grc_test_dataset, grc_pos_node_cover, grc_test_verses, data_dir_grc)

por_gold_mostfreq_tag, por_gold_mostfreq_index = read_ud_gold_file('/nfs/datx/UD/pt_pud-ud-test.conllu', w2v_model, 'por')
por_data_loader_grc = get_data_loadrs_for_target_editions(['por-x-bible-versaointernacional'], grc_test_dataset, grc_pos_node_cover, grc_test_verses, data_dir_grc)

#test_existing_model('pos_tagging_posfeatFalse_transformerFalse_trainWEFalse_maskLangTrue_20220209-190017.pickle', False)

#test_mostfreq(yorbronz_data_loader, True, yor_gold_mostfreq_tag, yor_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)), from_target=True)
#test_mostfreq(tam_data_loader_grc, True, tam_gold_mostfreq_tag, tam_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))


tensor(5424.)
tensor(1763.)
tensor(10835.)
tensor(17562.)


In [33]:
# test_mostfreq(yor_data_loader_grc, True, yor_gold_mostfreq_tag, yor_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
# test_mostfreq(arb_data_loader_grc, True, arb_gold_mostfreq_tag, arb_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
# test_mostfreq(por_data_loader_grc, True, por_gold_mostfreq_tag, por_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
# test_mostfreq(tam_data_loader_grc, True, tam_gold_mostfreq_tag, tam_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))


In [34]:
importlib.reload(posutil)
threshold = 0.95
#target_editions = []

#for edition in current_editions:
#    if edition not in pos_lang_list:
#        target_editions.append(edition)


#model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_11langs-nopersian_posfeatFalsealltargets_transformerFalse6layresresidualFalse_trainWEFalse_maskLangTrue_epoch1__20220305-101058_earlystopping-GA_.pickle', map_location=torch.device('cpu'))
#model.to(dev)
##model.decoder.to(dev2)
#test_mostfreq(yor_data_loader_heb, True, yor_gold_mostfreq_tag, yor_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))

#target_data_loader_train = get_data_loadrs_for_target_editions(target_editions, train_dataset, train_pos_node_cover, train_verses, data_dir_train)
#target_data_loader_grc = get_data_loadrs_for_target_editions(target_editions, grc_test_dataset, grc_pos_node_cover, grc_test_verses, data_dir_grc)
#target_data_loader_heb = get_data_loadrs_for_target_editions(target_editions, heb_test_dataset, heb_pos_node_cover, heb_test_verses, data_dir_heb)
#target_data_loader_blinker = get_data_loadrs_for_target_editions(target_editions, blinker_test_dataset, blinker_pos_node_cover, blinker_verses, data_dir_blinker)

#res_ = torch.load(f'/mounts/work/ayyoob/results/gnn_postag/data/11lang-feature_vectors_posfeat{False}_transformer{True}_trainWE{False}_maskLang{True}_epoch{1}_Englishandallothertargets_typecheckFalse.torch.bin')
#tag_frequencies, tag_frequencies_target, train_pos_node_cover_ext, train_pos_labels_ext, grc_pos_node_cover_ext, grc_pos_labels_ext, heb_pos_node_cover_ext, heb_pos_labels_ext, blinker_pos_node_cover_ext, blinker_pos_labels_ext = res_
# res_ = posutil.get_tag_frequencies_node_tags(model, target_editions, train_pos_node_cover, train_pos_labels, grc_pos_node_cover, grc_pos_labels, heb_pos_node_cover, heb_pos_labels, blinker_pos_node_cover, blinker_pos_labels,
#                                     len(postag_map), target_data_loader_train, target_data_loader_grc, target_data_loader_heb, target_data_loader_blinker,
#                                     train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch, DataEncoder, target_train_treshold=threshold, type_check=False
#                                     , source_tag_frequencies=tag_frequencies_source)

# torch.save(res_, f'{len(pos_lang_list)}langs-nopersian_posfeatFalsealltargets_transformer{False}6layresresidual{False}_trainWE{False}_maskLang{True}_epoch{1}__{start_time}_earlystopping-GA_th{threshold}_typecheckFalse')
res_ = torch.load('/mounts/work/ayyoob/results/gnn_postag/data/11langs-nopersian_posfeatFalsealltargets_transformerFalse6layresresidualFalse_trainWEFalse_maskLangTrue_epoch1__20220306-170230_earlystopping-GA_th95_typecheckFalse')
tag_frequencies, tag_frequencies_target, train_pos_node_cover_ext, train_pos_labels_ext, grc_pos_node_cover_ext, grc_pos_labels_ext, heb_pos_node_cover_ext, heb_pos_labels_ext, blinker_pos_node_cover_ext, blinker_pos_labels_ext = res_

print(torch.sum(blinker_pos_labels_ext))
print(torch.sum(blinker_pos_labels))
print(torch.sum(blinker_pos_labels_ext) - torch.sum(blinker_pos_labels))
print((torch.sum(blinker_pos_labels_ext) - torch.sum(blinker_pos_labels))/273258) 

tag_frequencies_source = tag_frequencies - tag_frequencies_target
# print(1, torch.sum(tag_frequencies_target))
# posutil.keep_only_type_tags(tag_frequencies_target)
# print(1, torch.sum(tag_frequencies_target))
word_frequencies_target = torch.sum(tag_frequencies_target.to(torch.device('cpu')), dim=1)
tag_frequencies = tag_frequencies_source + tag_frequencies_target
tag_frequencies_copy = tag_frequencies.detach().clone()

tag_frequencies_copy[torch.logical_and(word_frequencies_target>0.1, word_frequencies_target<3), :] = 0.0000001

# We have to give uniform noise to some training examples to prevent the model from returning one of the most frequent tags always!!
uniform_noise = torch.BoolTensor(tag_frequencies.size(0))
uniform_noise[:] = True
shuffle_tensor = torch.randperm(tag_frequencies.size(0))[:int(tag_frequencies.size(0)*0.7)]
uniform_noise[shuffle_tensor] = False
tag_frequencies_copy[torch.logical_and(uniform_noise, word_frequencies_target < 0.1), :] = 0.0000001

sm = torch.sum(tag_frequencies_copy, dim=1)
normalized_tag_frequencies = (tag_frequencies_copy.transpose(1,0) / sm).transpose(1,0)

tensor(116768.)
tensor(82326.)
tensor(34442.)
tensor(0.1260)


In [35]:
from tqdm import tqdm
from torchvision import models
from torchsummary import summary

def create_model(train_gnn_dataset, grc_gnn_dataset, heb_gnn_dataset, blinker_gnn_dataset, test_gnn_dataset,
                tag_frequencies=False, use_transformers=False, train_word_embedding=False, mask_language=True, residual_connection = False,
                 params=''):
    global model, criterion, optimizer, early_stopping, start_time

    features = train_dataset.features[:]
    #features.append(afeatures.PassFeature(name='posTAG', dim=len(postag_map)))
    if tag_frequencies:
        features.append(afeatures.MappingFeature(len(postag_map), 'tag_priors', freeze=True))
    features[9].freeze = not train_word_embedding
    
    train_data_loader = DataLoader(train_gnn_dataset, batch_size=1, shuffle=True)
    grc_data_loader = DataLoader(grc_gnn_dataset, batch_size=1, shuffle=True)
    heb_data_loader = DataLoader(heb_gnn_dataset, batch_size=1, shuffle=True)
    blinker_data_loader = DataLoader(blinker_gnn_dataset, batch_size=1, shuffle=True)
    test_data_loader = DataLoader(test_gnn_dataset, batch_size=1, shuffle=False)

    clean_memory()
    drop_out = 0
    n_head = 1
    in_dim = sum(t.out_dim for t in features)
    start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    early_stopping = EarlyStopping(verbose=True, path=f'/mounts/work/ayyoob/models/gnn/checkpoint/postagging/check_point_{start_time}.pt', patience=5, delta=4)
    channels = 512
    decoder_in_dim = n_head * channels 

    print('len features', len(features), f'start time: {start_time}')

    # model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/gnn_512_flggll_word_halfTrain_nofeatlinear_encoderlineear_decoderonelayer20210910-235352-.pickle')
    if use_transformers:
        decoder = POSDecoderTransformer(decoder_in_dim, decoder_in_dim*2, len(postag_map), residual_connection, drop_out=drop_out).to(dev2)
    else:
        decoder = POSDecoder(decoder_in_dim, decoder_in_dim*2, len(postag_map))
        
    model = pyg_nn.GAE(Encoder(in_dim, channels, features, n_head, has_tagfreq_feature=tag_frequencies), decoder).to(dev)


    if use_transformers:
        decoder.to(dev2)

    # model.to(dev)
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

    torch.set_printoptions(edgeitems=5)
    print("model params - decoder params - conv1", sum(p.numel() for p in model.parameters()), sum(p.numel() for p in decoder.parameters()))

    for epoch in range(1, 2):
        print(f"\n----------------epoch {epoch} ---------------")
        
        train(epoch, train_data_loader, mask_language, test_data_loader)
        if not early_stopping.early_stop:
            train(epoch, train_data_loader, mask_language, test_data_loader)
        if not early_stopping.early_stop:
            train(epoch, grc_data_loader, mask_language, test_data_loader)
        if not early_stopping.early_stop:
            train(epoch, heb_data_loader, mask_language, test_data_loader)
        #train(epoch, blinker_data_loader, mask_language, test_data_loader)

        if early_stopping:
                model.load_state_dict(torch.load(f'/mounts/work/ayyoob/models/gnn/checkpoint/postagging/check_point_{start_time}.pt'))

        save_model(model, f'{len(pos_lang_list)}langs-nopersian_posfeat{tag_frequencies}alltargets_transformer{use_transformers}6layresresidual{residual_connection}_trainWE{train_word_embedding}_maskLang{mask_language}_epoch{epoch}_{params}_{start_time}_earlystopping-GA')
        test(epoch, test_data_loader, mask_language) 
        test_mostfreq(yor_data_loader_heb, True, yor_gold_mostfreq_tag, yor_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
        test_mostfreq(tam_data_loader_grc, True, tam_gold_mostfreq_tag, tam_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
        test_mostfreq(arb_data_loader_grc, True, arb_gold_mostfreq_tag, arb_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))
        test_mostfreq(por_data_loader_grc, True, por_gold_mostfreq_tag, por_gold_mostfreq_index, (w2v_model.wv.vectors.shape[0], len(postag_map)))

        clean_memory()


In [36]:
params = 'traintarget{threshold}_typecheckFalse'
gnn_dataset_train_pos_ext, gnn_dataset_grc_pos_ext, gnn_dataset_heb_pos_ext, gnn_dataset_blinker_pos_ext = create_me_a_gnn_dataset_you_stupid([ train_pos_node_cover_ext, grc_pos_node_cover_ext, heb_pos_node_cover_ext, blinker_pos_node_cover_ext]
   , [train_pos_labels_ext, grc_pos_labels_ext, heb_pos_labels_ext, blinker_pos_labels_ext], group_size=1024)
create_model(gnn_dataset_train_pos_ext, gnn_dataset_grc_pos_ext, gnn_dataset_heb_pos_ext, gnn_dataset_blinker_pos_ext, gnn_dataset_blinker_pos_bigbatch,
   train_word_embedding=False, mask_language=True, use_transformers=True, tag_frequencies=True, params=params)

#create_model(gnn_dataset_train_pos, gnn_dataset_grc_pos, gnn_dataset_heb_pos, gnn_dataset_blinker_pos, gnn_dataset_blinker_pos_bigbatch,
#   train_word_embedding=False, mask_language=True, use_transformers=False, tag_frequencies=True)
#model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatTrue_transformerTrue6layresresidual_trainWEFalse_maskLangTrue_epoch1_trainYoruba_20220218-223347.pickle')
#test(1, blinker_data_loader_bigbatch, True)
#importlib.reload(posutil)


posutil.generate_target_lang_tags(model, 'yor-x-bible-2010', f'{len(pos_lang_list)}langs-nopersian_posfeatTrueealltargets_transformerTrue6layresresidualFalse_trainWEFalse_epoch1_{params}_{start_time}_earlystopping-GA_', True, 
        train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
        DataEncoder)

posutil.generate_target_lang_tags(model, 'por-x-bible-versaointernacional',f'{len(pos_lang_list)}langs-nopersian_posfeatTruealltargets_transformerTrue6layresresidualFalse_trainWEFalse_epoch1_{params}_{start_time}_earlystopping-GA', True, 
        train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
        DataEncoder)
 
posutil.generate_target_lang_tags(model, 'tam-x-bible-newworld', f'{len(pos_lang_list)}langs-nopersian_posfeatTruealltargets_transformerTrue6layresresidualFalse_trainWEFalse_epoch1_{params}_{start_time}_earlystopping-GA', True, 
        train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
        DataEncoder)

rres = posutil.generate_target_lang_tags(model, 'arb-x-bible', f'{len(pos_lang_list)}langs-nopersian_posfeatTruealltargets_transformerTrue6layresresidualFalse_trainWEFalse_epoch1_{params}_{start_time}_earlystopping-GA', True, 
        train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
        DataEncoder)

len features 11 start time: 20220307-140939
model params - decoder params - conv1 291580899 15262225

----------------epoch 1 ---------------


  0%|          | 0/24143 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [37]:
model = None
decoder = None
clean_memory()
for item in engtest_data_loader:
    if item['verse'][0] == verse:
        print('found')
        break

NameError: name 'engtest_data_loader' is not defined

In [None]:
item.keys()

In [None]:
verse = list(eng_gen_data.keys())[0]
number_to_tag_map = {postag_map[i]:i for i in postag_map}
print(verse)
print({item[0]:number_to_tag_map[item[1]] for item in sorted(eng_gen_data[verse], key=lambda x: x[0])})

In [None]:
# model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatTrue_transformerTrue_trainWETrue_maskLangFalse_epoch3_20220218-101130.pickle')
# test(1, engtest_data_loader, mask_language=False) 

importlib.reload(posutil)
def get_language_node_cover(all_node_cover, target_edition, dataset):
    nodes_map = dataset.nodes_map
    res = collections.defaultdict(list)

    if target_edition in nodes_map:
        for verse in all_node_cover:
            if verse in nodes_map[target_edition]:
                lang_nodes = list(nodes_map[target_edition][verse].values())
                for tok in all_node_cover[verse]:
                    if tok in lang_nodes:
                        res[verse].append(tok)
    
    return res
    

def finetune_and_generate_for_target_lang(model_path, target_edition):
    global model, criterion, optimizer
    
    target_train_node_cover = get_language_node_cover(train_pos_node_cover_ext, target_edition, train_dataset)
    target_grc_node_cover = get_language_node_cover(grc_pos_node_cover_ext, target_edition, grc_test_dataset)
    target_heb_node_cover = get_language_node_cover(heb_pos_node_cover_ext, target_edition, heb_test_dataset)
    target_blinker_node_cover = get_language_node_cover(blinker_pos_node_cover_ext, target_edition, blinker_test_dataset)
    

    train_ds, grc_ds, heb_ds, blinker_ds = create_me_a_gnn_dataset_you_stupid(
            [target_train_node_cover, target_grc_node_cover, target_heb_node_cover, target_blinker_node_cover], 
            [train_pos_labels_ext, grc_pos_labels_ext,  heb_pos_labels_ext, blinker_pos_labels_ext], editions=[target_edition])
    
    train_data_loader = DataLoader(train_ds, shuffle=True)
    grc_data_loader = DataLoader(grc_ds, shuffle=True)
    heb_data_loader = DataLoader(heb_ds, shuffle=True)
    blinker_data_loader = DataLoader(blinker_ds, shuffle=True)

    # model = torch.load(model_path)
    # criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    # train(1, train_data_loader, mask_language=False, test_data_loader=blinker_data_loader, max_batches=100)
    # posutil.generate_target_lang_tags(model, target_edition, 'posfeatTrue_transformerTrue6layerresidual_trainWEFalse_finetune100', False, 
    #         train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
    #         DataEncoder)
    # model = None
    # clean_memory

    # model = torch.load(model_path)
    # criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    # train(1, train_data_loader, mask_language=False, test_data_loader=blinker_data_loader, max_batches=1000)
    posutil.generate_target_lang_tags(model, target_edition, 'posfeatTrue_transformerTrue6layerresidual_trainWEFalse_finetune1000', False, 
            train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
            DataEncoder)
    model = None
    clean_memory

    model = torch.load(model_path)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    train(1, train_data_loader, mask_language=False, test_data_loader=blinker_data_loader, max_batches=10000)
    posutil.generate_target_lang_tags(model, target_edition, 'posfeatTrue_transformerTrue6layerresidual_trainWEFalse_finetune10000', False, 
            train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
            DataEncoder)
    model = None
    clean_memory

    model = torch.load(model_path)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    train(1, train_data_loader, mask_language=False, test_data_loader=blinker_data_loader)
    train(1, grc_data_loader, mask_language=False, test_data_loader=blinker_data_loader)
    train(1, heb_data_loader, mask_language=False, test_data_loader=blinker_data_loader)
    train(1, blinker_data_loader, mask_language=False, test_data_loader=blinker_data_loader)
    posutil.generate_target_lang_tags(model, target_edition, 'posfeatTrue_transformerTrue6layerresidual_trainWEFalse_finetuneALL', False, 
            train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
            DataEncoder)


def generate_target_lang_tags_all_models(target_langs):
    global model
    model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatFalse_transformerFalse_trainWEFalse_maskLangTrue_20220209-201345.pickle', map_location=torch.device('cpu')).to(dev)
    test(0, blinker_data_loader_bigbatch, True)
    for lang in target_langs:
        posutil.generate_target_lang_tags(lang,  f"posfeatFalse_transformerFalse_trainWEFalse", True)

# posutil.generate_target_lang_tags_all_models(['yor-x-bible-2010', 'tam-x-bible-newworld', 'fin-x-bible-helfi'])
finetune_and_generate_for_target_lang('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatTrue_transformerTru6Lresidual_trainWETrue_maskLangFalse_epoch3_20220218-101130.pickle', 'yor-x-bible-2010')

In [None]:
i = sag
batch = khar
verse = gav
print(i, verse)

keys = list(gnn_dataset.verse_info.keys())

gnn_dataset.verse_info[verse]
print(keys)
#save_model(model, 'freeze-embedding_noLang')

In [None]:
#initial_model = model.to('cpu')
#model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatTrue_transformerFalse_trainWEFalse_maskLangTrue_20220210-173912.pickle')
#torch.cuda.set_device(0)
#model.to(dev)
#epoch = 0
#test_data_loader = DataLoader(gnn_dataset_blinker_pos_bigbatch, batch_size=1, shuffle=False)
#test(epoch, test_data_loader, mask_language=True)
##yoruba_postags = generate_target_lang_tags('yor-x-bible-2010', f"posfeatFalse_transformerFalse_trainWEFalse", True)


In [None]:
##importlib.reload(afeatures)
#initial_model = model
#torch.cuda.set_device(1)
#create_model(gnn_dataset_train_pos, gnn_dataset_grc_pos, gnn_dataset_heb_pos, gnn_dataset_blinker_pos, gnn_dataset_blinker_pos_bigbatch, tag_frequencies=True)
## yoruba_postags = generate_target_lang_tags('yor-x-bible-2010', f"posfeatTrue_transformerFalse_trainWEFalse", True)

In [None]:
torch.cuda.set_device(4)
gnn_ds_train_pos_ext, gnn_ds_grc_pos_ext, gnn_ds_heb_pos_ext, gnn_ds_blinker_pos_ext = create_me_a_gnn_dataset_you_stupid([train_pos_node_cover_ext, grc_pos_node_cover_ext, heb_pos_node_cover_ext, blinker_pos_node_cover_ext], 
    [train_pos_labels_ext, grc_pos_labels_ext, heb_pos_labels_ext, blinker_pos_labels_ext])

create_model(gnn_ds_train_pos_ext, gnn_ds_grc_pos_ext, gnn_ds_heb_pos_ext, gnn_ds_blinker_pos_ext, gnn_dataset_blinker_pos_bigbatch, tag_frequencies=True
    , use_transformers=False, train_word_embedding=True, mask_language=False)

# yoruba_postags = generate_target_lang_tags('yor-x-bible-2010', f"posfeatTrue_transformerTrue_trainWETrue", False)

In [None]:
torch.cuda.set_device(4)
model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatTrue_transformerTrue_trainWETrue_maskLangFalse_epoch1_20220214-190824.pickle')
torch.cuda.set_device(0)
model.to(dev)
epoch = 0
test_data_loader = DataLoader(gnn_dataset_blinker_pos_onlyeng, batch_size=1, shuffle=False)
test(epoch, test_data_loader, mask_language=False)
yoruba_postags = generate_target_lang_tags('fin-x-bible-helfi', f"posfeatTrue_transformerTrue_trainWETrue", False)


In [None]:

#normalized_gold_frequencies, gold_frequencies_all = get_words_tag_frequence(model, 2354770, len(postag_map), english_data_loader, from_gold_data=True)

#gold_frequencies_all = gold_frequencies
word_frequencies = torch.sum(gold_frequencies_all, dim=1)

subjectword_indices =  word_frequencies > 0.1
print(word_frequencies.shape)
gold_frequencies = gold_frequencies_all[subjectword_indices, :]
predicted_frequencies = tag_frequencies_english[subjectword_indices, :]
wordtype_frequencies = word_frequencies[subjectword_indices]
print(gold_frequencies.shape)

_, gold_tags = torch.max(gold_frequencies, dim=1)
_, predicted_tags = torch.max(predicted_frequencies, dim=1)

sorted_wordtype_frequencies, sort_pattern = torch.sort(wordtype_frequencies, descending=True)

sorted_gold_tags = gold_tags[sort_pattern]
sorted_predicted_tags = predicted_tags[sort_pattern]
quarter_size = int(sorted_gold_tags.size(0)/2.0)

print('quarter size', quarter_size)
print("general accuracy", torch.sum(gold_tags == predicted_tags)/predicted_tags.size(0))
print('first quarter accuracy', torch.sum(sorted_gold_tags[:quarter_size] == sorted_predicted_tags[:quarter_size])/quarter_size)
print('last part accuracy', torch.sum(sorted_gold_tags[1*quarter_size:] == sorted_predicted_tags[1*quarter_size:])/sorted_predicted_tags[1*quarter_size:].size(0))

print('total token count', torch.sum(wordtype_frequencies))
print('first quarter words token count', torch.sum(word_frequencies[:quarter_size]))

print('1st frequency', sorted_wordtype_frequencies[0])
print('10st frequency', sorted_wordtype_frequencies[10])
print('100st frequency', sorted_wordtype_frequencies[100])
print('100st frequency', sorted_wordtype_frequencies[736])
print('1000st frequency', sorted_wordtype_frequencies[1000])
print('10000st frequency', sorted_wordtype_frequencies[10000])


In [None]:
#normalized_tag_frequencies = torch.softmax(tag_frequencies_copy, dim=1)

sm = torch.sum(tag_frequencies_copy, dim=1)
normalized_tag_frequencies = (tag_frequencies_copy.transpose(1,0) / sm).transpose(1,0)

In [None]:
global model, decoder
#1/0

decoder = None
model = None

gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:

features = blinker_test_dataset.features[:]
#features_edge = train_dataset.features_edge[:]
from pprint import pprint
#print('indim',in_dim)
#features[-1].out_dim = 50
for i in features:
    #if i.type==3:
    #    i.out_dim=4
    print(vars(i))
print(vars(model.encoder.feature_encoder.layers[10].emb))
# print(model.encoder.feature_encoder.layers[10].emb)
#sum(p.out_dim for p in features)
#train_dataset.features.pop()
#train_dataset.features[0] = afeatures.OneHotFeature(20, 83, 'editf')
#train_dataset.features[1] = afeatures.OneHotFeature(32, 150, 'position')
#train_dataset.features[2] = afeatures.FloatFeature(4, 'degree_centrality')
#train_dataset.features[3] = afeatures.FloatFeature(4, 'closeness_centrality')
#train_dataset.features[4] = afeatures.FloatFeature(4, 'betweenness_centrality')
#train_dataset.features[5] = afeatures.FloatFeature(4, 'load_centrality')
#train_dataset.features[6] = afeatures.FloatFeature(4, 'harmonic_centrality')
#train_dataset.features[7] = afeatures.OneHotFeature(32, 250, 'greedy_modularity_community')
##train_dataset.features.append(afeatures.MappingFeature(100, 'word'))
#torch.save(train_dataset, "/mounts/work/ayyoob/models/gnn/dataset_helfi_train_community_word.pickle")
#torch.save(train_dataset.features[-3], "./features.tmp")

In [None]:
nodes_map = train_dataset.nodes_map
bad_edition_files = []
for edit in nodes_map:
    bad_count = 0
    for verse in nodes_map[edit]:
        if len(nodes_map[edit][verse].keys()) < 2:
            bad_count += 1
        if bad_count > 1:
            bad_edition_files.append(edit)
            break
print(bad_edition_files)

In [None]:
all_japanese_nodes = set()
nodes_map = train_dataset.nodes_map

for bad_editionf in bad_edition_files:
    for verse in nodes_map[bad_editionf]:
        for item in nodes_map[bad_editionf][verse].items():
            all_japanese_nodes.add(item[1])

print(" all japansese nodes: ", len(all_japanese_nodes))
edge_index = train_dataset.edge_index.to('cpu')
remaining_edges_index = []
for i in tqdm(range(0, edge_index.shape[1], 2)):
    if edge_index[0, i].item() not in all_japanese_nodes and edge_index[0, i+1].item() not in all_japanese_nodes:
        remaining_edges_index.extend([i, i+1])

print('original total edges count', edge_index.shape)
print('remaining edge count', len(remaining_edges_index))
train_dataset.edge_index = edge_index[:, remaining_edges_index]
train_dataset.edge_index.shape
