In [2]:
import torch, sys
sys.path.insert(0, '../')
from my_utils import gpu_utils
import importlib, gc
from my_utils.alignment_features import *
import my_utils.alignment_features as afeatures
importlib.reload(afeatures)
import gnn_utils.graph_utils as gutils
import postag_utils as posutil



In [3]:
dev = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
dev2 = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


In [4]:

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

import time
from datetime import datetime

import networkx as nx
import numpy as np
import torch
import torch.optim as optim

from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader

import torch_geometric.transforms as T

from tensorboardX import SummaryWriter
from sklearn.manifold import TSNE
# import matplotlib.pyplot as plt




In [5]:
from my_utils import align_utils as autils, utils
import argparse
from multiprocessing import Pool
import random

# set random seed
config_file = "/mounts/Users/student/ayyoob/Dokumente/code/pbc-ui-demo/config_pbc.ini"
utils.setup(config_file)

params = argparse.Namespace()


params.gold_file = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-grc-gold-alignments_train.txt"
pros, surs = autils.load_gold(params.gold_file)
all_verses = list(pros.keys())
params.gold_file = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-heb-gold-alignments_train.txt"
pros, surs = autils.load_gold(params.gold_file)
all_verses.extend(list(pros.keys()))
all_verses = list(set(all_verses))
print(len(all_verses))

params.editions_file =  "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi_lang_list.txt"
editions, langs = autils.load_simalign_editions(params.editions_file)
current_editions = [editions[lang] for lang in langs]

def get_pruned_verse_alignments(args):
    verse, current_editions = args
    
    verse_aligns_inter = autils.get_verse_alignments(verse)
    verse_aligns_gdfa = autils.get_verse_alignments(verse, gdfa=True)

    autils.prune_non_necessary_alignments(verse_aligns_inter, current_editions)
    autils.prune_non_necessary_alignments(verse_aligns_gdfa, current_editions)

    gc.collect()
    return verse_aligns_inter, verse_aligns_gdfa
    

args = []
for i,verse in enumerate(all_verses):
    args.append((verse, current_editions[:]))


24159


In [6]:
#importlib.reload(afeatures)

class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, features, n_head = 2, has_tagfreq_feature=False,):
        super(Encoder, self).__init__()
        self.conv1 = pyg_nn.GATConv(in_channels, 2*out_channels, heads= n_head)
        self.conv2 = pyg_nn.GATConv(2 * n_head *  out_channels , out_channels, heads= 1)
        #self.fin_lin = nn.Linear(out_channels, out_channels)
        
        if has_tagfreq_feature:
            self.feature_encoder = afeatures.FeatureEncoding(features, [normalized_tag_frequencies, word_vectors])
            #self.feature_encoder = afeatures.FeatureEncoding(features, [normalized_tag_frequencies,train_pos_labels, word_vectors])
        else:
            self.feature_encoder = afeatures.FeatureEncoding(features, [word_vectors])
            #self.feature_encoder = afeatures.FeatureEncoding(features, [train_pos_labels, word_vectors])

    def forward(self, x, edge_index):
        encoded = self.feature_encoder(x, dev)
        x = F.elu(self.conv1(encoded, edge_index, ))
        x = F.elu(self.conv2(x, edge_index))
        #return F.relu(self.fin_lin(x)), encoded
        return x, encoded

In [7]:
def freeze_encoders_embedding(encoder):
    for i,ft in enumerate(encoder.feature_types):
        if ft.type == MAPPING:
            print('doing it')
            encoder.layers[i] = afeatures.MappingEncoding(encoder.layers[i].emb.weight, freeze=True)

In [30]:
def clean_memory():
    gc.collect()
    with torch.no_grad():
        torch.cuda.empty_cache()

class DataEncoder():

    def __init__(self, data_loader, model, mask_language):
        self.data_loader = data_loader
        self.model = model
        self.mask_language = mask_language
    
    def __iter__(self):
        for i,batch in enumerate(tqdm(self.data_loader)):
            try:
                optimizer.zero_grad()
            except NameError:
                # it is ok since if we call test before train, optimizer is not defined
                pass

            x = batch['x'][0].to(dev)  # initial features (not encoded)
            edge_index = batch['edge_index'][0].to(dev) 
            verse = batch['verse'][0]

            if verse in masked_verses:
                continue

            try:
                if self.mask_language:
                    x[:, 0] = 0
                z, encoded = self.model.encode(x, edge_index) # Z will be the output of the GNN
                batch['encoded'] = encoded
            except Exception as e:
                global sag, khar, gav
                sag, khar, gav =  (i, batch, verse)
                print(e)
                1/0
            
            yield z, verse, i, batch

def train(epoch, data_loader, mask_language, test_data_loader, max_batches=999999999):
    global optimizer
    total_loss = 0
    model.train()
    loss_multi_round = 0

    data_encoder = DataEncoder(data_loader, model, mask_language)

    for z, verse, i, batch in data_encoder:
        
        target = batch['pos_classes'][0].to(dev)
        _, labels = torch.max(target, 1)
        
        index = batch['pos_index'][0].to(dev)

        preds = model.decoder(z, index, batch)

        # print(preds.shape, labels.shape)
        loss = criterion(preds, labels)
        loss = loss * target.shape[0] # TODO check if this is necessary
        loss.backward()
        total_loss += loss.item()

        if (i+1) % 1 == 0: # Gradient accumulation
            optimizer.step()
            

        if i % 500 == 499:
            print(f"loss: {total_loss}")
            total_loss = 0
            test(epoch, test_data_loader, mask_language)
            model.train()
            clean_memory()

        if i == max_batches:
            break
    
    print(f"total train loss: {total_loss}")


In [9]:
class POSDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, n_class, drop_out=0):
        super(POSDecoder, self).__init__()

        self.transfer = nn.Sequential(nn.Linear(input_size, hidden_size), nn.ReLU(), nn.Dropout(drop_out),
                        nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(drop_out),
                        nn.Linear(hidden_size, n_class))

    def forward(self, z, index, batch=None):
        h = z[index, :]

        res = self.transfer(h)

        return res

class POSDecoderTransformer(nn.Module):
    def __init__(self, input_size, hidden_size, n_class, drop_out=0):
        super(POSDecoderTransformer, self).__init__()

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=8, dim_feedforward=hidden_size)
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=6)

        self.transfer = nn.Sequential( nn.Linear(input_size, hidden_size), nn.ReLU(), nn.Dropout(drop_out), # TODO check what happens if I remove this.
                        nn.Linear(hidden_size, n_class))

    def forward(self, z_, index, batch_):
        z = z_.to(dev2)

        x = F.pad(batch_['encoded'], (0, z.size(1) - batch_['encoded'].size(1))).to(dev2)


        language_based_nodes = batch_['lang_based_nodes'] # determines which node belongs to which language
        transformer_indices = batch_['transformer_indices'] # the reverse of the prev structure

        sentences = []
        for lang_nodes in language_based_nodes: # we rearrange the nodes into sentences of each language
            tensor = z[lang_nodes, :] + x[lang_nodes, :]
            tensor = F.pad(tensor, (0, 0, 0, 150 - tensor.size(0)))
            sentences.append(tensor)
        
        batch = torch.stack(sentences) # A batch contains all translations of one sentence in all training languages.
        batch = torch.transpose(batch, 0, 1)

        h = self.transformer(batch)
        h = torch.transpose(h, 0, 1)
        h = h[transformer_indices[0], transformer_indices[1], :] # rearrange the nodes back to the order in which we recieved (the order that represents the graph)

        res = self.transfer(h)

        return res.to(dev)

In [10]:
def test(epoch, testloader, mask_language, filter_wordtypes=None):
    print('testing',  epoch)
    model.eval()
    total = 0
    correct = 0

    data_encoder = DataEncoder(testloader, model, mask_language)
    
    with torch.no_grad():
        for z, verse, i, batch in data_encoder:
            
            target = batch['pos_classes'][0].to(dev)
            index = batch['pos_index'][0].to(dev)
            
            if filter_wordtypes != None:
                non_filtered_words = filter_wordtypes[batch['x'][0][:, 9].long()] == 1
                non_filtered_words = non_filtered_words[index]
                index = index[non_filtered_words]

                target = target[non_filtered_words, :]

            preds = model.decoder(z, index, batch)
            
            if preds.size(0) > 0:
                _, predicted = torch.max(preds, 1)
                _, labels = torch.max(target, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'test, epoch: {epoch}, total:{total} ACC: {correct/total}')
    clean_memory()


In [11]:
def majority_voting_test(data_loader1, data_loader2):
    total = 0
    correct = 0
    
    for i,(batch, batch2) in enumerate(tqdm(zip(data_loader1, data_loader2))) :
            
        x = batch['x'][0]
        edge_index = batch['edge_index'][0]
        verse = batch['verse'][0]

        if verse in masked_verses:
            continue

        target = batch['pos_classes'][0]
        index = batch['pos_index'][0]

        index2 = batch2['pos_index'][0]
        

        for node, label in zip(index,target):
            other_side = edge_index[1, edge_index[0, :] == node]
            other_side_withpos = other_side[[True if i in index2 else False for i in other_side]]
            other_side_target_indices = [(i == index2).nonzero(as_tuple=True)[0].item() for i in other_side_withpos]
            #print(other_side_target_indices)
            proj_tags = batch2['pos_classes'][0][other_side_target_indices]

            if proj_tags.size(0) > 0:
                _, proj_tags = torch.max(proj_tags, 1)
                #print(target.shape, node, index.shape, proj_tags, other_side)
                
                if torch.argmax(label) == torch.mode(proj_tags)[0]:
                    correct += 1
                
                total += 1

    print(f'test, , total:{total} ACC: {correct/total}')


In [12]:
#print(wordtype_frequencies[736])
#print(wordtype_frequencies[1473])
#print(wordtype_frequencies[3683])
#print(wordtype_frequencies[7367])
#print(wordtype_frequencies[14733])

In [13]:
#frequent_words = torch.zeros(word_frequencies.size(0))
#frequent_words[word_frequencies > -1] = 1

#test(1, test_data_loader, filter_wordtypes=frequent_words)

In [14]:
from gensim.models import Word2Vec
w2v_model = Word2Vec.load("/mounts/work/ayyoob/models/w2v/word2vec_helfi_langs_15e.model")

print(w2v_model.wv.vectors.shape)
word_vectors = torch.from_numpy(w2v_model.wv.vectors).float()

(2354770, 100)


In [15]:
import pickle

train_verses = all_verses[:]
test_verses = all_verses[:] 
editf1 = 'fin-x-bible-helfi'
editf2 = "heb-x-bible-helfi"


if 'jpn-x-bible-newworld' in  current_editions[:]:
     current_editions.remove('jpn-x-bible-newworld')
if 'grc-x-bible-unaccented' in  current_editions[:]:
     current_editions.remove('grc-x-bible-unaccented')



data_dir_train = "/mounts/data/proj/ayyoob/align_induction/dataset/dataset_helfi_train_community_word"
data_dir_blinker = "/mounts/data/proj/ayyoob/align_induction/dataset/pruned_alignments_blinker_inter/"
data_dir_grc = "/mounts/data/proj/ayyoob/align_induction/dataset/dataset_helfi_grc_community_word/"
data_dir_heb = "/mounts/data/proj/ayyoob/align_induction/dataset/dataset_helfi_heb_community_word/"

