In [1]:
%set_env PYTHONPATH="/home/dstambl2/doc_alignment_implementations/thompson_2021_doc_align"

env: PYTHONPATH="/home/dstambl2/doc_alignment_implementations/thompson_2021_doc_align"


In [2]:
## Standard Library
import os
import json
import time
import re
import json
import numpy as np
from collections import defaultdict, namedtuple
from random import shuffle, randint

## External Libraries
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from sklearn.model_selection import train_test_split
import torch.nn.functional as functional
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
#DEFINE Training Hyperparams And Constants

## Batch Size
TRAIN_BATCH_SIZE = 32
VAL_BATCH_SIZE = 32

## Learning Rate
LR = 0.001

# Epochs (Consider setting high and implementing early stopping)
NUM_EPOCHS = 100

GPU_BOOL = torch.cuda.is_available()
GPU_BOOL

BASE_DATA_PATH="/home/dstambl2/doc_alignment_implementations/data"
BASE_EMBED_DIR = '/home/dstambl2/doc_alignment_implementations/data/cc_aligned_si_data/embeddings'

BASE_PROCESSED_PATH="/home/dstambl2/doc_alignment_implementations/data/cc_aligned_si_data/processed" 
ALIGNED_PAIRS_DOC = '/home/dstambl2/doc_alignment_implementations/data/cc_aligned_en_si.pairs'

SRC_LANG_CODE="en"
TGT_LANG_CODE="si"


In [4]:
#Split into train, valid and test sets

from utils.common import load_extracted, map_dic2list, \
    filter_empty_docs, regex_extractor_helper, tokenize_doc_to_sentence
from modules.get_embeddings import read_in_embeddings, load_embeddings
from modules.build_document_vector import build_document_vector
from modules.vector_modules.boiler_plate_weighting import LIDFDownWeighting
from align_docs import fit_pca_reducer
from modules.vector_modules.window_func import ModifiedPertV2
from utils.lru_cache import LRUCache
from utils.common import function_timer

CHUNK_RE = re.compile(r'(chunk_\d*(?:_\d*)?)', flags=re.IGNORECASE)
BASE_DOMAIN_RE = re.compile(r'https?\://(?:w{3}\.)?(?:(?:si|en)\.)?(.*?)/', flags=re.IGNORECASE)

#from sklearn.model_selection import train_test_split
embed_chunk_paths = []
for subdir, dirs, files in os.walk(BASE_EMBED_DIR):
    if subdir != BASE_EMBED_DIR:
        embed_chunk_paths.append(subdir)
        
embed_chunk_paths = sorted(embed_chunk_paths, key= lambda x: len(x))

#For now split into 0.7/0.1/0.2 split
#Split into train, test and val sets, remove last one for split since it will go into train set
train_ind, test_ind = train_test_split(list(range(len(embed_chunk_paths[:-1]))), test_size=0.2, random_state=1)
train_ind, val_ind = train_test_split(train_ind, test_size=0.125, random_state=1) # 0.125 x 0.8 = 0.1

train_ind.append(len(embed_chunk_paths) -1) #Add last imbalenced idx to train set

assert not any(set(train_ind) & set(test_ind)) and \
       not any(set(train_ind) & set(val_ind))and \
       not any(set(test_ind) & set(val_ind))

print(len(train_ind), len(val_ind), len(test_ind))

chunks_paths_train = [embed_chunk_paths[t] for t in train_ind]
chunks_paths_val = [embed_chunk_paths[t] for t in val_ind]
chunks_paths_test = [embed_chunk_paths[t] for t in test_ind]



73 11 21


In [5]:
print(sorted(train_ind))

[0, 3, 4, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 34, 36, 40, 41, 42, 43, 44, 45, 48, 49, 50, 51, 52, 54, 55, 57, 60, 61, 62, 63, 64, 66, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 82, 83, 86, 87, 90, 92, 93, 94, 95, 97, 99, 102, 104]


