In [8]:
%set_env PYTHONPATH="/home/dstambl2/doc_alignment_implementations/thompson_2021_doc_align"

env: PYTHONPATH="/home/dstambl2/doc_alignment_implementations/thompson_2021_doc_align"


In [9]:
## Standard Library
import os
import json
import time
import re
import json
import numpy as np
from collections import defaultdict, namedtuple
from random import shuffle, randint, sample, uniform, choice

## External Libraries
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from sklearn.model_selection import train_test_split
import torch.nn.functional as functional
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

In [10]:
#DEFINE Training Hyperparams And Constants

## Batch Size
TRAIN_BATCH_SIZE = 32
VAL_BATCH_SIZE = 32

## Learning Rate
LR = 0.001

# Epochs (Consider setting high and implementing early stopping)
NUM_EPOCHS = 100

GPU_BOOL = torch.cuda.is_available()
GPU_BOOL

BASE_DATA_PATH="/home/dstambl2/doc_alignment_implementations/data"
BASE_EMBED_DIR = '/home/dstambl2/doc_alignment_implementations/data/cc_aligned_si_data/embeddings'

BASE_PROCESSED_PATH="/home/dstambl2/doc_alignment_implementations/data/cc_aligned_si_data/processed" 
ALIGNED_PAIRS_DOC = '/home/dstambl2/doc_alignment_implementations/data/cc_aligned_en_si.pairs'

SRC_LANG_CODE="en"
TGT_LANG_CODE="si"


In [11]:
#Split into train, valid and test sets

from utils.common import load_extracted, map_dic2list, \
    filter_empty_docs, regex_extractor_helper, tokenize_doc_to_sentence
from modules.get_embeddings import read_in_embeddings, load_embeddings
from modules.build_document_vector import build_document_vector
from modules.vector_modules.boiler_plate_weighting import LIDFDownWeighting
from align_docs import fit_pca_reducer
from modules.vector_modules.window_func import ModifiedPertV2
from utils.lru_cache import LRUCache
from utils.common import function_timer

CHUNK_RE = re.compile(r'(chunk_\d*(?:_\d*)?)', flags=re.IGNORECASE)
BASE_DOMAIN_RE = re.compile(r'https?\://(?:w{3}\.)?(?:(?:si|en|sinhala|english)\.)?(.*?)/', flags=re.IGNORECASE)

#from sklearn.model_selection import train_test_split
embed_chunk_paths = []
for subdir, dirs, files in os.walk(BASE_EMBED_DIR):
    if subdir != BASE_EMBED_DIR:
        embed_chunk_paths.append(subdir)
        
embed_chunk_paths = sorted(embed_chunk_paths, key= lambda x: len(x))

#For now split into 0.7/0.1/0.2 split
#Split into train, test and val sets, remove last one for split since it will go into train set
train_ind, test_ind = train_test_split(list(range(len(embed_chunk_paths[:-1]))), test_size=0.2, random_state=1)
train_ind, val_ind = train_test_split(train_ind, test_size=0.125, random_state=1) # 0.125 x 0.8 = 0.1

train_ind.append(len(embed_chunk_paths) -1) #Add last imbalenced idx to train set

assert not any(set(train_ind) & set(test_ind)) and \
       not any(set(train_ind) & set(val_ind))and \
       not any(set(test_ind) & set(val_ind))

print(len(train_ind), len(val_ind), len(test_ind))

chunks_paths_train = [embed_chunk_paths[t] for t in train_ind]
chunks_paths_val = [embed_chunk_paths[t] for t in val_ind]
chunks_paths_test = [embed_chunk_paths[t] for t in test_ind]

#print(chunks_paths_train)

69 10 20


In [11]:
assert not bool(set(chunks_paths_test) & set(chunks_paths_val))
assert not bool(set(chunks_paths_train) & set(chunks_paths_val))
assert not bool(set(chunks_paths_train) & set(chunks_paths_test))

In [16]:

''' The following two funcs are only for getting sent embeds '''
def get_base_embedding(url, embedding_file_path, lang_code):

    chunk = regex_extractor_helper(CHUNK_RE, embedding_file_path)
    doc_path = '%s/%s.%s.gz' % (BASE_PROCESSED_PATH, chunk, lang_code)
    url_doc_dict = filter_empty_docs(load_extracted(doc_path))
    if url not in url_doc_dict:
        print("missing url in get_base_embedding %s for doc path %s" % (url, doc_path))
        #return noise
        return np.random.random((1,1024))
    doc_text = url_doc_dict[url]
        
    _, sent_embeds = read_in_embeddings(doc_text, embedding_file_path, lang_code)
    
    return sent_embeds


def load_embed_pairs(src_url, tgt_url, embed_dict,
                    src_lang=SRC_LANG_CODE,
                    tgt_lang=TGT_LANG_CODE):
    src_path = embed_dict[src_url]
    tgt_path = embed_dict[tgt_url]
    line_embeddings_src = get_base_embedding(src_url, src_path, src_lang)
    line_embeddings_tgt = get_base_embedding(tgt_url, tgt_path, tgt_lang)
    
    #print(line_embeddings_src.shape, line_embeddings_tgt.shape)
    return line_embeddings_src, line_embeddings_tgt
''' END OF PURE SENT EMBED FUNCS '''

def get_matching_url_dicts(input_path = ALIGNED_PAIRS_DOC):
    src_to_tgt = {}
    tgt_to_src = {}
    with open(input_path, 'r') as fp:
        for row in fp:
            src, tgt = row.split('\t')
            src_to_tgt[src] = tgt
            tgt_to_src[tgt]  = src
    return src_to_tgt, tgt_to_src


def load_embed_dict(chunks_paths):
    embed_dict = {}
    
    for chunk_path in chunks_paths:
        embed_dict_path = '%s/embedding_lookup.json' % (chunk_path)
        embed_dict_chunk = {}
        with open(embed_dict_path, 'r') as f:
            embed_dict_chunk = json.load(f)
        embed_dict.update(embed_dict_chunk)
                
    return embed_dict

'''Helper function for finding the most prominent domains to see inner doc sim'''
def filter_for_one_domain(embed_dict, src_to_tgt):
    dom_counter = defaultdict(int)
    
    for k, _ in src_to_tgt.items():
        x = BASE_DOMAIN_RE.findall(k)[0]
        #print(x)
        dom_counter[x] += 1
    return list(reversed(sorted([(k,v) for k, v in dom_counter.items()], key=lambda x: x[1])))[:10]