train_dataset = torch.load(f"{data_dir_train}/train_dataset_nox_noedge.torch.bin")
blinker_test_dataset = torch.load(f"{data_dir_blinker}/train_dataset_nox_noedge.torch.bin")
grc_test_dataset = torch.load(f"{data_dir_grc}/train_dataset_nox_noedge.torch.bin")
heb_test_dataset = torch.load(f"{data_dir_heb}/train_dataset_nox_noedge.torch.bin")


In [16]:
import codecs
import collections

postag_map = {"ADJ": 0, "ADP": 1, "ADV": 2, "AUX": 3, "CCONJ": 4, "DET": 5, "INTJ": 6, "NOUN": 7, "NUM": 8, "PART": 9, "PRON": 10, "PROPN": 11, "PUNCT": 12, "SCONJ": 13, "SYM": 14, "VERB": 15, "X": 16}

pos_lang_list = ["eng-x-bible-mixed", "deu-x-bible-newworld", "ces-x-bible-newworld", 
		"fra-x-bible-louissegond","hin-x-bible-newworld", "ita-x-bible-2009", 
		"prs-x-bible-goodnews", "ron-x-bible-2006", "spa-x-bible-newworld"]

def get_db_nodecount(dataset):
	res = 0
	for lang in dataset.nodes_map.values():
		for verse in lang.values():
			res += len(verse)
	
	return res

def get_language_nodes(dataset, lang_list, sentences):
	node_count = get_db_nodecount(dataset)
	pos_labels = torch.zeros(node_count, len(postag_map))

	pos_node_cover = collections.defaultdict(list)
	for lang in lang_list:
		for sentence in sentences:
			if sentence in dataset.nodes_map[lang]:
				for tok in dataset.nodes_map[lang][sentence]:
					pos_node_cover[sentence].append(dataset.nodes_map[lang][sentence][tok])
	
	return pos_labels, pos_node_cover


def get_pos_tags(dataset, pos_lang_list):
	all_tags = {}
	for lang in pos_lang_list:
		if lang not in dataset.nodes_map:
			continue
		all_tags[lang] = {}
		with codecs.open(F"/mounts/work/mjalili/projects/gnn-align/data/pbc_pos_tags/{lang}.conllu", "r", "utf-8") as lang_pos:
			tag_sent = []
			sent_id = ""
			for sline in lang_pos:
				sline = sline.strip()
				if sline == "":
					if sent_id not in dataset.nodes_map[lang]:
						tag_sent = []
						sent_id = ""
						continue

					all_tags[lang][sent_id] = [p[3] for p in tag_sent]
					tag_sent = []
					sent_id = ""
				elif "# verse_id" in sline:
					sent_id = sline.split()[-1]
				elif sline[0] == "#":
					continue
				else:
					tag_sent.append(sline.split("\t"))

	node_count = get_db_nodecount(dataset)
	pos_labels = torch.zeros(node_count, len(postag_map))
	pos_node_cover = collections.defaultdict(list)

	for lang in all_tags:
		for sent_id in all_tags[lang]:
			sent_tags = all_tags[lang][sent_id]
			for w_i in range(len(sent_tags)):
				if w_i not in dataset.nodes_map[lang][sent_id]:
					continue
				pos_labels[dataset.nodes_map[lang][sent_id][w_i], postag_map[sent_tags[w_i]]] = 1
				pos_node_cover[sent_id].append(dataset.nodes_map[lang][sent_id][w_i])

	return pos_labels, pos_node_cover
	#pos_pickle = {"pos_labels": pos_labels, "node_ids_train": pos_ids_train, "node_ids_dev": pos_ids_dev}
	#torch.save(pos_pickle, '/mounts/work/ayyoob/models/gnn/postag')


In [17]:
# blinker_test_dataset = torch.load("/mounts/work/ayyoob/models/gnn/dataset_blinker_full_community_word.pickle", map_location=torch.device('cpu'))
editf12 = "eng-x-bible-mixed"
editf22 = 'fra-x-bible-louissegond'

test_gold_eng_fra = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/eng_fra_pbc/eng-fra.gold"

pros_blinker, surs_blinker = autils.load_gold(test_gold_eng_fra)

blinker_verse_alignments_inter = {}

verses_map = {}

for edit in blinker_test_dataset.nodes_map:
    for verse in blinker_test_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in blinker_test_dataset.nodes_map[edit][verse]:
                verses_map[verse] = blinker_test_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
blinker_verses = [item[0] for item in sorted_verses]


In [18]:
#importlib.reload(afeatures)
#grc_test_dataset = torch.load("/mounts/work/ayyoob/models/gnn/dataset_helfi_grc_test_community_word.pickle", map_location=torch.device('cpu'))
editf_fin = "fin-x-bible-helfi"
editf_grc = 'grc-x-bible-helfi'

test_gold_grc = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-grc-gold-alignments_test.txt"

pros_grc, surs_grc = autils.load_gold(test_gold_grc)

grc_test_verse_alignments_inter = {}
grc_test_verse_alignments_gdfa = {}
gc.collect()

verses_map = {}

for edit in grc_test_dataset.nodes_map:
    for verse in grc_test_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in grc_test_dataset.nodes_map[edit][verse]:
                verses_map[verse] = grc_test_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
grc_test_verses = [item[0] for item in sorted_verses]

gc.collect()

0

In [19]:
#heb_test_dataset = torch.load("/mounts/work/ayyoob/models/gnn/dataset_helfi_heb_test_community_word.pickle", map_location=torch.device('cpu'))

test_gold_heb = "/mounts/Users/student/ayyoob/Dokumente/code/pbc_utils/data/helfi/splits/helfi-fin-heb-gold-alignments_test.txt"

pros_heb, surs_heb = autils.load_gold(test_gold_heb)

verses_map = {}

for edit in heb_test_dataset.nodes_map:
    for verse in heb_test_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in heb_test_dataset.nodes_map[edit][verse]:
                verses_map[verse] = heb_test_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
heb_test_verses = [item[0] for item in sorted_verses]
gc.collect()

0

In [20]:
verses_map = {}

for edit in train_dataset.nodes_map:
    for verse in train_dataset.nodes_map[edit]:
        if verse not in verses_map:
            for tok in train_dataset.nodes_map[edit][verse]:
                verses_map[verse] = train_dataset.nodes_map[edit][verse][tok]
                break

sorted_verses = sorted(verses_map.items(), key = lambda x: x[1])
all_verses = [item[0] for item in sorted_verses]

long_verses = set()

for edit in train_dataset.nodes_map.keys():
    for verse in train_dataset.nodes_map[edit]:
        to_print = False
        for tok in train_dataset.nodes_map[edit][verse]:
            if tok > 150:
                to_print = True
        if to_print == True:
            long_verses.add(verse)


train_verses = all_verses[:]

masked_verses = list(long_verses)
#masked_verses.extend(blinker_verses)

In [39]:
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import random

def get_language_based_nodes(nodes_map, verse, train_nodes, padding):
    res = []
    transformer_indices = [[-1 for i in range(len(train_nodes))],[-1 for i in range(len(train_nodes))]]
    
    lang_ind = 0
    for lang in nodes_map:
        if verse in nodes_map[lang]:
            items = nodes_map[lang][verse].items()
            items = sorted(items, key=lambda i: i[0])

            to_add = []
            for i, it in enumerate(items):
                to_add.append(it[1] - padding)
                if it[1] in train_nodes:
                    index = train_nodes.index(it[1])
                    transformer_indices[0][index] = lang_ind
                    transformer_indices[1][index] = i

            res.append(to_add)
            lang_ind += 1
            
    return res, transformer_indices

class POSTAGGNNDataset(Dataset):

    def __init__(self, dataset, verses, edit_files, alignments, node_cover, pos_labels, data_dir, create_data=False, group_size = 20):
        self.node_cover = node_cover
        self.pos_labels = pos_labels
        self.data_dir = data_dir
        self.items = self.calculate_size(verses, group_size, node_cover)
        self.dataset = dataset

        if create_data:
            self.calculate_verse_stats(verses, edit_files, alignments, dataset, data_dir)            
        
    def calculate_size(self, verses, group_size, node_cover):
        res = []
        for verse in verses:
            covered_nodes = node_cover[verse]
            random.shuffle(covered_nodes)
            items = [covered_nodes[i:i + group_size] for i in range(0, len(covered_nodes), group_size)]
            res.extend([(verse, i) for i in items])

        return res

    def calculate_verse_stats(self,verses, edition_files, alignments, dataset, data_dir):

        min_edge = 0
        for verse in tqdm(verses):
            min_nodes = 99999999999999
            max_nodes = 0
            edges_tmp = [[],[]]
            x_tmp = []
            features = []
            for i,editf1 in enumerate(edition_files):
                for j,editf2 in enumerate(edition_files[i+1:]):
                    aligns = autils.get_aligns(editf1, editf2, alignments[verse])
                    if aligns != None:
                        for align in aligns:
                            try:
                                n1,_ = gutils.node_nom(verse, editf1, align[0], None, dataset.nodes_map, x_tmp, edition_files, features)
                                n2,_ = gutils.node_nom(verse, editf2, align[1], None, dataset.nodes_map, x_tmp, edition_files, features)
                                edges_tmp[0].extend([n1, n2])

                                max_nodes = max(n1, n2, max_nodes)
                                min_nodes = min(n1, n2, min_nodes)
                            except Exception as e:
                                print(editf1, editf2, verse)
                                raise(e)

            self.verse_info = {}

            self.verse_info['padding'] = min_nodes
            
            self.verse_info['x'] = torch.clone(dataset.x[min_nodes:max_nodes+1,:])
            
            self.verse_info['edge_index'] = torch.clone(dataset.edge_index[:, min_edge : min_edge + len(edges_tmp[0])] - min_nodes)

            if torch.min(self.verse_info['edge_index']) != 0:
                print(verse, min_nodes, max_nodes, min_edge, len(edges_tmp[0]))
                print(torch.min(self.verse_info['edge_index']))
            
            if self.verse_info['x'].shape[0] != torch.max(self.verse_info['edge_index']) + 1 :
                print(verse, min_nodes, max_nodes, min_edge, len(edges_tmp[0]))
                print(torch.min(self.verse_info['edge_index']))
            
            min_edge = min_edge + len(edges_tmp[0])

            torch.save(self.verse_info, f"{data_dir}/verses/{verse}_info.torch.bin")
        
        dataset.x = None
        dataset.edge_index = None
        torch.save(dataset, f"{data_dir}/train_dataset_nox_noedge.torch.bin")
    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        verse, nodes = self.items[idx]
        
        self.verse_info = {verse: torch.load(f'{self.data_dir}/verses/{verse}_info.torch.bin')}


        word_number = self.verse_info[verse]['x'][:, 9]
        padding = self.verse_info[verse]['padding']
        
        language_based_nodes, transformer_indices = get_language_based_nodes(self.dataset.nodes_map, verse, nodes, padding)

        ## # Add POSTAG to set of features
        #postags = self.pos_labels[padding: self.verse_info[verse]['x'].size(0) + padding, : ]
        #postags = postags.detach().clone()
        #postags[torch.LongTensor(nodes) - padding, :] = 0
        #self.verse_info[verse]['x'] = torch.cat((self.verse_info[verse]['x'], postags), dim=1)

        # Add token id as a feature, used to extract token information (like token's tag distribution)
        word_number = torch.unsqueeze(word_number, 1)
        self.verse_info[verse]['x'] = torch.cat((self.verse_info[verse]['x'], word_number), dim=1)

        return {'verse':verse, 'x':self.verse_info[verse]['x'], 'edge_index':self.verse_info[verse]['edge_index'], 
                'pos_classes': self.pos_labels[nodes, :], 'pos_index': torch.LongTensor(nodes) - padding, 
                'padding': padding, 'lang_based_nodes': language_based_nodes, 'transformer_indices': transformer_indices}