In [17]:
''' The following two funcs are only for getting sent embeds '''
def get_base_embedding(url, embedding_file_path, lang_code):

    chunk = regex_extractor_helper(CHUNK_RE, embedding_file_path)
    doc_path = '%s/%s.%s.gz' % (BASE_PROCESSED_PATH, chunk, lang_code)
    url_doc_dict = filter_empty_docs(load_extracted(doc_path))
    if url not in url_doc_dict:
        print("missing url in get_base_embedding %s for doc path %s" % (url, doc_path))
        #return noise
        return np.random.random((1,1024))
    doc_text = url_doc_dict[url]
        
    _, sent_embeds = read_in_embeddings(doc_text, embedding_file_path, lang_code)
    
    return sent_embeds


def load_embed_pairs(src_url, tgt_url, embed_dict,
                    src_lang=SRC_LANG_CODE,
                    tgt_lang=TGT_LANG_CODE):
    src_path = embed_dict[src_url]
    tgt_path = embed_dict[tgt_url]
    line_embeddings_src = get_base_embedding(src_url, src_path, src_lang)
    line_embeddings_tgt = get_base_embedding(tgt_url, tgt_path, tgt_lang)
    
    #print(line_embeddings_src.shape, line_embeddings_tgt.shape)
    return line_embeddings_src, line_embeddings_tgt
''' END OF PURE SENT EMBED FUNCS '''

def get_matching_url_dicts(input_path = ALIGNED_PAIRS_DOC):
    src_to_tgt = {}
    tgt_to_src = {}
    with open(input_path, 'r') as fp:
        for row in fp:
            src, tgt = row.split('\t')
            src_to_tgt[src] = tgt
            tgt_to_src[tgt]  = src
    return src_to_tgt, tgt_to_src


def load_embed_dict(chunks_paths):
    embed_dict = {}
    
    for chunk_path in chunks_paths:
        embed_dict_path = '%s/embedding_lookup.json' % (chunk_path)
        embed_dict_chunk = {}
        with open(embed_dict_path, 'r') as f:
            embed_dict_chunk = json.load(f)
        embed_dict.update(embed_dict_chunk)
                
    return embed_dict


### Building Sample Logic

In [18]:

CandidateTuple = namedtuple(
    "CandidateTuple", "src_embed_path, tgt_embed_path, src_url, tgt_url, y_match_label")


''' Sampling logic start'''
@function_timer
def create_positive_samples(embed_dict, src_to_tgt_map, tgt_to_src_map, data_list):
    '''
    Builds positive samples
    '''
    for src_url, tgt_url in src_to_tgt_map.items():
        src_url, tgt_url = src_url.strip(), tgt_url.strip()
        if src_url in embed_dict and tgt_url in embed_dict:
            src_embed_path, tgt_embed_path = embed_dict[src_url], embed_dict[tgt_url]
            c = CandidateTuple(src_embed_path, tgt_embed_path, src_url, tgt_url, 1)
            data_list.append(c)


def create_all_possible_neg_pairs(src_to_tgt_map, tgt_to_src_map):
    
    src_url_list, tgt_url_list = list(src_to_tgt_map.keys()), list(tgt_to_src_map.keys())
    domain_dict = defaultdict(lambda: defaultdict(list)) #Ex{dom: {src: [], tgt: []}}
    for url in tgt_url_list:
        #url = url.strip()
        base_url = regex_extractor_helper(BASE_DOMAIN_RE, url).strip()
        domain_dict[base_url]['tgt'].append(url)
    
    for url in src_url_list:
        #url = url.strip()
        base_url = regex_extractor_helper(BASE_DOMAIN_RE, url).strip()
        domain_dict[base_url]['src'].append(url)
    
    negative_sample_dict = {}
    #Loop through all domains and create final negative pairing
    for domain, values in domain_dict.items():
        sample_list = []
        src_urls, tgt_urls = values['src'], values['tgt']
        
        if len(src_urls) > 100:
            shuffle(src_urls)
        if len(tgt_urls) > 100:
            shuffle(tgt_urls)
        for src_url in src_urls[:min(100, len(src_urls))]:
            for tgt_url in tgt_urls[:min(100, len(src_urls))]:
                if src_to_tgt_map[src_url] != tgt_url and tgt_to_src_map[tgt_url] != src_url:
                    sample_list.append((src_url, tgt_url))
        if len(sample_list) > 0:
            negative_sample_dict[domain] = sample_list
    return negative_sample_dict
        
    