print(filter_for_one_domain(load_embed_dict(chunks_paths_train), get_matching_url_dicts()[0] ))

TypeError: 'list_reverseiterator' object is not subscriptable

### Building Sample Logic

In [13]:

CandidateTuple = namedtuple(
    "CandidateTuple", "src_embed_path, tgt_embed_path, src_url, tgt_url, y_match_label, augment_neg_data_method")


''' Sampling logic start'''
@function_timer
def create_positive_samples(embed_dict, src_to_tgt_map, tgt_to_src_map, data_list):
    '''
    Builds positive samples
    '''
    for src_url, tgt_url in src_to_tgt_map.items():
        src_url, tgt_url = src_url.strip(), tgt_url.strip()
        if src_url in embed_dict and tgt_url in embed_dict:
            src_embed_path, tgt_embed_path = embed_dict[src_url], embed_dict[tgt_url]
            c = CandidateTuple(src_embed_path, tgt_embed_path, src_url, tgt_url, 1, None)
            data_list.append(c)


def create_all_possible_neg_pairs(src_to_tgt_map, tgt_to_src_map):
    
    src_url_list, tgt_url_list = list(src_to_tgt_map.keys()), list(tgt_to_src_map.keys())
    domain_dict = defaultdict(lambda: defaultdict(list)) #Ex{dom: {src: [], tgt: []}}
    for url in tgt_url_list:
        #url = url.strip()
        base_url = regex_extractor_helper(BASE_DOMAIN_RE, url).strip()
        domain_dict[base_url]['tgt'].append(url)
    
    for url in src_url_list:
        #url = url.strip()
        base_url = regex_extractor_helper(BASE_DOMAIN_RE, url).strip()
        domain_dict[base_url]['src'].append(url)
    
    negative_sample_dict = {}
    #Loop through all domains and create final negative pairing
    for domain, values in domain_dict.items():
        sample_list = []
        src_urls, tgt_urls = values['src'], values['tgt']
        
        if len(src_urls) > 100:
            shuffle(src_urls)
        if len(tgt_urls) > 100:
            shuffle(tgt_urls)
        for src_url in src_urls[:min(100, len(src_urls))]:
            for tgt_url in tgt_urls[:min(100, len(src_urls))]:
                if src_to_tgt_map[src_url] != tgt_url and tgt_to_src_map[tgt_url] != src_url:
                    sample_list.append((src_url, tgt_url))
        if len(sample_list) > 0:
            negative_sample_dict[domain] = sample_list
    return negative_sample_dict
        
    

def same_domain_neg_sample_helper(negative_sample_dict):
    '''
    IDEA: randomly modify docs  
    Helper function for returning
    Negative domains of same idx

    Algo: 1) Pick random domain
         2) Pick random sample from that domain
         3) pop that sample
    '''
    domain_list = list(negative_sample_dict.keys())
    domain = domain_list[randint(0, len(domain_list) - 1)]
    
    neg_pair_list = negative_sample_dict[domain]
    neg_pair_idx = randint(0, len(neg_pair_list) - 1)
    src_url, tgt_url = neg_pair_list[neg_pair_idx]
    
    #Update the dict to remove neg pair
    neg_pair_list.pop(neg_pair_idx)
    if len(neg_pair_list) == 0:
        negative_sample_dict.pop(domain)
    else:
        negative_sample_dict[domain] = neg_pair_list

    return src_url, tgt_url


def create_negative_samples_diff_docs(embed_dict, src_to_tgt_map, tgt_to_src_map, data_list, precent_cutoff):
    '''
    Builds negative samples
    '''

    number_pos_samples = len(data_list)
    visited_urls, MAX_INSTANCES_ALLOWED = defaultdict(int), 10 
    
    negative_sample_dict = create_all_possible_neg_pairs(src_to_tgt_map, tgt_to_src_map)    
    
    
    print("URL LIST LENGTHS", number_pos_samples, len(embed_dict), len(src_to_tgt_map), len(tgt_to_src_map))
    num_samples = int(number_pos_samples * precent_cutoff)
    
    loop_counter = 0 #NOTE: FOR DEBUG
    for i in range(num_samples):
        
        while True:
            src_url, tgt_url = same_domain_neg_sample_helper(negative_sample_dict)

            #Repeat until all conditions are met
            if (visited_urls[src_url.strip()] < MAX_INSTANCES_ALLOWED and \
                visited_urls[tgt_url.strip()] < MAX_INSTANCES_ALLOWED and \
                src_to_tgt_map[src_url] != tgt_url and tgt_to_src_map[tgt_url] != src_url):
                
                src_url, tgt_url = src_url.strip(), tgt_url.strip()
                if src_url in embed_dict and tgt_url in embed_dict:
                    src_embed_path, tgt_embed_path = embed_dict[src_url], embed_dict[tgt_url]
                    c = CandidateTuple(src_embed_path, tgt_embed_path, src_url, tgt_url, 0, None)
                    data_list.append(c)
                
                    visited_urls[src_url] += 1
                    visited_urls[tgt_url] += 1
                    
                break
            else:
                loop_counter += 1

    print("LOOP COUNTER LEN %s" % loop_counter)
    print(len(data_list))

def create_negative_samples_from_aligned_pairs(embed_dict, src_to_tgt_map, tgt_to_src_map, data_list, precent_cutoff):
    '''
    From positive samples, create negative ones, to increase sensitivity of classifier
    By dropping random handful of sents
    In general, the idx method will deal with sent dropping/ adding
    '''
    pos_list = [(src_url, tgt_url) for src_url, tgt_url in src_to_tgt_map.items()]
    
    sample_size = int(len(pos_list) * precent_cutoff)
    pos_list = sample(pos_list, sample_size)

    #Sample half of list and alternate which methods should be used
    method_idx, methods = 0, ["add", "delete", "sub"]
    
    for src_url, tgt_url in pos_list:
        src_url, tgt_url = src_url.strip(), tgt_url.strip()
        if src_url in embed_dict and tgt_url in embed_dict:
            src_embed_path, tgt_embed_path = embed_dict[src_url], embed_dict[tgt_url]
            c = CandidateTuple(src_embed_path, tgt_embed_path, src_url, tgt_url, 0, methods[method_idx]) 
            method_idx  = (method_idx + 1) % len(methods)
            data_list.append(c)

        