def create_me_a_gnn_dataset_you_stupid(node_covers, labels, group_size=100, editions=current_editions):

    train_ds = POSTAGGNNDataset(train_dataset, train_verses, editions, {}, node_covers[0], labels[0], data_dir_train, group_size=group_size)
    grc_ds = POSTAGGNNDataset(grc_test_dataset, grc_test_verses, editions, {}, node_covers[1], labels[1], data_dir_grc, group_size=group_size)
    heb_ds = POSTAGGNNDataset(heb_test_dataset, heb_test_verses, editions, {}, node_covers[2], labels[2], data_dir_heb, group_size=group_size)
    blinker_ds = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, editions, {}, node_covers[3], labels[3], data_dir_blinker, group_size=group_size)

    return train_ds, grc_ds, heb_ds, blinker_ds

#train_pos_labels, train_pos_node_cover = get_pos_tags(train_dataset, pos_lang_list)
#torch.save({'pos_labels':train_pos_labels, 'pos_node_cover':train_pos_node_cover}, f'{data_dir_train}/pos_data.torch.bin')
pos_data = torch.load(f'{data_dir_train}/pos_data.torch.bin')
train_pos_labels, train_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']

##blinker_pos_labels, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, pos_lang_list)
##torch.save({'pos_labels':blinker_pos_labels, 'pos_node_cover': blinker_pos_node_cover}, f'{data_dir_blinker}/pos_data.torch.bin')
pos_data = torch.load(f'{data_dir_blinker}/pos_data.torch.bin')
blinker_pos_labels, blinker_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']

grc_pos_labels, grc_pos_node_cover = get_pos_tags(grc_test_dataset, pos_lang_list)
torch.save({'pos_labels':grc_pos_labels, 'pos_node_cover': grc_pos_node_cover}, f'{data_dir_grc}/pos_data.torch.bin')
#pos_data = torch.load(f'{data_dir_grc}/pos_data.torch.bin')
#grc_pos_labels, grc_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']

heb_pos_labels, heb_pos_node_cover = get_pos_tags(heb_test_dataset, pos_lang_list)
torch.save({'pos_labels':heb_pos_labels, 'pos_node_cover': heb_pos_node_cover}, f'{data_dir_heb}/pos_data.torch.bin')
#pos_data = torch.load(f'{data_dir_heb}/pos_data.torch.bin')
#heb_pos_labels, heb_pos_node_cover = pos_data['pos_labels'], pos_data['pos_node_cover']

gnn_dataset_train_pos, gnn_dataset_grc_pos, gnn_dataset_heb_pos, gnn_dataset_blinker_pos = create_me_a_gnn_dataset_you_stupid(
    [train_pos_node_cover, grc_pos_node_cover, heb_pos_node_cover, blinker_pos_node_cover], [train_pos_labels, grc_pos_labels, heb_pos_labels, blinker_pos_labels], group_size=100)

gnn_dataset_train_pos_bigbatch, gnn_dataset_grc_pos_bigbatch, gnn_dataset_heb_pos_bigbatch, gnn_dataset_blinker_pos_bigbatch = create_me_a_gnn_dataset_you_stupid(
    [train_pos_node_cover, grc_pos_node_cover, heb_pos_node_cover, blinker_pos_node_cover], [train_pos_labels, grc_pos_labels, heb_pos_labels, blinker_pos_labels], group_size=10000)

train_data_loader_bigbatch = DataLoader(gnn_dataset_train_pos_bigbatch, batch_size=1, shuffle=False)
grc_data_loader_bigbatch = DataLoader(gnn_dataset_grc_pos_bigbatch, batch_size=1, shuffle=False)
heb_data_loader_bigbatch = DataLoader(gnn_dataset_heb_pos_bigbatch, batch_size=1, shuffle=False)
blinker_data_loader_bigbatch = DataLoader(gnn_dataset_blinker_pos_bigbatch, batch_size=1, shuffle=False)


In [22]:
#no_eng_langs = pos_lang_list[:]
#no_eng_langs.remove('eng-x-bible-mixed')

## # train_pos_labels, train_pos_node_cover = get_pos_tags(train_dataset, no_eng_langs)
## # gnn_dataset_train_pos = POSTAGGNNDataset(train_dataset, train_verses, current_editions, verse_alignments_inter,
## #                        train_pos_node_cover, train_pos_labels, data_dir_train, group_size = 10)


#_, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, ['eng-x-bible-mixed'])
#blinker_pos_labels, _ = get_pos_tags(blinker_test_dataset, pos_lang_list)
#gnn_dataset_blinker_pos_onlyeng = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
#                            blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 500)



In [23]:
def save_model(model, name):
    #model.encoder.feature_encoder.feature_types[0] = afeatures.OneHotFeature(20, 83, 'editf')
    #model.encoder.feature_encoder.feature_types[1] = afeatures.OneHotFeature(32, 150, 'position')
    #model.encoder.feature_encoder.feature_types[2] = afeatures.FloatFeature(4, 'degree_centrality')
    #model.encoder.feature_encoder.feature_types[3] = afeatures.FloatFeature(4, 'closeness_centrality')
    #model.encoder.feature_encoder.feature_types[4] = afeatures.FloatFeature(4, 'betweenness_centrality')
    #model.encoder.feature_encoder.feature_types[5] = afeatures.FloatFeature(4, 'load_centrality')
    #model.encoder.feature_encoder.feature_types[6] = afeatures.FloatFeature(4, 'harmonic_centrality')
    #model.encoder.feature_encoder.feature_types[7] = afeatures.OneHotFeature(32, 250, 'greedy_modularity_community')
    #model.encoder.feature_encoder.feature_types[8] = afeatures.OneHotFeature(32, 250, 'community_2')
    #model.encoder.feature_encoder.feature_types[9] = afeatures.MappingFeature(100, 'word')
    #model.encoder.feature_encoder.feature_types[10] = afeatures.MappingFeature(len(postag_map), 'tag_priors', freeze=True)
    torch.save(model, f'/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_{name}_' + datetime.now().strftime("%Y%m%d-%H%M%S") + '.pickle')

In [24]:
eng_test_pos_labels, eng_test_pos_node_cover = get_pos_tags(blinker_test_dataset, ['eng-x-bible-mixed'])
gnn_dataset_engtest_pos = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, {},
                      eng_test_pos_node_cover, eng_test_pos_labels, data_dir_blinker, group_size = 500)
engtest_data_loader = DataLoader(gnn_dataset_engtest_pos, batch_size=1, shuffle=False)

#test(1, engtest_data_loader, mask_language=False) 

finetune_pos_labels, finetune_pos_node_cover = get_pos_tags(train_dataset, ['eng-x-bible-mixed'])
gnn_dataset_finetune_pos = POSTAGGNNDataset(train_dataset, train_verses, current_editions, {},
                      finetune_pos_node_cover, finetune_pos_labels, data_dir_train, group_size = 100)
finetune_data_loader = DataLoader(gnn_dataset_finetune_pos, batch_size=1, shuffle=False)


#train(1, finetune_data_loader, mask_language=False)
#test(1, engtest_data_loader, mask_language=False) 

# blinker_pos_labels, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, ['eng-x-bible-mixed'])
# gnn_dataset_blinker_pos = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
#                              blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 10000)

# blinker_pos_labels, blinker_pos_node_cover = get_pos_tags(blinker_test_dataset, no_eng_langs)
# gnn_dataset_blinker_pos_majvoting_test = POSTAGGNNDataset(blinker_test_dataset, blinker_verses, current_editions, blinker_verse_alignments_inter,
#                              blinker_pos_node_cover, blinker_pos_labels, data_dir_blinker, group_size = 10000)

# test_data_loader = DataLoader(gnn_dataset_blinker_pos, batch_size=1, shuffle=False)
# test_data_loader_majvoting = DataLoader(gnn_dataset_blinker_pos_majvoting_test, batch_size=1, shuffle=False)
# majority_voting_test(test_data_loader, test_data_loader_majvoting)

In [25]:
#importlib.reload(posutil)
#target_data_loader_train = get_data_loadrs_for_target_editions(['yor-x-bible-2010'], train_dataset, train_pos_node_cover, train_verses, data_dir_train)
#target_data_loader_grc = get_data_loadrs_for_target_editions(['yor-x-bible-2010'], grc_test_dataset, grc_pos_node_cover, grc_test_verses, data_dir_grc)
#target_data_loader_heb = get_data_loadrs_for_target_editions(['yor-x-bible-2010'], heb_test_dataset, heb_pos_node_cover, heb_test_verses, data_dir_heb)
#target_data_loader_blinker = get_data_loadrs_for_target_editions(['yor-x-bible-2010'], blinker_test_dataset, blinker_pos_node_cover, blinker_verses, data_dir_blinker)

#res_ = posutil.get_tag_frequencies_node_tags(model, ['yor-x-bible-2010'], train_pos_node_cover, train_pos_labels, grc_pos_node_cover, grc_pos_labels, heb_pos_node_cover, heb_pos_labels, blinker_pos_node_cover, blinker_pos_labels,
#                                    len(postag_map), target_data_loader_train, target_data_loader_grc, target_data_loader_heb, target_data_loader_blinker,
#                                    train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch, DataEncoder)
##torch.save(res_, f'/mounts/work/ayyoob/results/gnn_postag/data/feature_vectors_posfea{True}_targetyor.torch.bin')
##res_ = torch.load('/mounts/work/ayyoob/results/gnn_postag/data/feature_vectors_posfeaTrue_targetyor.torch.bin')
#tag_frequencies, tag_frequencies_target, train_pos_node_cover_ext, train_pos_labels_ext, grc_pos_node_cover_ext, grc_pos_labels_ext, heb_pos_node_cover_ext, heb_pos_labels_ext, blinker_pos_node_cover_ext, blinker_pos_labels_ext = res_


#word_frequencies_target = torch.sum(tag_frequencies_target, dim=1)
#tag_frequencies_copy = tag_frequencies.detach().clone()

#tag_frequencies_copy[torch.logical_and(word_frequencies_target>0.1, word_frequencies_target<3), :] = 0.0000001

## We have to give uniform noise to some training examples to prevent the model from returning one of the most frequent tags always!!
#uniform_noise = torch.BoolTensor(tag_frequencies.size(0))
#uniform_noise[:] = True
#shuffle_tensor = torch.randperm(tag_frequencies.size(0))[:int(tag_frequencies.size(0)*0.7)]
#uniform_noise[shuffle_tensor] = False
#tag_frequencies_copy[torch.logical_and(uniform_noise, word_frequencies_target < 0.1), :] = 0.0000001

#sm = torch.sum(tag_frequencies_copy, dim=1)
#normalized_tag_frequencies = (tag_frequencies_copy.transpose(1,0) / sm).transpose(1,0)