def same_domain_neg_sample_helper(negative_sample_dict):
    '''
    IDEA: randomly modify docs  
    Helper function for returning
    Negative domains of same idx

    Algo: 1) Pick random domain
         2) Pick random sample from that domain
         3) pop that sample
    '''
    domain_list = list(negative_sample_dict.keys())
    domain = domain_list[randint(0, len(domain_list) - 1)]
    
    neg_pair_list = negative_sample_dict[domain]
    neg_pair_idx = randint(0, len(neg_pair_list) - 1)
    src_url, tgt_url = neg_pair_list[neg_pair_idx]
    
    #Update the dict to remove neg pair
    neg_pair_list.pop(neg_pair_idx)
    if len(neg_pair_list) == 0:
        negative_sample_dict.pop(domain)
    else:
        negative_sample_dict[domain] = neg_pair_list

    return src_url, tgt_url


def create_negative_samples(embed_dict, src_to_tgt_map, tgt_to_src_map, data_list):
    '''
    Builds negative samples
    '''

    number_pos_samples = len(data_list)
    visited_urls, MAX_INSTANCES_ALLOWED = defaultdict(int), 10 
    
    negative_sample_dict = create_all_possible_neg_pairs(src_to_tgt_map, tgt_to_src_map)    
    
    
    print("URL LIST LENGTHS", number_pos_samples, len(embed_dict), len(src_to_tgt_map), len(tgt_to_src_map))
    
    loop_counter = 0 #NOTE: FOR DEBUG
    for i in range(number_pos_samples):
        
        while True:
            src_url, tgt_url = same_domain_neg_sample_helper(negative_sample_dict)

            #Repeat until all conditions are met
            if (visited_urls[src_url.strip()] < MAX_INSTANCES_ALLOWED and \
                visited_urls[tgt_url.strip()] < MAX_INSTANCES_ALLOWED and \
                src_to_tgt_map[src_url] != tgt_url and tgt_to_src_map[tgt_url] != src_url):
                
                src_url, tgt_url = src_url.strip(), tgt_url.strip()
                if src_url in embed_dict and tgt_url in embed_dict:
                    src_embed_path, tgt_embed_path = embed_dict[src_url], embed_dict[tgt_url]
                    c = CandidateTuple(src_embed_path, tgt_embed_path, src_url, tgt_url, 0)
                    data_list.append(c)
                
                    visited_urls[src_url] += 1
                    visited_urls[tgt_url] += 1
                    
                break
            else:
                loop_counter += 1

    print("LOOP COUNTER LEN %s" % loop_counter)
    print(len(data_list))
        

def build_pair_dataset(embed_dict, src_to_tgt_map, tgt_to_src_map):
    '''
    Create positive and negative url tuple pairs
    '''
    data_list = []
    create_positive_samples(embed_dict, src_to_tgt_map, tgt_to_src_map, data_list) #First Create positive samples
    print(len(data_list))
    
    subset_src_to_tgt_map = {k: v for k,v in src_to_tgt_map.items() if k.strip() in embed_dict and v.strip() in embed_dict}
    subset_tgt_to_src_map = {k: v for k,v in tgt_to_src_map.items() if k.strip() in embed_dict and v.strip() in embed_dict}
    
    create_negative_samples(embed_dict, subset_src_to_tgt_map, subset_tgt_to_src_map, data_list)  #Now create negative samples
    
    #Shuffle and return data
    shuffle(data_list)
    print(len(data_list))

    return data_list

en_to_si_pairs, si_to_en_pairs = get_matching_url_dicts()    

In [19]:
#TEST ABOVE CODE
embed_dict_train = load_embed_dict(chunks_paths_train)
data_list = build_pair_dataset(embed_dict_train, en_to_si_pairs, si_to_en_pairs)
sample_one = list(embed_dict_train.keys())[10]
sample_two = list(embed_dict_train.keys())[57]
print(sample_one, sample_two)
#load_embed_pairs(sample_one, sample_two, embed_dict_train)

Function create_positive_samples took 0.33 seconds
45606
URL LIST LENGTHS 45606 90001 45606 44260
LOOP COUNTER LEN 75
91212
91212
https://tattoosartideas.com/tiger-tattoo/ https://www.ideabeam.com/mobile/store/idealz-lanka


### Helper logic (PCA/Cache) for Doc Embeds