#TODO: Uncomment when done with debug
def build_pair_dataset(embed_dict, src_to_tgt_map, tgt_to_src_map):
    '''
    Create positive and negative url tuple pairs
    Makeup of dataset will be:
        50% pure positive pairs
        25% pure negative pairs
        25% negative pairs that are close to each other
    '''
    data_list = []
    create_positive_samples(embed_dict, src_to_tgt_map, tgt_to_src_map, data_list) #First Create positive samples
    print(len(data_list))
    
    subset_src_to_tgt_map = {k: v for k,v in src_to_tgt_map.items() if k.strip() in embed_dict and v.strip() in embed_dict}
    subset_tgt_to_src_map = {k: v for k,v in tgt_to_src_map.items() if k.strip() in embed_dict and v.strip() in embed_dict}
    
    diff_neg_precent, close_neg_percent = 0.5, 0.5
    create_negative_samples_diff_docs(embed_dict, subset_src_to_tgt_map, subset_tgt_to_src_map, data_list, diff_neg_precent)  #Now create negative samples
    create_negative_samples_from_aligned_pairs(embed_dict, subset_src_to_tgt_map, subset_tgt_to_src_map, data_list, close_neg_percent)  #create negative samples by modding positive ones
    #Shuffle and return data
    shuffle(data_list)
    print(len(data_list))

    return data_list

en_to_si_pairs, si_to_en_pairs = get_matching_url_dicts()    

In [14]:
#TEST ABOVE CODE
embed_dict_train = load_embed_dict(chunks_paths_train)
data_list = build_pair_dataset(embed_dict_train, en_to_si_pairs, si_to_en_pairs)
sample_one = list(embed_dict_train.keys())[10]
sample_two = list(embed_dict_train.keys())[57]
print(sample_one, sample_two)
#load_embed_pairs(sample_one, sample_two, embed_dict_train)

Function create_positive_samples took 0.25 seconds
44385
URL LIST LENGTHS 44385 87552 44385 43034
LOOP COUNTER LEN 14
66577
88769
https://www.jsjlmachinery.com/about-us/contact-us/ http://centers.cultural.gov.lk/batticaloa/index.php?option=com_content&view=frontpage&Itemid=65&lid=mm&mid=1&lang=en


### Helper logic (PCA/Cache) for Doc Embeds