In [31]:
from tqdm import tqdm
from torchvision import models
from torchsummary import summary

def create_model(train_gnn_dataset, grc_gnn_dataset, heb_gnn_dataset, blinker_gnn_dataset, test_gnn_dataset,
                tag_frequencies=False, use_transformers=False, train_word_embedding=False, mask_language=True):
    global model, criterion, optimizer

    features = train_dataset.features[:]
    #features.append(afeatures.PassFeature(name='posTAG', dim=len(postag_map)))
    if tag_frequencies:
        features.append(afeatures.MappingFeature(len(postag_map), 'tag_priors', freeze=True))
    features[9].freeze = not train_word_embedding
    print('len features', len(features))
    
    train_data_loader = DataLoader(train_gnn_dataset, batch_size=1, shuffle=True)
    grc_data_loader = DataLoader(grc_gnn_dataset, batch_size=1, shuffle=True)
    heb_data_loader = DataLoader(heb_gnn_dataset, batch_size=1, shuffle=True)
    blinker_data_loader = DataLoader(blinker_gnn_dataset, batch_size=1, shuffle=True)
    test_data_loader = DataLoader(test_gnn_dataset, batch_size=1, shuffle=True)

    clean_memory()
    drop_out = 0
    n_head = 1
    in_dim = sum(t.out_dim for t in features)


    channels = 512

    decoder_in_dim = n_head * channels 

    # model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/gnn_512_flggll_word_halfTrain_nofeatlinear_encoderlineear_decoderonelayer20210910-235352-.pickle')
    if use_transformers:
        decoder = POSDecoderTransformer(decoder_in_dim, decoder_in_dim*2, len(postag_map)).to(dev2)
    else:
        decoder = POSDecoder(decoder_in_dim, decoder_in_dim*2, len(postag_map))
        
    model = pyg_nn.GAE(Encoder(in_dim, channels, features, n_head, has_tagfreq_feature=tag_frequencies), decoder).to(dev)


    #vgg = models.vgg16()
    #summary(vgg, (3, 224, 224))
    if use_transformers:
        decoder.to(dev2)

    # model.to(dev)
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

    torch.set_printoptions(edgeitems=5)
    print("model params - decoder params - conv1", sum(p.numel() for p in model.parameters()), sum(p.numel() for p in decoder.parameters()))

    for epoch in range(1, 2):
        print(f"\n----------------epoch {epoch} ---------------")
        
        train(epoch, train_data_loader, mask_language, engtest_data_loader)
        train(epoch, grc_data_loader, mask_language, engtest_data_loader)
        train(epoch, heb_data_loader, mask_language, engtest_data_loader)
        train(epoch, blinker_data_loader, mask_language, engtest_data_loader)
        save_model(model, f'posfeat{tag_frequencies}_transformer{use_transformers}_trainWE{train_word_embedding}_maskLang{mask_language}_epoch{epoch}')
        test(epoch, test_data_loader, mask_language) 
        clean_memory()


In [38]:
create_model(gnn_dataset_train_pos, gnn_dataset_grc_pos, gnn_dataset_heb_pos, gnn_dataset_blinker_pos, gnn_dataset_engtest_pos,
train_word_embedding=False, mask_language=True, use_transformers=False, tag_frequencies=False)

len features 10
model params - decoder params - conv1 237862513 1592337

----------------epoch 1 ---------------


  1%|          | 498/57706 [00:38<1:16:52, 12.40it/s]

loss: 60792.10155946016
testing 1


100%|██████████| 250/250 [00:14<00:00, 16.80it/s]


test, epoch: 1, total:7477 ACC: 0.681423030627257


  2%|▏         | 997/57706 [01:33<1:03:50, 14.81it/s] 

loss: 41461.1582795158
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.05it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7302394008292096


  3%|▎         | 1498/57706 [02:28<1:24:13, 11.12it/s] 

loss: 41308.60204675794
testing 1


100%|██████████| 250/250 [00:14<00:00, 17.19it/s]


test, epoch: 1, total:7477 ACC: 0.7190049485087602


  3%|▎         | 1999/57706 [03:23<1:13:46, 12.59it/s] 

loss: 38717.00278094411
testing 1


100%|██████████| 250/250 [00:15<00:00, 16.21it/s]it/s]


test, epoch: 1, total:7477 ACC: 0.7164638223886586


  4%|▍         | 2497/57706 [04:21<1:18:27, 11.73it/s] 

loss: 35806.365319907665
testing 1


100%|██████████| 250/250 [00:13<00:00, 17.88it/s]


test, epoch: 1, total:7477 ACC: 0.7374615487494984


  5%|▌         | 2999/57706 [05:15<1:14:07, 12.30it/s] 

loss: 36571.082430224866
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.34it/s]it/s]


test, epoch: 1, total:7477 ACC: 0.7396014444295841


  6%|▌         | 3498/57706 [06:11<55:15, 16.35it/s]   

loss: 34578.3864111295
testing 1


100%|██████████| 250/250 [00:13<00:00, 17.96it/s]


test, epoch: 1, total:7477 ACC: 0.7409388792296375


  7%|▋         | 3999/57706 [07:07<59:58, 14.92it/s]   

loss: 35110.236918091774
testing 1


100%|██████████| 250/250 [00:14<00:00, 17.39it/s]/s]


test, epoch: 1, total:7477 ACC: 0.7412063661896483


  8%|▊         | 4498/57706 [08:07<1:03:37, 13.94it/s] 

loss: 35730.45427325368
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.17it/s]it/s]


test, epoch: 1, total:7477 ACC: 0.7444162097097766


  9%|▊         | 4998/57706 [09:02<1:02:25, 14.07it/s] 

loss: 34052.00408928841
testing 1


100%|██████████| 250/250 [00:14<00:00, 16.70it/s]
  9%|▊         | 5001/57706 [09:19<28:01:49,  1.91s/it]

test, epoch: 1, total:7477 ACC: 0.7733048013909322


 10%|▉         | 5499/57706 [09:58<59:44, 14.56it/s]   

loss: 36352.707267828286
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.29it/s]/s]


test, epoch: 1, total:7477 ACC: 0.7698274709107931


 10%|█         | 5997/57706 [10:54<1:24:26, 10.21it/s] 

loss: 34441.14432041347
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.17it/s]


test, epoch: 1, total:7477 ACC: 0.7644777317105791


 11%|█▏        | 6498/57706 [11:51<54:55, 15.54it/s]   

loss: 35073.933283150196
testing 1


100%|██████████| 250/250 [00:13<00:00, 17.99it/s]
 11%|█▏        | 6502/57706 [12:07<22:13:02,  1.56s/it]

test, epoch: 1, total:7477 ACC: 0.7547144576701886


 12%|█▏        | 6999/57706 [12:48<1:11:11, 11.87it/s] 

loss: 35425.1607362628
testing 1


100%|██████████| 250/250 [00:13<00:00, 17.95it/s]it/s]


test, epoch: 1, total:7477 ACC: 0.7601979403504079


 13%|█▎        | 7498/57706 [13:44<1:11:15, 11.74it/s] 

loss: 32236.675681129098
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.41it/s]


test, epoch: 1, total:7477 ACC: 0.7521733315500869


 14%|█▍        | 7998/57706 [14:39<1:18:32, 10.55it/s] 

loss: 32964.19366782904
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.07it/s]it/s]


test, epoch: 1, total:7477 ACC: 0.7672863447906915


 15%|█▍        | 8498/57706 [15:35<1:15:56, 10.80it/s] 

loss: 34060.01130712032
testing 1


100%|██████████| 250/250 [00:15<00:00, 16.26it/s]it/s]


test, epoch: 1, total:7477 ACC: 0.7638090143105524


 16%|█▌        | 8999/57706 [16:32<53:05, 15.29it/s]   

loss: 32805.055265404284
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.24it/s]


test, epoch: 1, total:7477 ACC: 0.7832018189113281


 16%|█▋        | 9498/57706 [17:28<1:05:38, 12.24it/s] 

loss: 32272.189907018095
testing 1


100%|██████████| 250/250 [00:15<00:00, 16.30it/s]it/s]


test, epoch: 1, total:7477 ACC: 0.7734385448709375


 17%|█▋        | 9999/57706 [18:26<1:13:31, 10.81it/s] 

loss: 32108.60938668251
testing 1


100%|██████████| 250/250 [00:13<00:00, 17.94it/s]it/s]
 17%|█▋        | 10003/57706 [18:42<21:02:09,  1.59s/it]

test, epoch: 1, total:7477 ACC: 0.7695599839507824


 18%|█▊        | 10498/57706 [19:22<1:14:47, 10.52it/s] 

loss: 33036.3243214041
testing 1


100%|██████████| 250/250 [00:14<00:00, 17.09it/s]


test, epoch: 1, total:7477 ACC: 0.7626053229905042


 19%|█▉        | 10997/57706 [20:14<45:56, 16.94it/s]   

loss: 34477.23181404732
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.08it/s]
 19%|█▉        | 11002/57706 [20:26<13:46:30,  1.06s/it]

test, epoch: 1, total:7477 ACC: 0.7763809014310552


 20%|█▉        | 11497/57706 [20:57<36:51, 20.90it/s]   

loss: 33724.44744684547
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.09it/s]
 20%|█▉        | 11503/57706 [21:08<10:40:33,  1.20it/s]

test, epoch: 1, total:7477 ACC: 0.7861441754714458


 21%|██        | 11997/57706 [21:37<44:10, 17.24it/s]   

loss: 32829.86324210465
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.60it/s]


test, epoch: 1, total:7477 ACC: 0.7722348535508894


 22%|██▏       | 12498/57706 [22:19<45:45, 16.46it/s]   

loss: 32987.28686556965
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.85it/s]
 22%|██▏       | 12498/57706 [22:29<45:45, 16.46it/s]

test, epoch: 1, total:7477 ACC: 0.7836030493513442


 23%|██▎       | 12997/57706 [22:58<41:58, 17.75it/s]   

loss: 32020.25784060359
testing 1


100%|██████████| 250/250 [00:10<00:00, 23.70it/s]
 23%|██▎       | 12997/57706 [23:10<41:58, 17.75it/s]

test, epoch: 1, total:7477 ACC: 0.7845392537113816


 23%|██▎       | 13499/57706 [23:41<1:02:03, 11.87it/s] 

loss: 32685.025623455644
testing 1


100%|██████████| 250/250 [00:10<00:00, 24.24it/s]
 23%|██▎       | 13503/57706 [23:53<15:45:01,  1.28s/it]

test, epoch: 1, total:7477 ACC: 0.7828005884713121


 24%|██▍       | 13998/57706 [24:23<46:03, 15.81it/s]   

loss: 33657.02029967308
testing 1


100%|██████████| 250/250 [00:10<00:00, 24.06it/s]


test, epoch: 1, total:7477 ACC: 0.7632740403905309


 25%|██▌       | 14498/57706 [25:05<46:25, 15.51it/s]   

loss: 31361.95577391982
testing 1


100%|██████████| 250/250 [00:10<00:00, 24.56it/s]
 25%|██▌       | 14502/57706 [25:17<14:39:40,  1.22s/it]

test, epoch: 1, total:7477 ACC: 0.7691587535107663


 26%|██▌       | 14999/57706 [25:45<53:42, 13.25it/s]   