In [20]:
import math
from sklearn.decomposition import PCA
def fit_pca_reducer_debug(embedding_list_src, embedding_list_tgt):
    '''
    Builds PCA Dim Reduction from sample of sentence embeddings
    in the webdomain
    '''
    all_sent_embeds = np.vstack(embedding_list_src + embedding_list_tgt)

    pca = PCA(n_components=128)
    divide_num = 1
    if len(all_sent_embeds) // 6 >= 128:
        divide_num = 6
    elif len(all_sent_embeds) // 5 >= 128:
        divide_num = 5
    elif len(all_sent_embeds) // 4 >= 128:
        divide_num = 4
    elif len(all_sent_embeds) // 3 >= 128:
        divide_num = 3
    elif len(all_sent_embeds) // 2 >= 128:
        divide_num = 2
    elif len(all_sent_embeds) // 1 >= 128:
        divide_num = 1
    else:
        sent_size = all_sent_embeds.shape[0]
        num_iters = int(math.ceil(128 / sent_size))        
        all_sent_embeds = np.repeat(all_sent_embeds, repeats=num_iters, axis=0)
        

    my_rand_int = np.random.randint(all_sent_embeds.shape[0], size=len(all_sent_embeds) // divide_num)
    pca_fit_data = all_sent_embeds[my_rand_int, :]
    pca.fit(pca_fit_data)
    return pca

#DEFINE Helper functions for building doc vectors
#NOTE: Much of this info will be stored in an LRU cache


class CachedData:
    '''
    Keeps organized cache of data
    Keys will be domain_name
    Since for each domain, we want the source and target lang info
    '''
    def __init__(self, src_text_list_tokenized,
                       src_embed_list,
                       src_url_list,
                       tgt_text_list_tokenized,
                       tgt_embed_list,
                       tgt_url_list,
                       ):
        self.src_text_list_tokenized = src_text_list_tokenized
        self.src_embed_list = src_embed_list
        self.src_url_list = src_url_list

        self.tgt_text_list_tokenized = tgt_text_list_tokenized
        self.tgt_embed_list = tgt_embed_list
        self.tgt_url_list = tgt_url_list
        
        self.lidf_weighter = LIDFDownWeighting(src_text_list_tokenized + tgt_text_list_tokenized)
        self.pca = fit_pca_reducer_debug(src_embed_list, tgt_embed_list)
    
    def get_fitted_objects(self):
        '''
        Return PCA and LIDF
        '''
        return self.pca, self.lidf_weighter

    def get_src_data(self):
        return self.src_text_list_tokenized, self.src_embed_list, self.src_url_list
    
    def get_tgt_data(self):
        return self.tgt_text_list_tokenized, self.tgt_embed_list, self.tgt_url_list

### Domain Specific Doc Embedding Logic

In [21]:

def load_embeds_for_domain(embed_dict, lang_code, text_list, url_list):
    '''
    Load in embeds for a domain
    ''' 
    embed_list = []
    try:
        for ii in range(len(text_list)):
            url, doc_text = url_list[ii], text_list[ii]
            if url in embed_dict: #TEMP FIX: TODO: Rerun embeds and issue of missing URLS here or delete sentences that miss embeds
                embed_file_path = embed_dict[url]
                try:
                    _, embeddings = read_in_embeddings(doc_text, embed_file_path, lang_code)
                    embed_list.append(embeddings)
                except:
                    print(embed_file_path, lang_code, "EMBED EXCEPTION")
            else:
                print("URL NOT IN EMBED DICT: ", url)

    except KeyError as e: #For debugging
        print(e)
        print("EXCEPTION OCCURED in load_embeds_for_domain")
    
    return embed_list


''' Domain Specific Chunk helper data'''

def get_all_chunks_with_domain(embed_dict, base_domain):
    '''
    First get a list of all chunks that contain the domain_name
    Second get all docs in src lang and tgt lang
    Third, get pca, ldf weighter and more
    '''
    chunks = [regex_extractor_helper(CHUNK_RE, value) \
                for key, value in embed_dict.items() if base_domain.strip() in key.strip().lower()]
    if len(chunks) == 0:
        print("WAIT, CHUNK LIST IS EMPTY, so CHUNK_RE error")
    return list(set(chunks))

def load_all_chunks_for_domain(chunk_list, base_domain, lang_code):
    '''
    Given a list of domain chunks and the domain_name
    Get a domain doc dict
    '''
    domain_doc_dict = {}
    for chunk in chunk_list:
        doc_path = '%s/%s.%s.gz' % (BASE_PROCESSED_PATH, chunk, lang_code)
        url_doc_dict = filter_empty_docs(load_extracted(doc_path))
        match_doc_dict = {}
        for k, v in url_doc_dict.items():
            if base_domain.strip().lower() in k.strip().lower():
                match_doc_dict[k] = v
        domain_doc_dict.update(match_doc_dict)
    return domain_doc_dict


def get_all_relevant_domain_data(chunk_list, base_domain, lang_code, embed_dict):
    '''
    return text_list_tokenized, embed_list, url list
    '''
    domain_doc_dict = load_all_chunks_for_domain(chunk_list, base_domain, lang_code) 
    obj_domain = map_dic2list(domain_doc_dict)
    
    text_list = obj_domain['text']
    text_list_tokenized = [tokenize_doc_to_sentence(doc, lang_code) for doc in text_list]
    
    url_list = [url.strip() for url in obj_domain['mapping']]
    embed_list = load_embeds_for_domain(embed_dict, lang_code,
                                        text_list, url_list)
    
    return text_list_tokenized, embed_list, url_list


def get_doc_embedding(url, text_list_tokenized,
                      url_list,
                      embedding_list,
                      lidf_weighter,
                      pca,
                      pert_obj,
                      doc_vec_method):
    
    i = url_list.index(url)
    
    return build_document_vector(text_list_tokenized[i],
                        url_list[i],
                        embedding_list[i],
                        lidf_weighter,
                        pca,
                        pert_obj,
                        doc_vec_method=doc_vec_method).doc_vector


def handle_doc_embed_logic(embed_dict, src_url, tgt_url,
                         src_lang_code, tgt_lang_code, doc_vector_method,
                         pert_obj, 
                         lru_cache):
    
   
    base_domain_src = regex_extractor_helper(BASE_DOMAIN_RE, src_url).strip()
    base_domain_tgt = regex_extractor_helper(BASE_DOMAIN_RE, tgt_url).strip()
    
    #NOTE: NOT ALL Pairs share same base domain, ex: www.buyaas.com, si.buyaas.com, so url regex was adjusted     
    if base_domain_src != base_domain_tgt:
        print(base_domain_src, base_domain_tgt, "DIFFERENT DOMAINS")
    assert base_domain_src == base_domain_tgt  
    
    
    chunk_list_src = get_all_chunks_with_domain(embed_dict, base_domain_src) 
    chunk_list_tgt = get_all_chunks_with_domain(embed_dict, base_domain_src)
    
    cd = lru_cache.get(base_domain_src)
    if cd == -1:
        src_text_list_tokenized, src_embed_list, src_url_list = \
            get_all_relevant_domain_data(chunk_list_src, base_domain_src, src_lang_code, embed_dict)
        
        tgt_text_list_tokenized, tgt_embed_list, tgt_url_list = \
            get_all_relevant_domain_data(chunk_list_tgt, base_domain_tgt, tgt_lang_code, embed_dict)
        cd = CachedData(src_text_list_tokenized, src_embed_list, src_url_list, 
                        tgt_text_list_tokenized, tgt_embed_list, tgt_url_list)
        lru_cache.put(base_domain_src, cd)
    else:
        print("cache hit")
    src_text_list_tokenized, src_embed_list, src_url_list = cd.get_src_data()
    tgt_text_list_tokenized, tgt_embed_list, tgt_url_list = cd.get_tgt_data()
    pca, lidf_weighter = cd.get_fitted_objects()
        

    src_doc_embed = get_doc_embedding(src_url, src_text_list_tokenized, src_url_list, src_embed_list,
                      lidf_weighter, pca, pert_obj, doc_vector_method)
    tgt_doc_embed = get_doc_embedding(tgt_url, tgt_text_list_tokenized, tgt_url_list, tgt_embed_list,
                      lidf_weighter, pca, pert_obj, doc_vector_method)
    return src_doc_embed, tgt_doc_embed
    

### Define Datasets and DataLoader

In [22]:
#Now defime dataloader class
from threading import Lock


#First get embed_dict, src_to_tgt_map, tgt_to_src_map
#Then build pairset
#Finally, at each idx, just get embed_src, embed_tgt, y
class DocEmbedDataset(Dataset):
    
    """
    DocEmbeddingDataset
    """
    
    def __init__(self, chunks_paths_list, src_lang, tgt_lang, doc_vector_method="SENT_ORDER", cache_capacity=500): #NOTE: BE SURE TO ALLOCATE LOTS OF MEM
      self.src_to_tgt_pairs, self.tgt_to_src_pairs = get_matching_url_dicts()
      self.embed_dict = load_embed_dict(chunks_paths_list)
      self.data_list = build_pair_dataset(self.embed_dict, self.src_to_tgt_pairs, self.tgt_to_src_pairs)
      
      self.src_lang = src_lang
      self.tgt_lang = tgt_lang
      
      if doc_vector_method not in ['AVG', 'AVG_BP', 'SENT_ORDER']:
        raise ValueError("""Doc Vec Method must be one of the following:
                            AVF, AVG_BP, SENT_ORDER
                            Not found: %s
                            """ % doc_vector_method)
      self.doc_vector_method = doc_vector_method
      self.pert_obj = ModifiedPertV2(None, None)
      self.lru_cache = LRUCache(cache_capacity)
      
      #self.lock = Lock()

    def __len__(self):
        """
        Get length of the dataset
        """
        return len(self.data_list)

    def __getitem__(self,
                    idx):
        """
        Gets the two vectors and target
        """
        _, _, src_url, tgt_url, y_match_label = self.data_list[idx]
        #print("CACHE LIST: %s " %  list(self.lru_cache.keys())) #REMOVE AFTER DEBUG
        #with self.lock:
        #print(y_match_label, "Y LABEL", src_url, tgt_url, idx)
        #TODO: TEMP, only to test training loop
        src_doc_embedding, tgt_doc_embedding = load_embed_pairs(src_url, tgt_url, 
                                                                self.embed_dict,
                                                                src_lang=self.src_lang,
                                                                tgt_lang=self.tgt_lang)
        src_doc_embedding, tgt_doc_embedding = src_doc_embedding[0], tgt_doc_embedding[0]
        
        '''                                               
        src_doc_embedding, tgt_doc_embedding = handle_doc_embed_logic(self.embed_dict,
                                                                      src_url, tgt_url,
                                                                      self.src_lang, self.tgt_lang,
                                                                      self.doc_vector_method,
                                                                      self.pert_obj,
                                                                      self.lru_cache) '''
        return src_doc_embedding, tgt_doc_embedding, y_match_label


In [26]:
#Create dataloader
train_dataset=DocEmbedDataset(chunks_paths_train[:2], SRC_LANG_CODE, TGT_LANG_CODE) #TODO: Remove after done debug of loop
validation_dataset=DocEmbedDataset(chunks_paths_val[:2], SRC_LANG_CODE, TGT_LANG_CODE)
test_dataset=DocEmbedDataset(chunks_paths_test[:1], SRC_LANG_CODE, TGT_LANG_CODE)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) #prefetch_factor=5, prefetch maybe could help 
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