In [15]:
import math
from sklearn.decomposition import PCA
def fit_pca_reducer_debug(embedding_list_src, embedding_list_tgt):
    '''
    Builds PCA Dim Reduction from sample of sentence embeddings
    in the webdomain
    '''
    all_sent_embeds = np.vstack(embedding_list_src + embedding_list_tgt)

    pca = PCA(n_components=128)
    divide_num = 1
    if len(all_sent_embeds) // 6 >= 128:
        divide_num = 6
    elif len(all_sent_embeds) // 5 >= 128:
        divide_num = 5
    elif len(all_sent_embeds) // 4 >= 128:
        divide_num = 4
    elif len(all_sent_embeds) // 3 >= 128:
        divide_num = 3
    elif len(all_sent_embeds) // 2 >= 128:
        divide_num = 2
    elif len(all_sent_embeds) // 1 >= 128:
        divide_num = 1
    else:
        sent_size = all_sent_embeds.shape[0]
        num_iters = int(math.ceil(128 / sent_size))        
        all_sent_embeds = np.repeat(all_sent_embeds, repeats=num_iters, axis=0)
        

    my_rand_int = np.random.randint(all_sent_embeds.shape[0], size=len(all_sent_embeds) // divide_num)
    pca_fit_data = all_sent_embeds[my_rand_int, :]
    pca.fit(pca_fit_data)
    return pca

#DEFINE Helper functions for building doc vectors
#NOTE: Much of this info will be stored in an LRU cache


class CachedData:
    '''
    Keeps organized cache of data
    Keys will be domain_name
    Since for each domain, we want the source and target lang info
    '''
    def __init__(self, src_text_list_tokenized,
                       src_embed_list,
                       src_url_list,
                       tgt_text_list_tokenized,
                       tgt_embed_list,
                       tgt_url_list,
                       ):
        self.src_text_list_tokenized = src_text_list_tokenized
        self.src_embed_list = src_embed_list
        self.src_url_list = src_url_list

        self.tgt_text_list_tokenized = tgt_text_list_tokenized
        self.tgt_embed_list = tgt_embed_list
        self.tgt_url_list = tgt_url_list
        
        self.lidf_weighter = LIDFDownWeighting(src_text_list_tokenized + tgt_text_list_tokenized)
        self.pca = fit_pca_reducer_debug(src_embed_list, tgt_embed_list)
    
    def get_fitted_objects(self):
        '''
        Return PCA and LIDF
        '''
        return self.pca, self.lidf_weighter

    def get_src_data(self):
        return self.src_text_list_tokenized, self.src_embed_list, self.src_url_list
    
    def get_tgt_data(self):
        return self.tgt_text_list_tokenized, self.tgt_embed_list, self.tgt_url_list

In [16]:
def create_augment_helper(src_text_list_tokenized, tgt_text_list_tokenized, augment_negative_data_method):
    '''
    Goal is to pick which to augment (src or tgt) and to confirm that adding/subing is possible
    This logic is here, because when we add or sub, we are assuming that other docs exist
    1) Randomly pick if augment_negative_data_method is delete
    2) If add/sub, pick rand or larger of two
    '''
    rv = uniform(0, 1)
    random_change_item = "src" if rv >= 0.5 else "tgt"
    
    if augment_negative_data_method in ["add", "sub"]:
        if len(tgt_text_list_tokenized) == 1 and len(src_text_list_tokenized) == 1:
            augment_negative_data_method = "delete"
            item_to_change = random_change_item
        elif len(tgt_text_list_tokenized) == 1 and len(src_text_list_tokenized) > 1:
            item_to_change = "src"
        elif len(tgt_text_list_tokenized) > 1 and len(src_text_list_tokenized) == 1:
            item_to_change = "tgt"
        else:
            item_to_change = random_change_item
    
    else:
        item_to_change = random_change_item
    
    return augment_negative_data_method, item_to_change


#Cell for dealing with augmentation
def create_augmented_negative_sample(src_url, src_text_list_tokenized, src_url_list, src_embed_list,
                                    tgt_url, tgt_text_list_tokenized, tgt_url_list, tgt_embed_list,
                                    augment_negative_data_method
                                    ):
    '''
    Builds negative sample from either
    1 dropping a few sentences
    2. Substituting sents from another doc
    3. Adding sentences from other docs
    
    Algo
    1) Randomly pick if modifying source of target first
        If add/sub, pick domains with more data
    2) If both src and tgt have only 1 doc, default to delete method
    3) For deletion, pick random num of rand indexes of modified item
    4) For sub, delete a random index and sub out rand sents with another doc
    5) For add, just tack on rand sents

    '''
    METHODS = ["add", "delete", "sub"]
    if augment_negative_data_method not in METHODS:
        print("method not in list")
        return
    
    
    def helper_delete_multiple_element(list_object, indices):
        '''
        Helper function taken from
        https://thispointer.com/python-remove-elements-from-list-by-index/
        '''
        indices = sorted(indices, reverse=True)
        for idx in indices:
            if idx < len(list_object):
                list_object.pop(idx)
    
    def delete_logic(url, change_text_list_tokenized, change_url_list, change_embed_list):
        '''
        Handles deletion logic
        '''
        #First get data that will be augmented
        i = change_url_list.index(url)
        change_doc = change_text_list_tokenized[i]
        
        num_sents = change_embed_list[i].shape[0]
        
        #Note call add logic in this case
        if num_sents == 1:
            add_logic(url, change_text_list_tokenized, change_url_list, change_embed_list)
            return
        change_sent_embeds = change_embed_list[i].tolist()
        
        #Randomly pick how many and which sents to drop, then drop them
        num_sents_to_drop = randint(1, max(int(math.ceil((len(change_doc) - 1) / 5)), 1))
        
        #Make sure that there is always going to be one sentence left
        num_sents_to_drop = min(num_sents_to_drop, num_sents - 1)
        
        index_values = sample(list(enumerate(change_doc)), num_sents_to_drop)
        drop_index_vals = [x[0] for x in index_values]
        
        helper_delete_multiple_element(change_doc, drop_index_vals)
        helper_delete_multiple_element(change_sent_embeds, drop_index_vals)
        
        assert len(change_doc) == len(change_sent_embeds)
        
        #Update the altered doc 
        change_text_list_tokenized[i] = change_doc
        change_embed_list[i] = np.asarray(change_sent_embeds)
        
    def sub_logic(url, change_text_list_tokenized, change_url_list, change_embed_list):
        '''
        Handles substitution logic
        '''
        #First get data that will be augmented
        i = change_url_list.index(url)
        change_doc = change_text_list_tokenized[i]
        change_sent_embeds = change_embed_list[i]
        
        #Randomly pick how many sents to sub
        num_sents_to_sub = randint(1, max(int(math.ceil((len(change_doc) - 1) / 3)), 2))
        #Randomly pick which doc
        sub_idx = choice([j for j in range(len(change_text_list_tokenized)) if j!=i])
        
        #Randomly pick sents to sub with
        sub_doc, sub_embed_matrix = change_text_list_tokenized[sub_idx], change_embed_list[sub_idx]
        
        cut_off_point = min(len(change_doc), len(sub_doc))
        num_sents_to_sub = min(num_sents_to_sub, len(sub_doc[:cut_off_point]))
        index_values = sample(list(enumerate(sub_doc[:cut_off_point])), num_sents_to_sub)
        sub_index_vals = [x[0] for x in index_values]
        
        for idx in sub_index_vals:
            change_doc[idx] = sub_doc[idx]
            change_sent_embeds[idx] = sub_embed_matrix[idx]
        
        assert len(change_doc) == len(change_sent_embeds)
        #Update the altered doc
        change_text_list_tokenized[i] = change_doc
        change_embed_list[i] = change_sent_embeds

    
    def add_logic(url, change_text_list_tokenized, change_url_list, change_embed_list):
        '''
        Handles addition logic
        '''
        #First get data that will be augmented
        i = change_url_list.index(url)
        change_doc = change_text_list_tokenized[i]
        change_sent_embeds = change_embed_list[i]
        
        #Randomly pick how many sents to adds
        num_sents_to_add = randint(1, max(int(math.ceil((len(change_doc) - 1) / 5)), 1))
        #Randomly pick which doc
        add_idx = choice([j for j in range(len(change_text_list_tokenized)) if j!=i])
        
        #Randomly pick sents to add
        add_doc, add_embed_matrix = change_text_list_tokenized[add_idx], change_embed_list[add_idx]
        
        index_values = sample(list(enumerate(add_doc)), num_sents_to_add)
        add_index_vals = [x[0] for x in index_values]
        add_doc_sents = [add_doc[idx] for idx in add_index_vals]
        add_embed_sents = [add_embed_matrix[idx] for idx in add_index_vals]
        
        #Add sents
        change_doc += add_doc_sents
        change_sent_embeds = np.append(change_sent_embeds, np.asarray(add_embed_sents), axis=0)
        
        assert len(change_doc) == len(change_sent_embeds)
        
        #Update the altered doc
        change_text_list_tokenized[i] = change_doc
        change_embed_list[i] = change_sent_embeds
    
    
    augment_negative_data_method, item_to_change = create_augment_helper(src_text_list_tokenized,
                                                                         tgt_text_list_tokenized,
                                                                         augment_negative_data_method)
    if item_to_change == "src":
        url, change_text_list_tokenized, change_url_list, change_embed_list = src_url, src_text_list_tokenized, src_url_list, src_embed_list
    else:
        url, change_text_list_tokenized, change_url_list, change_embed_list = tgt_url, tgt_text_list_tokenized, tgt_url_list, tgt_embed_list
    
    if augment_negative_data_method == "add":
        #Add from one doc to another
        add_logic(url, change_text_list_tokenized, change_url_list, change_embed_list)
    elif augment_negative_data_method == "delete":
        delete_logic(url, change_text_list_tokenized, change_url_list, change_embed_list)
        #delete random number of sentences from source and target
    elif augment_negative_data_method == "sub":
        sub_logic(url, change_text_list_tokenized, change_url_list, change_embed_list)
        #swap sentences 


### Domain Specific Doc Embedding Logic

In [17]:

def load_embeds_for_domain(embed_dict, lang_code, text_list, url_list):
    '''
    Load in embeds for a domain
    ''' 
    embed_list = []
    try:
        for ii in range(len(text_list)):
            url, doc_text = url_list[ii], text_list[ii]
            #if url in embed_dict: #TEMP FIX: TODO: Rerun embeds and issue of missing URLS here or delete sentences that miss embeds
            embed_file_path = embed_dict[url]
            try:
                _, embeddings = read_in_embeddings(doc_text, embed_file_path, lang_code)
                embed_list.append(embeddings)
            except Exception as e:
                print(embed_file_path, lang_code, "EMBED EXCEPTION", e)

    except KeyError as e: #For debugging
        print(e)
        print("EXCEPTION OCCURED in load_embeds_for_domain")
    
    return embed_list


''' Domain Specific Chunk helper data'''

def get_all_chunks_with_domain(embed_dict, base_domain):
    '''
    First get a list of all chunks that contain the domain_name
    Second get all docs in src lang and tgt lang
    Third, get pca, ldf weighter and more
    '''
    chunks = [regex_extractor_helper(CHUNK_RE, value) \
                for key, value in embed_dict.items() if base_domain.strip() in key.strip().lower()]
    if len(chunks) == 0:
        print("WAIT, CHUNK LIST IS EMPTY, so CHUNK_RE error")
    return list(set(chunks))

def load_all_chunks_for_domain(chunk_list, base_domain, lang_code):
    '''
    Given a list of domain chunks and the domain_name
    Get a domain doc dict
    '''
    domain_doc_dict = {}
    for chunk in chunk_list:
        doc_path = '%s/%s.%s.gz' % (BASE_PROCESSED_PATH, chunk, lang_code)
        url_doc_dict = filter_empty_docs(load_extracted(doc_path))
        match_doc_dict = {}
        for k, v in url_doc_dict.items():
            if base_domain.strip().lower() in k.strip().lower():
                match_doc_dict[k] = v
        domain_doc_dict.update(match_doc_dict)
    return domain_doc_dict


def get_all_relevant_domain_data(chunk_list, base_domain, lang_code, embed_dict):
    '''
    return text_list_tokenized, embed_list, url list
    '''
    domain_doc_dict = load_all_chunks_for_domain(chunk_list, base_domain, lang_code) 
    obj_domain = map_dic2list(domain_doc_dict)
    
    text_list = obj_domain['text']
    text_list_tokenized = [tokenize_doc_to_sentence(doc, lang_code) for doc in text_list]
    
    url_list = [url.strip() for url in obj_domain['mapping']]
    embed_list = load_embeds_for_domain(embed_dict, lang_code,
                                        text_list, url_list)
    
    return text_list_tokenized, embed_list, url_list



def get_doc_embedding(url, text_list_tokenized,
                      url_list,
                      embedding_list,
                      lidf_weighter,
                      pca,
                      pert_obj,
                      doc_vec_method):
    '''
    Call doc embedding method
    '''
    i = url_list.index(url)
    
    return build_document_vector(text_list_tokenized[i],
                        url_list[i],
                        embedding_list[i],
                        lidf_weighter,
                        pca,
                        pert_obj,
                        doc_vec_method=doc_vec_method).doc_vector


def handle_doc_embed_logic(embed_dict, src_url, tgt_url,
                         src_lang_code, tgt_lang_code, doc_vector_method,
                         pert_obj, 
                         lru_cache,
                         augment_negative_data_method):
    
   
    base_domain_src = regex_extractor_helper(BASE_DOMAIN_RE, src_url).strip()
    base_domain_tgt = regex_extractor_helper(BASE_DOMAIN_RE, tgt_url).strip()
    
    #NOTE: NOT ALL Pairs share same base domain, ex: www.buyaas.com, si.buyaas.com, so url regex was adjusted     
    if base_domain_src != base_domain_tgt:
        print(base_domain_src, base_domain_tgt, "DIFFERENT DOMAINS")
    assert base_domain_src == base_domain_tgt  
    
    
    chunk_list_src = get_all_chunks_with_domain(embed_dict, base_domain_src) 
    chunk_list_tgt = get_all_chunks_with_domain(embed_dict, base_domain_src)
    
    cd = lru_cache.get(base_domain_src)
    if cd == -1:
        src_text_list_tokenized, src_embed_list, src_url_list = \
            get_all_relevant_domain_data(chunk_list_src, base_domain_src, src_lang_code, embed_dict)
        
        tgt_text_list_tokenized, tgt_embed_list, tgt_url_list = \
            get_all_relevant_domain_data(chunk_list_tgt, base_domain_tgt, tgt_lang_code, embed_dict)
        cd = CachedData(src_text_list_tokenized, src_embed_list, src_url_list, 
                        tgt_text_list_tokenized, tgt_embed_list, tgt_url_list)
        lru_cache.put(base_domain_src, cd)
    else:
        print("cache hit")
    src_text_list_tokenized, src_embed_list, src_url_list = cd.get_src_data()
    tgt_text_list_tokenized, tgt_embed_list, tgt_url_list = cd.get_tgt_data()
    pca, lidf_weighter = cd.get_fitted_objects()
    
    #Handle negative augment data logic
    if augment_negative_data_method is not None:
        print(augment_negative_data_method, "method")
        create_augmented_negative_sample(src_url, src_text_list_tokenized, src_url_list, src_embed_list,
                            tgt_url, tgt_text_list_tokenized, tgt_url_list, tgt_embed_list,
                            augment_negative_data_method)
    
    src_doc_embed = get_doc_embedding(src_url, src_text_list_tokenized, src_url_list, src_embed_list,
                      lidf_weighter, pca, pert_obj, doc_vector_method)
    tgt_doc_embed = get_doc_embedding(tgt_url, tgt_text_list_tokenized, tgt_url_list, tgt_embed_list,
                      lidf_weighter, pca, pert_obj, doc_vector_method)
    return src_doc_embed, tgt_doc_embed
    

### Define Datasets and DataLoader

In [18]:
#Now defime dataloader class
#from threading import Lock

#NOTE: augment_negative_data_flag is for telling dataset class to slightly modify matching docs to form close, but neg pairs
# This helps the sensitivity of the classifier

#First get embed_dict, src_to_tgt_map, tgt_to_src_map
#Then build pairset
#Finally, at each idx, just get embed_src, embed_tgt, y
class DocEmbedDataset(Dataset):
    
    """
    DocEmbeddingDataset
    """
    
    def __init__(self, chunks_paths_list, src_lang, tgt_lang, doc_vector_method="SENT_ORDER", cache_capacity=500): #NOTE: BE SURE TO ALLOCATE LOTS OF MEM
      self.src_to_tgt_pairs, self.tgt_to_src_pairs = get_matching_url_dicts()
      self.embed_dict = load_embed_dict(chunks_paths_list)
      self.data_list = build_pair_dataset(self.embed_dict, self.src_to_tgt_pairs, self.tgt_to_src_pairs)
      
      self.src_lang = src_lang
      self.tgt_lang = tgt_lang
      
      if doc_vector_method not in ['AVG', 'AVG_BP', 'SENT_ORDER']:
        raise ValueError("""Doc Vec Method must be one of the following:
                            AVF, AVG_BP, SENT_ORDER
                            Not found: %s
                            """ % doc_vector_method)
      self.doc_vector_method = doc_vector_method
      self.pert_obj = ModifiedPertV2(None, None)
      self.lru_cache = LRUCache(cache_capacity)
      
      self.exception_counter = 0
      
      #self.lock = Lock()

    def __len__(self):
        """
        Get length of the dataset
        """
        return len(self.data_list)

    def __getitem__(self,
                    idx):
        """
        Gets the two vectors and target
        """
        _, _, src_url, tgt_url, y_match_label, augment_negative_data_method = self.data_list[idx]
        #print("CACHE LIST: %s " %  list(self.lru_cache.keys())) #REMOVE AFTER DEBUG
        #with self.lock:
        #print(y_match_label, "Y LABEL", src_url, tgt_url, augment_negative_data_method)
        '''src_doc_embedding, tgt_doc_embedding = load_embed_pairs(src_url, tgt_url, 
                                                                self.embed_dict,
                                                                src_lang=self.src_lang,
                                                                tgt_lang=self.tgt_lang)
        src_doc_embedding, tgt_doc_embedding = src_doc_embedding[0], tgt_doc_embedding[0] '''
        try:                                               
          src_doc_embedding, tgt_doc_embedding = handle_doc_embed_logic(self.embed_dict,
                                                                        src_url, tgt_url,
                                                                        self.src_lang, self.tgt_lang,
                                                                        self.doc_vector_method,
                                                                        self.pert_obj,
                                                                        self.lru_cache,
                                                                        augment_negative_data_method)
        except Exception as e:
          print(e, "DATA LOADING EXCEPTION, RETURNING NOISE TO BE USED AS NEG PAIR")
          self.exception_counter += 1
          return np.random.random((2048)), np.random.random((2048)), 0 # In rare exception cases
        return src_doc_embedding, tgt_doc_embedding, y_match_label


### New Dataset Classes

In [46]:
import linecache
import json
import numpy as np
len_length = sum(1 for line in open('/home/dstambl2/doc_alignment_implementations/data/cc_aligned_si_data/classifier_train_doc_vecs/valid/data.txt'))
print(len_length)
particular_line  = linecache.getline('/home/dstambl2/doc_alignment_implementations/data/cc_aligned_si_data/classifier_train_doc_vecs/valid/data.txt', 13021)

src, tgt, label = particular_line.split('\t')

def str_float_list_to_np(input):
    '''
    Helper function
    '''
    input = input.replace(']','').replace('[', '')
    input = [float(item) for item in input.split(',')]
    return np.asarray(input)

src, tgt = str_float_list_to_np(src), str_float_list_to_np(tgt)

print(src.shape, src.dtype, tgt.shape, tgt.dtype)
print(len(src), src[:10], src[-10:], tgt[:10], tgt[-10:])
#x = json.loads(src)

#print(x[0:25])

#src = np.frombuffer(src.encode(), dtype=float, count=-1)
#tgt = np.frombuffer(tgt.encode(), dtype=float, count=-1)
#label = int(label)

#print(src.shape)
#print(src.shape, label)



13020


ValueError: not enough values to unpack (expected 3, got 1)

In [47]:
#NOTE: These classes just read in train set
class DocEmbedDataset(Dataset):
    
    """
    DocEmbeddingDataset
    """
    
    def __init__(self, file_path, src_lang, tgt_lang): #NOTE: BE SURE TO ALLOCATE LOTS OF MEM
      self.src_lang = src_lang
      self.tgt_lang = tgt_lang
      self.file_path = file_path
      
      
      self.file_len = sum(1 for line in open(file_path))
      
    
    def __len__(self):
      return self.file_len
    
    def __getitem__(self, idx):
      
      def str_float_list_to_np(input):
        '''
        Helper function
        '''
        input = input.replace(']','').replace('[', '')
        input = [float(item) for item in input.split(',')]
        return np.asarray(input)

      line_cache_idx = idx + 1 #Line cache is indexed at 1 for some reason
      data_record  = linecache.getline(self.file_path, line_cache_idx)
      try:
        src_raw, tgt_raw, label_str = data_record.split('\t')
      except Exception as e:
        print(idx, data_record, "EXCEPTION")
      src_doc_embedding, tgt_doc_embedding = str_float_list_to_np(src_raw), str_float_list_to_np(tgt_raw)
      return src_doc_embedding, tgt_doc_embedding, int(label_str)

    

      


In [48]:
#Create dataloader (USING NEW DATALOADER FORMAT)
val_dataset= DocEmbedDataset('/home/dstambl2/doc_alignment_implementations/data/cc_aligned_si_data/classifier_train_doc_vecs/valid/data.txt', SRC_LANG_CODE, TGT_LANG_CODE)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True) #prefetch_factor=5, prefetch maybe could help 