loss: 32199.8413085863
testing 1


100%|██████████| 250/250 [00:16<00:00, 14.77it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7818643841112746


 27%|██▋       | 15499/57706 [26:52<1:01:24, 11.45it/s] 

loss: 31715.11689314805
testing 1


100%|██████████| 250/250 [00:17<00:00, 13.90it/s]5it/s]


test, epoch: 1, total:7477 ACC: 0.7783870536311355


 28%|██▊       | 15999/57706 [28:02<1:01:54, 11.23it/s] 

loss: 31219.561755333096
testing 1


100%|██████████| 250/250 [00:18<00:00, 13.36it/s]3it/s]


test, epoch: 1, total:7477 ACC: 0.7773171057910927


 29%|██▊       | 16499/57706 [29:12<1:02:49, 10.93it/s] 

loss: 31446.97399880737
testing 1


100%|██████████| 250/250 [00:17<00:00, 13.94it/s]3it/s]


test, epoch: 1, total:7477 ACC: 0.7700949578708038


 29%|██▉       | 16998/57706 [30:23<1:03:40, 10.65it/s] 

loss: 32279.265567600727
testing 1


100%|██████████| 250/250 [00:18<00:00, 13.41it/s]5it/s]


test, epoch: 1, total:7477 ACC: 0.7936338103517454


 30%|███       | 17499/57706 [31:33<57:03, 11.74it/s]   

loss: 30920.391665201634
testing 1


100%|██████████| 250/250 [00:18<00:00, 13.38it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7779858231911194


 31%|███       | 17998/57706 [32:41<45:38, 14.50it/s]   

loss: 32226.608943078667
testing 1


100%|██████████| 250/250 [00:12<00:00, 20.75it/s]
 31%|███       | 18001/57706 [32:54<17:48:19,  1.61s/it]

test, epoch: 1, total:7477 ACC: 0.7757121840310285


 32%|███▏      | 18499/57706 [33:30<44:45, 14.60it/s]   

loss: 30466.84911365807
testing 1


100%|██████████| 250/250 [00:12<00:00, 20.49it/s]


test, epoch: 1, total:7477 ACC: 0.7727698274709108


 33%|███▎      | 18999/57706 [34:19<44:46, 14.41it/s]   

loss: 31123.23767172359
testing 1


100%|██████████| 250/250 [00:14<00:00, 17.17it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7786545405911461


 34%|███▍      | 19498/57706 [35:21<52:26, 12.14it/s]   

loss: 31737.671187520027
testing 1


100%|██████████| 250/250 [00:17<00:00, 14.06it/s]


test, epoch: 1, total:7477 ACC: 0.793098836431724


 35%|███▍      | 19999/57706 [36:26<55:23, 11.35it/s]   

loss: 30966.438239596784
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.29it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7882840711515313


 36%|███▌      | 20498/57706 [37:35<1:07:51,  9.14it/s] 

loss: 32172.9316954948
testing 1


100%|██████████| 250/250 [00:17<00:00, 14.39it/s]


test, epoch: 1, total:7477 ACC: 0.7726360839909054


 36%|███▋      | 20998/57706 [38:44<59:18, 10.32it/s]   

loss: 30431.26906323433
testing 1


100%|██████████| 250/250 [00:17<00:00, 14.26it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7845392537113816


 37%|███▋      | 21498/57706 [39:50<49:38, 12.15it/s]   

loss: 31304.073297854513
testing 1


100%|██████████| 250/250 [00:19<00:00, 12.96it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7817306406312692


 38%|███▊      | 21999/57706 [40:56<55:52, 10.65it/s]   

loss: 30403.621094852686
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.30it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7774508492710981


 39%|███▉      | 22499/57706 [42:02<49:59, 11.74it/s]   

loss: 31362.616195239127
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.07it/s]


test, epoch: 1, total:7477 ACC: 0.7809281797512371


 40%|███▉      | 22998/57706 [43:08<1:04:47,  8.93it/s] 

loss: 30669.524622805417
testing 1


100%|██████████| 250/250 [00:17<00:00, 14.46it/s]


test, epoch: 1, total:7477 ACC: 0.7803932058312157


 41%|████      | 23497/57706 [44:12<53:00, 10.76it/s]   

loss: 31071.832025129348
testing 1


100%|██████████| 250/250 [00:20<00:00, 12.43it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7932325799117294


 42%|████▏     | 23999/57706 [45:19<42:15, 13.29it/s]   

loss: 29946.477564591914
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.00it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7777183362311088


 42%|████▏     | 24498/57706 [46:24<47:50, 11.57it/s]   

loss: 29947.293967485428
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.52it/s]
 42%|████▏     | 24498/57706 [46:42<47:50, 11.57it/s]

test, epoch: 1, total:7477 ACC: 0.7856092015514243


 43%|████▎     | 24997/57706 [47:29<36:37, 14.89it/s]   

loss: 29856.787325300276
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.49it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7818643841112746


 44%|████▍     | 25499/57706 [48:35<52:35, 10.21it/s]   

loss: 31261.104525879025
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.06it/s]


test, epoch: 1, total:7477 ACC: 0.7842717667513709


 45%|████▌     | 25999/57706 [49:40<54:52,  9.63it/s]   

loss: 30918.30260427296
testing 1


100%|██████████| 250/250 [00:17<00:00, 14.53it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7975123712719004


 46%|████▌     | 26499/57706 [50:47<50:52, 10.22it/s]   

loss: 31431.326733484864
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.54it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.776648388391066


 47%|████▋     | 26998/57706 [51:52<45:03, 11.36it/s]   

loss: 30042.864431351423
testing 1


100%|██████████| 250/250 [00:15<00:00, 15.99it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7731710579109269


 48%|████▊     | 27498/57706 [52:57<58:16,  8.64it/s]   

loss: 30816.291095875204
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.31it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7759796709910391


 49%|████▊     | 27998/57706 [54:02<41:19, 11.98it/s]   

loss: 31299.452487632632
testing 1


100%|██████████| 250/250 [00:19<00:00, 12.71it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7936338103517454


 49%|████▉     | 28498/57706 [55:06<35:20, 13.77it/s]   

loss: 29201.13902322203
testing 1


100%|██████████| 250/250 [00:15<00:00, 15.66it/s]


test, epoch: 1, total:7477 ACC: 0.7981810886719273


 50%|█████     | 28998/57706 [56:10<43:53, 10.90it/s]   

loss: 29006.763607326895
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.32it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.785074227631403


 51%|█████     | 29498/57706 [57:15<44:03, 10.67it/s]   

loss: 29810.281037926674
testing 1


100%|██████████| 250/250 [00:16<00:00, 15.23it/s]


test, epoch: 1, total:7477 ACC: 0.7822656145512906


 52%|█████▏    | 29999/57706 [58:19<39:56, 11.56it/s]   

loss: 31283.9463442564
testing 1


100%|██████████| 250/250 [00:18<00:00, 13.67it/s]t/s]


test, epoch: 1, total:7477 ACC: 0.7868128928714725


 53%|█████▎    | 30499/57706 [59:26<28:09, 16.10it/s]   

loss: 29055.69918948412
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.70it/s]


test, epoch: 1, total:7477 ACC: 0.7799919753911997


 54%|█████▎    | 30997/57706 [1:00:23<30:26, 14.62it/s]  

loss: 28491.436715815216
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.38it/s]


test, epoch: 1, total:7477 ACC: 0.8088805670723552


 55%|█████▍    | 31498/57706 [1:01:15<27:21, 15.96it/s]   

loss: 30318.052173748612
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.50it/s]
 55%|█████▍    | 31502/57706 [1:01:27<8:11:26,  1.13s/it] 

test, epoch: 1, total:7477 ACC: 0.8027283669921091


 55%|█████▌    | 31998/57706 [1:01:59<23:47, 18.01it/s]  

loss: 30526.386958748102
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.84it/s]
 55%|█████▌    | 32003/57706 [1:02:10<6:29:37,  1.10it/s]

test, epoch: 1, total:7477 ACC: 0.7798582319111943


 56%|█████▋    | 32498/57706 [1:02:41<28:46, 14.60it/s]  

loss: 29547.272876989096
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.77it/s]


test, epoch: 1, total:7477 ACC: 0.7842717667513709


 57%|█████▋    | 32999/57706 [1:03:23<23:52, 17.25it/s]  

loss: 29155.111212275922
testing 1


100%|██████████| 250/250 [00:09<00:00, 27.16it/s]


test, epoch: 1, total:7477 ACC: 0.8037983148321519


 58%|█████▊    | 33497/57706 [1:04:05<19:57, 20.21it/s]   

loss: 29703.844828650355
testing 1


100%|██████████| 250/250 [00:10<00:00, 24.84it/s]
 58%|█████▊    | 33502/57706 [1:04:17<6:29:20,  1.04it/s]

test, epoch: 1, total:7477 ACC: 0.7955062190718203


 59%|█████▉    | 33998/57706 [1:04:49<20:29, 19.28it/s]  

loss: 29525.418810538948
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.21it/s]


test, epoch: 1, total:7477 ACC: 0.7905577103116224


 60%|█████▉    | 34499/57706 [1:05:32<21:21, 18.11it/s]  

loss: 29876.798161805607
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.05it/s]


test, epoch: 1, total:7477 ACC: 0.8023271365520931


 61%|██████    | 34999/57706 [1:06:14<24:00, 15.76it/s]  

loss: 29990.65583114326
testing 1


100%|██████████| 250/250 [00:09<00:00, 27.07it/s]
 61%|██████    | 34999/57706 [1:06:25<24:00, 15.76it/s]

test, epoch: 1, total:7477 ACC: 0.793098836431724


 62%|██████▏   | 35499/57706 [1:06:58<30:10, 12.27it/s]   

loss: 31067.214357439894
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.52it/s]


test, epoch: 1, total:7477 ACC: 0.7945700147117828


 62%|██████▏   | 35997/57706 [1:07:40<24:43, 14.63it/s]  

loss: 27797.19819731172
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.39it/s]
 62%|██████▏   | 36003/57706 [1:07:51<5:50:33,  1.03it/s]

test, epoch: 1, total:7477 ACC: 0.7944362712317774


 63%|██████▎   | 36499/57706 [1:08:26<24:22, 14.50it/s]  

loss: 29384.374792933464
testing 1


100%|██████████| 250/250 [00:10<00:00, 24.85it/s]
 63%|██████▎   | 36501/57706 [1:08:38<10:14:13,  1.74s/it]

test, epoch: 1, total:7477 ACC: 0.7963086799518524


 64%|██████▍   | 36999/57706 [1:09:09<30:24, 11.35it/s]   

loss: 29078.484402140602
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.48it/s]
 64%|██████▍   | 37004/57706 [1:09:20<6:11:00,  1.08s/it]

test, epoch: 1, total:7477 ACC: 0.7953724755918149


 65%|██████▍   | 37499/57706 [1:09:52<19:57, 16.88it/s]  

loss: 28846.706269759685
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.87it/s]
 65%|██████▍   | 37503/57706 [1:10:03<6:13:02,  1.11s/it]

test, epoch: 1, total:7477 ACC: 0.8047345191921894


 66%|██████▌   | 37998/57706 [1:10:34<22:34, 14.55it/s]  