Function create_positive_samples took 0.02 seconds
1306
URL LIST LENGTHS 1306 2581 1306 1267
LOOP COUNTER LEN 0
2612
2612
Function create_positive_samples took 0.01 seconds
1297
URL LIST LENGTHS 1297 2537 1297 1231
LOOP COUNTER LEN 0
2594
2594
Function create_positive_samples took 0.01 seconds
653
URL LIST LENGTHS 653 1287 653 630
LOOP COUNTER LEN 0
1306
1306


In [27]:
idx = 0
print(len(train_dataset), len(validation_dataset), len(test_dataloader))
#TODO: Continue to debug lack of embeds issue
'''for i, vals in enumerate(train_dataloader):
    x_1, x_2, y = vals
    print(i, x_1.shape, x_2.shape, y, "THIS BE PRINTING")
    idx += 1
    if idx > 1:
        break'''


2612 2594 1306


'for i, vals in enumerate(train_dataloader):\n    x_1, x_2, y = vals\n    print(i, x_1.shape, x_2.shape, y, "THIS BE PRINTING")\n    idx += 1\n    if idx > 1:\n        break'

### Defining Plotting Logic

In [28]:
#Define loss func
#Plot loss function
def plot_loss_charts(train_loss_store, validation_loss_store):
  '''
  Plots loss charts over course of training
  '''
  ## Plotting epoch-wise test loss curve:
  plt.plot(train_loss_store, '-o', label = 'train_loss', color = 'orange')
  plt.plot(validation_loss_store, label = 'validation_loss', color = 'blue')
  plt.xlabel('Epoch Number')
  plt.ylabel('Loss At each epoch')
  plt.legend()
  plt.show()