In [None]:
#Create dataloader (USING OLDER DATALOADER FORMAT)
train_dataset=DocEmbedDataset(chunks_paths_train, SRC_LANG_CODE, TGT_LANG_CODE) #TODO: Remove after done debug of loop
validation_dataset=DocEmbedDataset(chunks_paths_val, SRC_LANG_CODE, TGT_LANG_CODE)
test_dataset=DocEmbedDataset(chunks_paths_test, SRC_LANG_CODE, TGT_LANG_CODE)

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True) #prefetch_factor=5, prefetch maybe could help 
validation_dataloader = DataLoader(validation_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [49]:
idx = 0
#, len(validation_dataset), len(test_dataloader))
for i, vals in enumerate(val_dataloader):
    x_1, x_2, y = vals

407


### Defining Plotting Logic

In [33]:
#Define loss func
#Plot loss function
def plot_loss_charts(train_loss_store, validation_loss_store):
  '''
  Plots loss charts over course of training
  '''
  ## Plotting epoch-wise test loss curve:
  plt.plot(train_loss_store, '-o', label = 'train_loss', color = 'orange')
  plt.plot(validation_loss_store, label = 'validation_loss', color = 'blue')
  plt.xlabel('Epoch Number')
  plt.ylabel('Loss At each epoch')
  plt.legend()
  plot_loss_charts
  #plt.show()
  plt.savefig('/home/dstambl2/doc_alignment_implementations/thompson_2021_doc_align/devtools/plots/%s_two.png' % "loss_plots")
  plt.clf()


def plot_accuracy(train_score_store, validation_score_store, skip_plot=False):
  '''
  Plots Accuracy charts over course of training
  '''
  #Don't plot if this flag is set to true
  if skip_plot:
    return

  ## Plotting epoch-wise test loss curve:
  plt.plot(train_score_store, '-o', label = 'train_accuracy', color = 'orange')
  plt.plot(validation_score_store, label = 'validation_accuracy', color = 'blue')
  plt.xlabel('Epoch Number')
  plt.ylabel('Accuracy At each epoch')
  plt.legend()
  #plt.show()
  plt.savefig('/home/dstambl2/doc_alignment_implementations/thompson_2021_doc_align/devtools/plots/%s_two.png' % "acc_plots")
  plt.clf()

### Creating the model

In [32]:
class DocAlignerClassifier(nn.Module):
    def __init__(self, input_size, hidden_size_one=256, hidden_size_two=128, hidden_size_three=64):
        super(DocAlignerClassifier, self).__init__()
        
        self.fc1 = nn.Linear(input_size, hidden_size_one)
        self.fc2 = nn.Linear(hidden_size_one, hidden_size_two)
        self.fc3 = nn.Linear( hidden_size_two, hidden_size_three)
        self.fc_out = nn.Linear(hidden_size_three, 1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        #Optional
        self.dropout = nn.Dropout(p=0.1)
        #self.batchnorm1 = nn.BatchNorm1d(128)
        #self.batchnorm2 = nn.BatchNorm1d(64)
        
        nn.init.kaiming_uniform_(self.fc1.weight) #He init for relu activation layers
        nn.init.kaiming_uniform_(self.fc2.weight)
        #nn.init.kaiming_uniform_(self.fc3.weight)
        nn.init.xavier_uniform_(self.fc_out.weight) #Xavier init for sigmoid

                
    def forward(self, src_doc_embed, tgt_doc_embed):
        '''
        Forward data through the lstm
        '''
        combined_input = torch.cat((src_doc_embed.view(src_doc_embed.size(0), -1),
                          tgt_doc_embed.view(tgt_doc_embed.size(0), -1)), dim=1)
        #print(combined_input.shape, "combed input")

        x = self.relu(self.fc1(combined_input))
        #x = self.batchnorm1(x)
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x)) #NOTE: Better to leave out batch norm
        #x = self.batchnorm2(x)
        x = self.dropout(x)
        output = self.sigmoid(self.fc_out(x))

        return output