loss: 29915.487880647182
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.61it/s]


test, epoch: 1, total:7477 ACC: 0.7868128928714725


 67%|██████▋   | 38498/57706 [1:11:15<19:36, 16.33it/s]  

loss: 28954.12161179632
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.21it/s]
 67%|██████▋   | 38502/57706 [1:11:27<6:10:21,  1.16s/it]

test, epoch: 1, total:7477 ACC: 0.7936338103517454


 68%|██████▊   | 38997/57706 [1:11:58<20:21, 15.32it/s]  

loss: 29296.288713146
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.27it/s]
 68%|██████▊   | 39002/57706 [1:12:09<4:51:46,  1.07it/s]

test, epoch: 1, total:7477 ACC: 0.7937675538317507


 68%|██████▊   | 39497/57706 [1:12:40<19:29, 15.57it/s]  

loss: 29806.223991628736
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.30it/s]
 68%|██████▊   | 39500/57706 [1:12:51<6:45:18,  1.34s/it]

test, epoch: 1, total:7477 ACC: 0.7844055102313762


 69%|██████▉   | 39999/57706 [1:13:22<19:48, 14.90it/s]  

loss: 28337.276672810316
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.88it/s]
 69%|██████▉   | 40003/57706 [1:13:34<5:43:44,  1.17s/it]

test, epoch: 1, total:7477 ACC: 0.8028621104721145


 70%|███████   | 40499/57706 [1:14:04<13:36, 21.07it/s]  

loss: 29446.786886418326
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.59it/s]
 70%|███████   | 40502/57706 [1:14:16<5:35:15,  1.17s/it]

test, epoch: 1, total:7477 ACC: 0.7936338103517454


 71%|███████   | 40999/57706 [1:14:48<23:14, 11.98it/s]  

loss: 29189.354666974396
testing 1


100%|██████████| 250/250 [00:10<00:00, 24.71it/s]


test, epoch: 1, total:7477 ACC: 0.8058044670322322


 72%|███████▏  | 41498/57706 [1:15:30<18:02, 14.97it/s]  

loss: 28197.966673851013
testing 1


100%|██████████| 250/250 [00:09<00:00, 27.63it/s]
 72%|███████▏  | 41503/57706 [1:15:41<4:09:40,  1.08it/s]

test, epoch: 1, total:7477 ACC: 0.8078106192323125


 73%|███████▎  | 41998/57706 [1:16:12<17:22, 15.07it/s]  

loss: 29171.886109143496
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.82it/s]
 73%|███████▎  | 41998/57706 [1:16:27<17:22, 15.07it/s]

test, epoch: 1, total:7477 ACC: 0.7944362712317774


 74%|███████▎  | 42498/57706 [1:17:07<23:10, 10.93it/s]  

loss: 29212.413708478212
testing 1


100%|██████████| 250/250 [00:14<00:00, 17.15it/s]3it/s]


test, epoch: 1, total:7477 ACC: 0.7960411929918416


 75%|███████▍  | 42998/57706 [1:18:02<19:07, 12.82it/s]  

loss: 28675.50098800659
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.63it/s]
 75%|███████▍  | 43002/57706 [1:18:17<6:30:14,  1.59s/it]

test, epoch: 1, total:7477 ACC: 0.7940350407917613


 75%|███████▌  | 43499/57706 [1:18:59<17:00, 13.92it/s]  

loss: 29854.08224734664
testing 1


100%|██████████| 250/250 [00:14<00:00, 17.84it/s]


test, epoch: 1, total:7477 ACC: 0.7929650929517186


 76%|███████▌  | 43999/57706 [1:19:54<19:26, 11.75it/s]  

loss: 29510.54767908156
testing 1


100%|██████████| 250/250 [00:14<00:00, 16.89it/s]5it/s]
 76%|███████▋  | 44001/57706 [1:20:10<8:53:13,  2.33s/it] 

test, epoch: 1, total:7477 ACC: 0.8092817975123713


 77%|███████▋  | 44499/57706 [1:20:50<18:53, 11.65it/s]  

loss: 29217.616110488772
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.72it/s]


test, epoch: 1, total:7477 ACC: 0.8067406713922697


 78%|███████▊  | 44998/57706 [1:21:45<17:18, 12.24it/s]  

loss: 30026.592356693
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.74it/s]4it/s]


test, epoch: 1, total:7477 ACC: 0.8037983148321519


 79%|███████▉  | 45498/57706 [1:22:42<16:01, 12.70it/s]  

loss: 28422.209765970707
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.17it/s]


test, epoch: 1, total:7477 ACC: 0.8134278453925371


 80%|███████▉  | 45999/57706 [1:23:37<14:35, 13.37it/s]  

loss: 29034.758199363947
testing 1


100%|██████████| 250/250 [00:14<00:00, 16.67it/s]7it/s]


test, epoch: 1, total:7477 ACC: 0.7944362712317774


 81%|████████  | 46498/57706 [1:24:33<12:44, 14.66it/s]  

loss: 29049.520737092942
testing 1


100%|██████████| 250/250 [00:13<00:00, 19.01it/s]


test, epoch: 1, total:7477 ACC: 0.809549284472382


 81%|████████▏ | 46999/57706 [1:25:27<19:52,  8.98it/s]  

loss: 29414.77227556333
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.98it/s]


test, epoch: 1, total:7477 ACC: 0.7995185234719807


 82%|████████▏ | 47499/57706 [1:26:22<14:25, 11.79it/s]  

loss: 29522.839936850592
testing 1


100%|██████████| 250/250 [00:15<00:00, 16.56it/s]
 82%|████████▏ | 47499/57706 [1:26:38<14:25, 11.79it/s]

test, epoch: 1, total:7477 ACC: 0.8119566671124783


 83%|████████▎ | 47999/57706 [1:27:14<13:19, 12.14it/s]  

loss: 28847.363803632557
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.37it/s]


test, epoch: 1, total:7477 ACC: 0.8040658017921626


 84%|████████▍ | 48498/57706 [1:28:10<12:07, 12.67it/s]  

loss: 29585.21772081405
testing 1


100%|██████████| 250/250 [00:14<00:00, 16.79it/s]
 84%|████████▍ | 48502/57706 [1:28:27<4:28:17,  1.75s/it]

test, epoch: 1, total:7477 ACC: 0.793098836431724


 85%|████████▍ | 48998/57706 [1:29:06<13:52, 10.46it/s]  

loss: 29198.767078474164
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.82it/s]6it/s]
 85%|████████▍ | 49000/57706 [1:29:21<5:28:31,  2.26s/it]

test, epoch: 1, total:7477 ACC: 0.8000534973920022


 86%|████████▌ | 49499/57706 [1:30:02<10:19, 13.24it/s]  

loss: 28942.768796622753
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.77it/s]


test, epoch: 1, total:7477 ACC: 0.7944362712317774


 87%|████████▋ | 49998/57706 [1:31:00<10:21, 12.40it/s]  

loss: 29854.20639207959
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.46it/s]


test, epoch: 1, total:7477 ACC: 0.801524675672061


 88%|████████▊ | 50499/57706 [1:31:56<10:25, 11.52it/s]  

loss: 29821.06729345396
testing 1


100%|██████████| 250/250 [00:15<00:00, 16.11it/s]2it/s]


test, epoch: 1, total:7477 ACC: 0.8043332887521734


 88%|████████▊ | 50998/57706 [1:32:52<07:36, 14.71it/s]  

loss: 29118.702739151195
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.27it/s]


test, epoch: 1, total:7477 ACC: 0.8112879497124516


 89%|████████▉ | 51499/57706 [1:33:49<07:19, 14.14it/s]  

loss: 29554.597927093506
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.74it/s]


test, epoch: 1, total:7477 ACC: 0.804199545272168


 90%|█████████ | 51998/57706 [1:34:43<08:28, 11.22it/s]  

loss: 28355.460974171758
testing 1


100%|██████████| 250/250 [00:14<00:00, 17.82it/s]


test, epoch: 1, total:7477 ACC: 0.8054032365922161


 91%|█████████ | 52498/57706 [1:35:40<06:07, 14.16it/s]  

loss: 30726.339223045507
testing 1


100%|██████████| 250/250 [00:14<00:00, 17.75it/s]


test, epoch: 1, total:7477 ACC: 0.7888190450715528


 92%|█████████▏| 52997/57706 [1:36:36<05:48, 13.49it/s]  

loss: 29731.466311063617
testing 1


100%|██████████| 250/250 [00:12<00:00, 19.44it/s]9it/s]


test, epoch: 1, total:7477 ACC: 0.7995185234719807


 93%|█████████▎| 53499/57706 [1:37:31<04:43, 14.84it/s]  

loss: 28629.71487067826
testing 1


100%|██████████| 250/250 [00:13<00:00, 18.34it/s]


test, epoch: 1, total:7477 ACC: 0.7991172930319647


 94%|█████████▎| 53999/57706 [1:38:26<04:25, 13.97it/s]  

loss: 29594.03818562627
testing 1


100%|██████████| 250/250 [00:15<00:00, 16.03it/s]7it/s]


test, epoch: 1, total:7477 ACC: 0.7953724755918149


 94%|█████████▍| 54498/57706 [1:39:23<04:26, 12.03it/s]  

loss: 28837.669435851276
testing 1


100%|██████████| 250/250 [00:12<00:00, 19.65it/s]
 94%|█████████▍| 54502/57706 [1:39:39<1:22:07,  1.54s/it]

test, epoch: 1, total:7477 ACC: 0.8062056974722482


 95%|█████████▌| 54999/57706 [1:40:20<03:35, 12.55it/s]  

loss: 29995.68530363217
testing 1


100%|██████████| 250/250 [00:13<00:00, 19.05it/s]


test, epoch: 1, total:7477 ACC: 0.7964424234318577


 96%|█████████▌| 55499/57706 [1:41:17<03:19, 11.07it/s]  

loss: 29116.868619788438
testing 1


100%|██████████| 250/250 [00:12<00:00, 19.62it/s]7it/s]
 96%|█████████▌| 55500/57706 [1:41:31<1:34:25,  2.57s/it]

test, epoch: 1, total:7477 ACC: 0.7967099103918684


 97%|█████████▋| 55999/57706 [1:42:11<01:46, 16.09it/s]  

loss: 28866.955785244703
testing 1


100%|██████████| 250/250 [00:14<00:00, 17.29it/s]


test, epoch: 1, total:7477 ACC: 0.8146315367125853


 98%|█████████▊| 56499/57706 [1:43:08<01:24, 14.29it/s]  

loss: 29541.835002522916
testing 1


100%|██████████| 250/250 [00:11<00:00, 21.42it/s]9it/s]
 98%|█████████▊| 56500/57706 [1:43:21<41:44,  2.08s/it]

test, epoch: 1, total:7477 ACC: 0.8098167714323927


 99%|█████████▉| 56998/57706 [1:44:00<00:48, 14.70it/s]

loss: 29635.474113680422
testing 1


100%|██████████| 250/250 [00:11<00:00, 22.00it/s]
 99%|█████████▉| 57001/57706 [1:44:13<16:11,  1.38s/it]

test, epoch: 1, total:7477 ACC: 0.8100842583924034


100%|█████████▉| 57499/57706 [1:44:51<00:15, 13.01it/s]

loss: 27889.344849720597
testing 1