def plot_accuracy(train_score_store, validation_score_store, skip_plot=False):
  '''
  Plots Accuracy charts over course of training
  '''
  #Don't plot if this flag is set to true
  if skip_plot:
    return

  ## Plotting epoch-wise test loss curve:
  plt.plot(train_score_store, '-o', label = 'train_accuracy', color = 'orange')
  plt.plot(validation_score_store, label = 'validation_accuracy', color = 'blue')
  plt.xlabel('Epoch Number')
  plt.ylabel('Accuracy At each epoch')
  plt.legend()
  plt.show()

### Creating the model

In [29]:
class DocAlignerClassifier(nn.Module):
    def __init__(self, input_size, hidden_size_one=256, hidden_size_two=128, hidden_size_three=64):
        super(DocAlignerClassifier, self).__init__()
        
        self.fc1 = nn.Linear(input_size, hidden_size_one)
        self.fc2 = nn.Linear(hidden_size_one, hidden_size_two)
        self.fc3 = nn.Linear( hidden_size_two, hidden_size_three)
        self.fc_out = nn.Linear(hidden_size_three, 1)

        self.relu = nn.ReLU()
        
        #Optional
        self.dropout = nn.Dropout(p=0.1)
        self.batchnorm1 = nn.BatchNorm1d(64)
                
    def forward(self, src_doc_embed, tgt_doc_embed):
        '''
        Forward data through the lstm
        '''
        combined_input = torch.cat((src_doc_embed.view(src_doc_embed.size(0), -1),
                          tgt_doc_embed.view(tgt_doc_embed.size(0), -1)), dim=1)
        #print(combined_input.shape, "combed input")

        x = self.relu(self.fc1(combined_input))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        output = self.fc_out(x)

        return output