In [62]:
##Define Model Saving functionality
import json
def handle_model_save(val_acc_store, val_loss_store, train_loss_store, train_acc_store,
                      base_path, model_name, epoch_num, model, optimizer):
  '''
  Function for saving models and training data
  '''
  save_dict = {
      'val_acc_store': val_acc_store,
      'val_loss_store': val_loss_store,
      'train_loss_store': train_loss_store,
      'train_acc_store': train_acc_store,
      'epoch': epoch_num,
  }
  traing_info_path = "%s/%s_epoch_%s.json" % (base_path, model_name, epoch_num)
  with open(traing_info_path, 'w') as f:
    json.dump(save_dict, f)

  MODEL_PATH = "%s/%s_%s" % (base_path, model_name, epoch_num)
  torch.save({
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              }, MODEL_PATH)


#Define loading historic training data func
def load_train_data(traing_info_path):
  with open(traing_info_path) as json_file:
    data = json.load(json_file)
  
  return data['val_acc_store'], data['val_loss_store'], data['train_loss_store'], data['train_acc_store'], int(data['epoch'])


### Training Loop

In [63]:
##Define Loss and Accuracy Eval Function
def eval_acc_and_loss_func(model, loader, device, loss_metric, is_train = False, verbose = 1):
    '''
    Evaluate Function for CNN training
    Slightly different than eval function from part 1
    '''
    correct, total, loss_sum = 0, 0, 0
    
    #temp_idx = 0  TODO: remove when down with loop debug
    
    eval_type = "Train" if is_train else "Validation"
    for X_1, X_2, Y in loader:
        outputs, predicted, calculated_loss = None, None, None
        X_1, X_2, Y = X_1.to(device), X_2.to(device), Y.to(device)
        
        X_1, X_2 = X_1.float(), X_2.float()
        outputs = model(X_1, X_2)
        
        #Reshape output and turn y into a float
        outputs = outputs.view(-1)
        Y = Y.float()
        predicted = torch.round(outputs)
        total += Y.size(0)
        
        correct += (predicted == Y).sum().item()
        calculated_loss = loss_metric(outputs,Y).item()
        loss_sum += calculated_loss
        
        #TODO: remove these lines when done with loop debug
        #temp_idx += 1
        #if temp_idx >= 3:
        #    break
        
        
    outputs, predicted, calculated_loss = None, None, None
    if verbose:
        print('%s accuracy: %f %%' % (eval_type, 100.0 * correct / total))
        print('%s loss: %f' % (eval_type, loss_sum / total))
    print
    return 100.0 * correct / total, loss_sum/ total