100%|██████████| 250/250 [00:08<00:00, 28.66it/s]
100%|█████████▉| 57505/57706 [1:45:01<02:43,  1.23it/s]

test, epoch: 1, total:7477 ACC: 0.8090143105523606


100%|██████████| 57706/57706 [1:45:12<00:00,  9.14it/s]


total train loss: 12154.4075050354


 23%|██▎       | 499/2192 [00:39<02:17, 12.30it/s]

loss: 40078.94637584686
testing 1


100%|██████████| 250/250 [00:09<00:00, 25.71it/s]
 23%|██▎       | 501/2192 [00:50<48:57,  1.74s/it]

test, epoch: 1, total:7477 ACC: 0.8033970843921359


 46%|████▌     | 999/2192 [01:28<01:24, 14.09it/s]

loss: 40443.52976617962
testing 1


100%|██████████| 250/250 [00:10<00:00, 23.46it/s]


test, epoch: 1, total:7477 ACC: 0.8067406713922697


 68%|██████▊   | 1498/2192 [02:18<00:46, 14.91it/s]

loss: 38933.15093395114
testing 1


100%|██████████| 250/250 [00:07<00:00, 31.74it/s]


test, epoch: 1, total:7477 ACC: 0.8078106192323125


 91%|█████████ | 1998/2192 [03:08<00:16, 11.95it/s]

loss: 41104.44333682582
testing 1


100%|██████████| 250/250 [00:08<00:00, 30.81it/s]


test, epoch: 1, total:7477 ACC: 0.8009897017520395


100%|██████████| 2192/2192 [03:32<00:00, 10.33it/s]


total train loss: 15794.375205144286


 10%|▉         | 499/4993 [00:25<03:09, 23.71it/s]

loss: 24111.327615510672
testing 1


100%|██████████| 250/250 [00:09<00:00, 27.53it/s]
 10%|█         | 503/4993 [00:36<1:06:24,  1.13it/s]

test, epoch: 1, total:7477 ACC: 0.8201150193928046


 20%|██        | 999/4993 [00:59<03:45, 17.71it/s]  

loss: 22250.933524686843
testing 1


100%|██████████| 250/250 [00:09<00:00, 27.46it/s]
 20%|██        | 1003/4993 [01:09<1:05:58,  1.01it/s]

test, epoch: 1, total:7477 ACC: 0.8090143105523606


 30%|███       | 1499/4993 [01:34<02:28, 23.48it/s]  

loss: 22975.5532618016
testing 1


100%|██████████| 250/250 [00:09<00:00, 27.48it/s]
 30%|███       | 1505/4993 [01:45<41:39,  1.40it/s]

test, epoch: 1, total:7477 ACC: 0.8114216931924568


 40%|████      | 1999/4993 [02:08<02:53, 17.28it/s]

loss: 23137.5597955361
testing 1


100%|██████████| 250/250 [00:11<00:00, 21.35it/s]


test, epoch: 1, total:7477 ACC: 0.8146315367125853


 50%|█████     | 2498/4993 [02:45<02:00, 20.78it/s]  

loss: 23817.35313051939
testing 1


100%|██████████| 250/250 [00:09<00:00, 27.41it/s]


test, epoch: 1, total:7477 ACC: 0.8136953323525478


 60%|██████    | 2998/4993 [03:19<01:44, 19.16it/s]

loss: 21590.505381432362
testing 1


100%|██████████| 250/250 [00:10<00:00, 22.99it/s]


test, epoch: 1, total:7477 ACC: 0.815701484552628


 70%|███████   | 3497/4993 [03:55<01:23, 18.00it/s]

loss: 23973.96987787634
testing 1


100%|██████████| 250/250 [00:08<00:00, 27.86it/s]
 70%|███████   | 3503/4993 [04:06<19:57,  1.24it/s]

test, epoch: 1, total:7477 ACC: 0.8199812759127992


 80%|████████  | 3998/4993 [04:30<00:39, 25.14it/s]

loss: 22794.724494640715
testing 1


100%|██████████| 250/250 [00:08<00:00, 28.16it/s]


test, epoch: 1, total:7477 ACC: 0.8187775845927511


 90%|█████████ | 4497/4993 [05:04<00:25, 19.27it/s]

loss: 22915.38640972972
testing 1


100%|██████████| 250/250 [00:09<00:00, 26.61it/s]
 90%|█████████ | 4500/4993 [05:15<09:14,  1.12s/it]

test, epoch: 1, total:7477 ACC: 0.8223886585528956


100%|██████████| 4993/4993 [05:40<00:00, 14.64it/s]


total train loss: 21917.511179342866


 78%|███████▊  | 497/641 [00:33<00:08, 17.37it/s]

loss: 31936.211573064327
testing 1


100%|██████████| 250/250 [00:09<00:00, 27.68it/s]
 78%|███████▊  | 503/641 [00:44<02:04,  1.11it/s]

test, epoch: 1, total:7477 ACC: 0.8203825063528153


100%|██████████| 641/641 [00:53<00:00, 12.05it/s]


total train loss: 8679.027367591858
testing 1


100%|██████████| 250/250 [00:09<00:00, 27.55it/s]


test, epoch: 1, total:7477 ACC: 0.8100842583924034


In [42]:
importlib.reload(posutil)
target_langs = ['yor-x-bible-2010', 'tam-x-bible-newworld', 'fin-x-bible-helfi']
for lang in target_langs:
    posutil.generate_target_lang_tags(model, lang, f"new_posfeatFalse_transformerFalse_trainWEFalse", True, train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset,
     train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch, DataEncoder)
    

100%|██████████| 24078/24078 [11:31<00:00, 34.85it/s]
100%|██████████| 783/783 [00:31<00:00, 25.05it/s]
100%|██████████| 2225/2225 [00:52<00:00, 42.48it/s]
100%|██████████| 250/250 [00:08<00:00, 28.63it/s]
100%|██████████| 24078/24078 [11:03<00:00, 36.31it/s]
100%|██████████| 783/783 [00:32<00:00, 24.19it/s]
100%|██████████| 250/250 [00:07<00:00, 32.35it/s]
100%|██████████| 24078/24078 [12:48<00:00, 31.31it/s]
100%|██████████| 783/783 [00:37<00:00, 21.16it/s]
100%|██████████| 2225/2225 [01:03<00:00, 35.10it/s]
100%|██████████| 250/250 [00:09<00:00, 26.78it/s]


In [28]:
# model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatTrue_transformerTrue_trainWETrue_maskLangFalse_epoch3_20220218-101130.pickle')
# test(1, engtest_data_loader, mask_language=False) 

importlib.reload(posutil)
def get_language_node_cover(all_node_cover, target_edition, dataset):
    nodes_map = dataset.nodes_map
    res = collections.defaultdict(list)

    if target_edition in nodes_map:
        for verse in all_node_cover:
            if verse in nodes_map[target_edition]:
                lang_nodes = list(nodes_map[target_edition][verse].values())
                for tok in all_node_cover[verse]:
                    if tok in lang_nodes:
                        res[verse].append(tok)
    
    return res
    

def finetune_and_generate_for_target_lang(model_path, target_edition):
    global model, criterion, optimizer
    
    target_train_node_cover = get_language_node_cover(train_pos_node_cover_ext, target_edition, train_dataset)
    target_grc_node_cover = get_language_node_cover(grc_pos_node_cover_ext, target_edition, grc_test_dataset)
    target_heb_node_cover = get_language_node_cover(heb_pos_node_cover_ext, target_edition, heb_test_dataset)
    target_blinker_node_cover = get_language_node_cover(blinker_pos_node_cover_ext, target_edition, blinker_test_dataset)
    

    train_ds, grc_ds, heb_ds, blinker_ds = create_me_a_gnn_dataset_you_stupid(
            [target_train_node_cover, target_grc_node_cover, target_heb_node_cover, target_blinker_node_cover], 
            [train_pos_labels_ext, grc_pos_labels_ext,  heb_pos_labels_ext, blinker_pos_labels_ext], editions=[target_edition])
    
    train_data_loader = DataLoader(train_ds, shuffle=True)
    grc_data_loader = DataLoader(grc_ds, shuffle=True)
    heb_data_loader = DataLoader(heb_ds, shuffle=True)
    blinker_data_loader = DataLoader(blinker_ds, shuffle=True)

    # model = torch.load(model_path)
    # criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    # train(1, train_data_loader, mask_language=False, test_data_loader=blinker_data_loader, max_batches=100)
    # posutil.generate_target_lang_tags(model, target_edition, 'posfeatTrue_transformerTrue6layerresidual_trainWEFalse_finetune100', False, 
    #         train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
    #         DataEncoder)
    # model = None
    # clean_memory

    # model = torch.load(model_path)
    # criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    # train(1, train_data_loader, mask_language=False, test_data_loader=blinker_data_loader, max_batches=1000)
    posutil.generate_target_lang_tags(model, target_edition, 'posfeatTrue_transformerTrue6layerresidual_trainWEFalse_finetune1000', False, 
            train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
            DataEncoder)
    model = None
    clean_memory

    model = torch.load(model_path)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    train(1, train_data_loader, mask_language=False, test_data_loader=blinker_data_loader, max_batches=10000)
    posutil.generate_target_lang_tags(model, target_edition, 'posfeatTrue_transformerTrue6layerresidual_trainWEFalse_finetune10000', False, 
            train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
            DataEncoder)
    model = None
    clean_memory

    model = torch.load(model_path)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
    train(1, train_data_loader, mask_language=False, test_data_loader=blinker_data_loader)
    train(1, grc_data_loader, mask_language=False, test_data_loader=blinker_data_loader)
    train(1, heb_data_loader, mask_language=False, test_data_loader=blinker_data_loader)
    train(1, blinker_data_loader, mask_language=False, test_data_loader=blinker_data_loader)
    posutil.generate_target_lang_tags(model, target_edition, 'posfeatTrue_transformerTrue6layerresidual_trainWEFalse_finetuneALL', False, 
            train_dataset, grc_test_dataset, heb_test_dataset, blinker_test_dataset, train_data_loader_bigbatch, grc_data_loader_bigbatch, heb_data_loader_bigbatch, blinker_data_loader_bigbatch,
            DataEncoder)


def generate_target_lang_tags_all_models(target_langs):
    global model
    model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatFalse_transformerFalse_trainWEFalse_maskLangTrue_20220209-201345.pickle', map_location=torch.device('cpu')).to(dev)
    test(0, blinker_data_loader_bigbatch, True)
    for lang in target_langs:
        posutil.generate_target_lang_tags(lang,  f"posfeatFalse_transformerFalse_trainWEFalse", True)

# posutil.generate_target_lang_tags_all_models(['yor-x-bible-2010', 'tam-x-bible-newworld', 'fin-x-bible-helfi'])
finetune_and_generate_for_target_lang('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatTrue_transformerTru6Lresidual_trainWETrue_maskLangFalse_epoch3_20220218-101130.pickle', 'yor-x-bible-2010')

NameError: name 'train_pos_node_cover_ext' is not defined