### Training Loop

In [30]:
##Define Loss and Accuracy Eval Function
def eval_acc_and_loss_func(model, loader, device, loss_metric, is_train = False, verbose = 1):
    '''
    Evaluate Function for CNN training
    Slightly different than eval function from part 1
    '''
    correct, total, loss_sum = 0, 0, 0
    
    eval_type = "Train" if is_train else "Validation"
    for X_1, X_2, Y in loader:
        outputs, predicted, calculated_loss = None, None, None
        X_1, X_2, Y = X_1.to(device), X_2.to(device), Y.to(device)

        outputs = model(X_1, X_2)
        
        #Reshape output and turn y into a float
        outputs = outputs.view(-1)
        Y = Y.float()
        predicted = torch.round(torch.sigmoid(outputs))
        
        total += Y.size(0)
        
        correct += (predicted == Y).sum().item()
        calculated_loss = loss_metric(outputs,Y).item()
        loss_sum += calculated_loss
        
    outputs, predicted, calculated_loss = None, None, None
    if verbose:
        print('%s accuracy: %f %%' % (eval_type, 100.0 * correct / total))
        print('%s loss: %f' % (eval_type, loss_sum / total))
    print
    return 100.0 * correct / total, loss_sum/ total

In [31]:
#DEFINE TRAIN LOOP HERE
def train(model,
          optimizer,
          loss_metric,
          lr,
          train_dataloader,
          valid_dataloader,
          device,
          epochs=5,
          stopping_threshold=3,
          saving_per_epoch=10,
          base_save_path="",
          model_name="",
          **kwargs):
    """
    For each epoch, loop through batch,
    compute forward and backward passes, apply gradient updates
    Evaluate results and output
    """


    train_loss_store, train_acc_store = [], []
    val_loss_store, val_acc_store, = [], []
    start_epoch = 0

    #Declare variables for early stopping
    last_val_loss, stop_tracker = 100, 0

    #training loop:
    print("Starting Training")
    for epoch in range(start_epoch, epochs):
      time1 = time.time() #timekeeping
      outputs, loss = None, None

      model.train()
      
      correct_train, total, loss_sum = 0, 0, 0
      for i, (x_1, x_2, y) in enumerate(train_dataloader):
        
        # Print device human readable names
        #torch.cuda.get_device_name()

        x_1, x_2, y = x_1.to(device), x_2.to(device), y.to(device)

        #loss calculation and gradient update:

        if i > 0 or epoch > 0:
          optimizer.zero_grad()
        outputs = model.forward(x_1, x_2)
        
        #Reshape output and turn y into a float
        outputs = outputs.view(-1)
        y = y.float()
        #print(y.shape, outputs.shape, outputs, "Loss inp info")

        
        loss = loss_metric(outputs, y)
        loss.backward()
                      
        ##performing update:
        optimizer.step()

        #Update Loss Info
        loss_sum += loss.item()
        predicted = torch.round(torch.sigmoid(outputs))
        total += y.size(0)
        correct_train += (predicted == y).sum().item()
              
      print("Epoch",epoch+1,':')

      
      model.eval()
      with torch.no_grad():

        #Print Train Info
        print('%s accuracy: %f %%' % ("Train", 100.0 * correct_train / total))
        print('%s loss: %f' % ("Train", loss_sum / total))
        print
        
        train_acc, train_loss = 100.0 * correct_train / total, loss_sum/ total
        val_acc, val_loss = eval_acc_and_loss_func(model, valid_dataloader, device, loss_metric, is_train = False)

        val_acc_store.append(val_acc)
        val_loss_store.append(val_loss)

        train_loss_store.append(train_loss)
        train_acc_store.append(train_acc)

      time2 = time.time() #timekeeping
      #if show_progress:
      print('Elapsed time for epoch:',time2 - time1,'s')
      print('ETA of completion:',(time2 - time1)*(epochs - epoch - 1)/60,'minutes')
      


      #Handle early stopping logic
      if val_loss >= last_val_loss:
            stop_tracker += 1
            if stop_tracker >= stopping_threshold:
                print('Early Stopping triggered, Convergence has occured')
                plot_loss_charts(train_loss_store, val_loss_store)
                plot_accuracy(train_acc_store, val_acc_store)
                #base_save_path, model_name, epoch + 1, model, optimizer)
                #print("Model Copy Saved")

                return train_loss_store, val_acc_store
      else:
          stop_tracker = 0
      last_val_loss = val_loss


    plot_loss_charts(train_loss_store, val_loss_store)
    plot_accuracy(train_acc_store, val_acc_store)
    return train_loss_store, val_acc_store