In [48]:
#DEFINE TRAIN LOOP HERE
def train(model,
          optimizer,
          loss_metric,
          lr,
          train_dataloader,
          valid_dataloader,
          device,
          epochs=5,
          stopping_threshold=3,
          saving_per_epoch=1,
          base_save_path="/home/dstambl2/doc_alignment_implementations/thompson_2021_doc_align/devtools/models",
          model_name="train_loop_model",
          load_train_hist_path=None,
          **kwargs):
    """
    For each epoch, loop through batch,
    compute forward and backward passes, apply gradient updates
    Evaluate results and output
    """

    #If data already exists, that means the model was preloaded and training should
    #resume from where it was interupted
    if load_train_hist_path is not None:
      val_acc_store, val_loss_store, train_loss_store, train_acc_store, start_epoch =load_train_data(load_train_hist_path)
    else:
      train_loss_store, train_acc_store = [], []
      val_loss_store, val_acc_store, = [], []
      start_epoch = 0

    #Declare variables for early stopping
    last_val_loss, stop_tracker = 100, 0

    #training loop:
    print("Starting Training")
    for epoch in range(start_epoch, epochs):
      time1 = time.time() #timekeeping
      outputs, loss = None, None

      model.train()
      
      correct_train, total, loss_sum = 0, 0, 0
      for i, (x_1, x_2, y) in enumerate(train_dataloader):
        
        # Print device human readable names
        #torch.cuda.get_device_name()

        x_1, x_2, y = x_1.to(device), x_2.to(device), y.to(device)

        #loss calculation and gradient update:

        x_1, x_2 = x_1.float(), x_2.float()
        if i > 0 or epoch > 0:
          optimizer.zero_grad()
        outputs = model.forward(x_1, x_2)
        
        #Reshape output and turn y into a float
        outputs = outputs.view(-1)
        y = y.float()
        #print(y.shape, outputs.shape, outputs, "Loss inp info")

        
        loss = loss_metric(outputs, y)
        loss.backward()
                      
        ##performing update:
        optimizer.step()

        #Update Loss Info
        loss_sum += loss.item()
        
        #Acc was likely not increasing, because preds kept rounding to zero
        predicted = torch.round(outputs) #NOTE: MODEL predictions keep rounding to zero
        #print(outputs, predicted, y, "predicted stuff", y.size(0))

        total += y.size(0)
        correct_train += (predicted == y).sum().item()
        
        #TEMP, TODO: Delete when done debug
        #if i >= 3:
        #  break
              
      print("Epoch",epoch+1,':')

      
      model.eval()
      with torch.no_grad():

        #Print Train Info
        print('%s accuracy: %f %%' % ("Train", 100.0 * (correct_train / total)))
        print('%s loss: %f' % ("Train", loss_sum / total))
        print
        
        train_acc, train_loss = 100.0 * correct_train / total, loss_sum/ total
        val_acc, val_loss = eval_acc_and_loss_func(model, valid_dataloader, device, loss_metric, is_train = False)

        val_acc_store.append(val_acc)
        val_loss_store.append(val_loss)

        train_loss_store.append(train_loss)
        train_acc_store.append(train_acc)

      time2 = time.time() #timekeeping
      #if show_progress:
      print('Elapsed time for epoch:',time2 - time1,'s')
      print('ETA of completion:',(time2 - time1)*(epochs - epoch - 1)/60,'minutes')
      
      if (epoch + 1) % saving_per_epoch == 0:
        handle_model_save(val_acc_store, val_loss_store, train_loss_store, train_acc_store,
                          base_save_path, model_name, epoch + 1, model, optimizer)
        print("Model Copy Saved")
      print()


      #Handle early stopping logic
      if val_loss >= last_val_loss:
            stop_tracker += 1
            if stop_tracker >= stopping_threshold:
                print('Early Stopping triggered, Convergence has occured')
                plot_loss_charts(train_loss_store, val_loss_store)
                plot_accuracy(train_acc_store, val_acc_store)
                handle_model_save(val_acc_store, val_loss_store, train_loss_store, train_acc_store,
                  base_save_path, model_name, epoch + 1, model, optimizer)
                print("Model Copy Saved")

                return train_loss_store, val_acc_store
      else:
          stop_tracker = 0
      last_val_loss = val_loss


    plot_loss_charts(train_loss_store, val_loss_store)
    plot_accuracy(train_acc_store, val_acc_store)
    return train_loss_store, val_acc_store