In [None]:
#gnn_dataset_train_pos_ext, gnn_dataset_grc_pos_ext, gnn_dataset_heb_pos_ext, gnn_dataset_blinker_pos_ext = create_me_a_gnn_dataset_you_stupid([ train_pos_node_cover_ext, grc_pos_node_cover_ext, heb_pos_node_cover_ext, blinker_pos_node_cover_ext]
#    , [train_pos_labels_ext, grc_pos_labels_ext, heb_pos_labels_ext, blinker_pos_labels_ext])
#create_model(gnn_dataset_train_pos_ext, gnn_dataset_grc_pos_ext, gnn_dataset_heb_pos_ext, gnn_dataset_blinker_pos_ext, gnn_dataset_blinker_pos_bigbatch,
#    train_word_embedding=False, mask_language=False, use_transformers=True, tag_frequencies=True)

In [None]:
#i = sag
#batch = khar
#verse = gav
#print(i, verse)

#keys = list(gnn_dataset.verse_info.keys())

#gnn_dataset.verse_info[verse]
#save_model(model, 'freeze-embedding_noLang')

In [37]:
#initial_model = model.to('cpu')
#model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatTrue_transformerFalse_trainWEFalse_maskLangTrue_20220210-173912.pickle')
#torch.cuda.set_device(0)
#model.to(dev)
#epoch = 0
#test_data_loader = DataLoader(gnn_dataset_blinker_pos_bigbatch, batch_size=1, shuffle=False)
#test(epoch, test_data_loader, mask_language=True)
##yoruba_postags = generate_target_lang_tags('yor-x-bible-2010', f"posfeatFalse_transformerFalse_trainWEFalse", True)


testing 0


100%|██████████| 250/250 [00:16<00:00, 15.09it/s]


test, epoch: 0, total:52256 ACC: 0.1478107777097367


In [None]:
##importlib.reload(afeatures)
#initial_model = model
#torch.cuda.set_device(1)
#create_model(gnn_dataset_train_pos, gnn_dataset_grc_pos, gnn_dataset_heb_pos, gnn_dataset_blinker_pos, gnn_dataset_blinker_pos_bigbatch, tag_frequencies=True)
## yoruba_postags = generate_target_lang_tags('yor-x-bible-2010', f"posfeatTrue_transformerFalse_trainWEFalse", True)

In [None]:
torch.cuda.set_device(4)
gnn_ds_train_pos_ext, gnn_ds_grc_pos_ext, gnn_ds_heb_pos_ext, gnn_ds_blinker_pos_ext = create_me_a_gnn_dataset_you_stupid([train_pos_node_cover_ext, grc_pos_node_cover_ext, heb_pos_node_cover_ext, blinker_pos_node_cover_ext], 
    [train_pos_labels_ext, grc_pos_labels_ext, heb_pos_labels_ext, blinker_pos_labels_ext])

create_model(gnn_ds_train_pos_ext, gnn_ds_grc_pos_ext, gnn_ds_heb_pos_ext, gnn_ds_blinker_pos_ext, gnn_dataset_blinker_pos_bigbatch, tag_frequencies=True
    , use_transformers=False, train_word_embedding=True, mask_language=False)

# yoruba_postags = generate_target_lang_tags('yor-x-bible-2010', f"posfeatTrue_transformerTrue_trainWETrue", False)

len features 11
model params - decoder params - conv1 277911011 1592337

----------------epoch 1 ---------------


  1%|          | 498/58044 [00:55<1:17:38, 12.35it/s]

loss: 61841.90937706083
testing 1


100%|██████████| 250/250 [00:08<00:00, 28.45it/s]
  1%|          | 502/58044 [01:06<18:54:26,  1.18s/it]

test, epoch: 1, total:7477 ACC: 0.684900361107396


  2%|▏         | 998/58044 [01:48<1:15:17, 12.63it/s]

loss: 43130.26362814009
testing 1


100%|██████████| 250/250 [00:04<00:00, 55.62it/s]


test, epoch: 1, total:7477 ACC: 0.7349204226293968


  3%|▎         | 1498/58044 [02:39<1:19:49, 11.81it/s]

loss: 38608.35494910739
testing 1


100%|██████████| 250/250 [00:04<00:00, 53.10it/s]
  3%|▎         | 1501/58044 [02:45<13:45:12,  1.14it/s]

test, epoch: 1, total:7477 ACC: 0.7581917881503276


  3%|▎         | 1998/58044 [03:28<1:15:55, 12.30it/s]

loss: 35195.79320020508
testing 1


100%|██████████| 250/250 [00:04<00:00, 53.17it/s]
  3%|▎         | 2002/58044 [03:35<11:29:29,  1.35it/s]

test, epoch: 1, total:7477 ACC: 0.7442824662297713


  4%|▍         | 2369/58044 [04:06<1:36:23,  9.63it/s]


KeyboardInterrupt: 

In [None]:
torch.cuda.set_device(4)
model = torch.load('/mounts/work/ayyoob/models/gnn/checkpoint/postagging/pos_tagging_posfeatFalse_transformerFalse_trainWEFalse_maskLangTrue_20220209-201345.pickle', )
torch.cuda.set_device(0)
model.to(dev)
epoch = 0
test_data_loader = DataLoader(gnn_dataset_blinker_pos_onlyeng, batch_size=1, shuffle=False)
test(epoch, test_data_loader, mask_language=False)
yoruba_postags = generate_target_lang_tags('fin-x-bible-helfi', f"posfeatTrue_transformerTrue_trainWETrue", False)


  1%|          | 3/250 [00:00<00:10, 23.34it/s]

testing 0


100%|██████████| 250/250 [00:10<00:00, 23.99it/s]

test, epoch: 0, total:7477 ACC: 0.8324194195532968





In [None]:

#normalized_gold_frequencies, gold_frequencies_all = get_words_tag_frequence(model, 2354770, len(postag_map), english_data_loader, from_gold_data=True)

#gold_frequencies_all = gold_frequencies
word_frequencies = torch.sum(gold_frequencies_all, dim=1)

subjectword_indices =  word_frequencies > 0.1
print(word_frequencies.shape)
gold_frequencies = gold_frequencies_all[subjectword_indices, :]
predicted_frequencies = tag_frequencies_english[subjectword_indices, :]
wordtype_frequencies = word_frequencies[subjectword_indices]
print(gold_frequencies.shape)

_, gold_tags = torch.max(gold_frequencies, dim=1)
_, predicted_tags = torch.max(predicted_frequencies, dim=1)

sorted_wordtype_frequencies, sort_pattern = torch.sort(wordtype_frequencies, descending=True)

sorted_gold_tags = gold_tags[sort_pattern]
sorted_predicted_tags = predicted_tags[sort_pattern]
quarter_size = int(sorted_gold_tags.size(0)/2.0)

print('quarter size', quarter_size)
print("general accuracy", torch.sum(gold_tags == predicted_tags)/predicted_tags.size(0))
print('first quarter accuracy', torch.sum(sorted_gold_tags[:quarter_size] == sorted_predicted_tags[:quarter_size])/quarter_size)
print('last part accuracy', torch.sum(sorted_gold_tags[1*quarter_size:] == sorted_predicted_tags[1*quarter_size:])/sorted_predicted_tags[1*quarter_size:].size(0))

print('total token count', torch.sum(wordtype_frequencies))
print('first quarter words token count', torch.sum(word_frequencies[:quarter_size]))

print('1st frequency', sorted_wordtype_frequencies[0])
print('10st frequency', sorted_wordtype_frequencies[10])
print('100st frequency', sorted_wordtype_frequencies[100])
print('100st frequency', sorted_wordtype_frequencies[736])
print('1000st frequency', sorted_wordtype_frequencies[1000])
print('10000st frequency', sorted_wordtype_frequencies[10000])


NameError: name 'gold_frequencies_all' is not defined

In [None]:
#normalized_tag_frequencies = torch.softmax(tag_frequencies_copy, dim=1)

sm = torch.sum(tag_frequencies_copy, dim=1)
normalized_tag_frequencies = (tag_frequencies_copy.transpose(1,0) / sm).transpose(1,0)

In [None]:
global model, decoder
#1/0

decoder = None
model = None

gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:

#features = blinker_test_dataset.features[:]
#features_edge = train_dataset.features_edge[:]
from pprint import pprint
#print('indim',in_dim)
#features[-1].out_dim = 50
for i in features:
    #if i.type==3:
    #    i.out_dim=4
    print(vars(i))

#sum(p.out_dim for p in features)
#train_dataset.features.pop()
#train_dataset.features[0] = afeatures.OneHotFeature(20, 83, 'editf')
#train_dataset.features[1] = afeatures.OneHotFeature(32, 150, 'position')
#train_dataset.features[2] = afeatures.FloatFeature(4, 'degree_centrality')
#train_dataset.features[3] = afeatures.FloatFeature(4, 'closeness_centrality')
#train_dataset.features[4] = afeatures.FloatFeature(4, 'betweenness_centrality')
#train_dataset.features[5] = afeatures.FloatFeature(4, 'load_centrality')
#train_dataset.features[6] = afeatures.FloatFeature(4, 'harmonic_centrality')
#train_dataset.features[7] = afeatures.OneHotFeature(32, 250, 'greedy_modularity_community')
##train_dataset.features.append(afeatures.MappingFeature(100, 'word'))
#torch.save(train_dataset, "/mounts/work/ayyoob/models/gnn/dataset_helfi_train_community_word.pickle")
#torch.save(train_dataset.features[-3], "./features.tmp")

{'type': 1, 'out_dim': 20, 'global_normalize': False, 'name': 'editf', 'Active': True, 'n_classes': 83}
{'type': 1, 'out_dim': 32, 'global_normalize': False, 'name': 'position', 'Active': True, 'n_classes': 150}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'degree_centrality', 'Active': True}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'closeness_centrality', 'Active': True}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'betweenness_centrality', 'Active': True}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'load_centrality', 'Active': True}
{'type': 3, 'out_dim': 4, 'global_normalize': False, 'name': 'harmonic_centrality', 'Active': True}
{'type': 1, 'out_dim': 32, 'global_normalize': False, 'name': 'greedy_modularity_community', 'Active': True, 'n_classes': 250}
{'type': 1, 'out_dim': 32, 'global_normalize': False, 'name': 'label_propagation_community', 'Active': True, 'n_classes': 250}
{'type': 6, 'out_dim': 100, 'global_

In [None]:
nodes_map = train_dataset.nodes_map
bad_edition_files = []
for edit in nodes_map:
    bad_count = 0
    for verse in nodes_map[edit]:
        if len(nodes_map[edit][verse].keys()) < 2:
            bad_count += 1
        if bad_count > 1:
            bad_edition_files.append(edit)
            break
print(bad_edition_files)

In [None]:
all_japanese_nodes = set()
nodes_map = train_dataset.nodes_map

for bad_editionf in bad_edition_files:
    for verse in nodes_map[bad_editionf]:
        for item in nodes_map[bad_editionf][verse].items():
            all_japanese_nodes.add(item[1])

print(" all japansese nodes: ", len(all_japanese_nodes))
edge_index = train_dataset.edge_index.to('cpu')
remaining_edges_index = []
for i in tqdm(range(0, edge_index.shape[1], 2)):
    if edge_index[0, i].item() not in all_japanese_nodes and edge_index[0, i+1].item() not in all_japanese_nodes:
        remaining_edges_index.extend([i, i+1])

print('original total edges count', edge_index.shape)
print('remaining edge count', len(remaining_edges_index))
train_dataset.edge_index = edge_index[:, remaining_edges_index]
train_dataset.edge_index.shape