In [32]:
#Define the model, and kick off training
LEARNING_RATE = 0.001 #NOTE: CAN ADJUST
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DocAlignerClassifier(2048) #Each doc vec is 2048, so times 2 will be 4096 TODO: Replace with 4096
model = model.to(device)

optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
loss_criterion = nn.BCEWithLogitsLoss()

In [33]:
#TODO: Test this with just base sent embeddings (as a place holder) just to see that training is working
# Since its much faster than using doc vector
train(
        model,
        optimizer,
        loss_criterion,
        LEARNING_RATE,
        train_dataloader,
        validation_dataloader,
        device,
        epochs=5
    )

Starting Training
Epoch 1 :
Train accuracy: 50.000000 %
Train loss: 0.021812
Validation accuracy: 50.000000 %
Validation loss: 0.021949
Elapsed time for epoch: 3069.637231826782 s
ETA of completion: 204.6424821217855 minutes
Epoch 2 :
Train accuracy: 50.000000 %
Train loss: 0.021789
Validation accuracy: 50.000000 %
Validation loss: 0.021930
Elapsed time for epoch: 2985.1118845939636 s
ETA of completion: 149.25559422969818 minutes
Epoch 3 :
Train accuracy: 50.000000 %
Train loss: 0.021778
Validation accuracy: 50.000000 %
Validation loss: 0.021922
Elapsed time for epoch: 2978.018518924713 s
ETA of completion: 99.2672839641571 minutes
Epoch 4 :
Train accuracy: 50.000000 %
Train loss: 0.021771
Validation accuracy: 50.000000 %
Validation loss: 0.021918
Elapsed time for epoch: 2978.5933632850647 s
ETA of completion: 49.64322272141774 minutes