In [55]:
#Define the model, and kick off training
LEARNING_RATE = 0.001 #NOTE: CAN ADJUST
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DocAlignerClassifier(4096) #Each doc vec is 2048, so times 2 will be 4096
model = model.to(device)

optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
loss_fn = nn.BCELoss() #nn.BCEWithLogitsLoss()

In [50]:
#TODO: Test this with just base sent embeddings (as a place holder) just to see that training is working
# Since its much faster than using doc vector
train(
        model,
        optimizer,
        loss_fn,
        LEARNING_RATE,
        train_dataloader,
        validation_dataloader,
        device,
        epochs=5
    )
print(train_dataset.exception_counter, "exception counter")

Starting Training
1 Y LABEL https://www.etthawitthi.com/local/ranjith-de-zoysa-passed-away https://www.sinhala.etthawitthi.com/local/ranjith-de-zoysa-passed-away None
0 Y LABEL http://www.sdzthl.com/galvalume-aluzinc-steel-coils-as-per-jis3321.html http://www.sdzthl.com/si/galvalume-aluzinc-steel-coils-as-per-jis3321.html sub
sub method
0 Y LABEL https://www.niceterminal.com/products/2-t-type-connection-terminal-block.html https://www.niceterminal.com/si/products/color-transparent-error-prevention-one-inlet-three-outlet.html None
tensor([0.7942, 0.3255, 0.4217], grad_fn=<ViewBackward0>) tensor([1., 0., 0.], grad_fn=<RoundBackward0>) tensor([1., 0., 0.]) predicted stuff 3
1 Y LABEL http://nirvanadhamma.lk/en/the-platform/dhamma-discourses/ http://www.nirvanadhamma.lk/the-platform/dhamma-discourses/ None
cache hit
1 Y LABEL http://xn--w0ct5a8c.xn--n0chiqomy9ed8bxb2a8e.xn--fzc2c9e2c/en/members-of-parliament/the-system-of-elections-in-sri-lanka/conduct-of-the-elections http://xn--w0ct5a8c.

<Figure size 432x288 with 0 Axes>

In [None]:
#Eval on test set
#model_checkpoint = torch.load('/home/dstambl2/doc_alignment_implementations/thompson_2021_doc_align/devtools/models/train_loop_model_5')
#model.load_state_dict(model_checkpoint['model_state_dict'])
eval_acc_and_loss_func(model, test_dataloader, device, loss_fn, is_train = False, verbose = 1)