Skip to content

Commit

Permalink
Added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
svjan5 committed May 27, 2019
1 parent 9d02afe commit 7cb43c0
Show file tree
Hide file tree
Showing 4 changed files with 531 additions and 149 deletions.
28 changes: 5 additions & 23 deletions batch_generator.cpp
Expand Up @@ -93,13 +93,11 @@ int getBatch( int *edges, // Edges in the sentence graph
int num_neg, // Number of negtive samples int num_neg, // Number of negtive samples
int batch_size, // Batchsize int batch_size, // Batchsize
float sample, // Paramter for deciding rate of subsampling float sample, // Paramter for deciding rate of subsampling
int mode // mode=0: only dependency edges, mode=1: only context, mode=3: both dependency and context
) { ) {


cnt_edges = 0, cnt_wrds = 0, cnt_negs = 0, cnt_sample = 0; // Count of number of edges, words, negs, samples in the entire batch cnt_edges = 0, cnt_wrds = 0, cnt_negs = 0, cnt_sample = 0; // Count of number of edges, words, negs, samples in the entire batch


if(mode == 0 || mode == 2) cntxt_edge_label = de2id.size(); cntxt_edge_label = de2id.size();
else cntxt_edge_label = 0;


for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
b_elen = 0, b_wlen = 0; // Count of number of edges and word in particular element of batch b_elen = 0, b_wlen = 0; // Count of number of edges and word in particular element of batch
Expand Down Expand Up @@ -134,26 +132,10 @@ int getBatch( int *edges, // Edges in the sentence graph


for(j = 0; j < num_deps; j++){ // Including dependency edges for(j = 0; j < num_deps; j++){ // Including dependency edges
tmp = fscanf(fin, "%d|%d|%d ", &src, &dest, &lbl); tmp = fscanf(fin, "%d|%d|%d ", &src, &dest, &lbl);
if (mode == 0 || mode == 2){ edges[cnt_edges*3 + 0] = src;
edges[cnt_edges*3 + 0] = src; edges[cnt_edges*3 + 1] = dest;
edges[cnt_edges*3 + 1] = dest; edges[cnt_edges*3 + 2] = lbl;
edges[cnt_edges*3 + 2] = lbl; cnt_edges++; b_elen++;
cnt_edges++; b_elen++;
}
}

if (mode == 1 || mode == 2){
for(k = 0; k < num_wrds; k++){ // Including context edges
for(j=-win_size; j<=win_size; j++){
idx = k + j;
if (idx >=0 && idx < num_wrds && idx != k){
edges[cnt_edges*3 + 0] = idx;
edges[cnt_edges*3 + 1] = k;
edges[cnt_edges*3 + 2] = cntxt_edge_label;
cnt_edges++; b_elen++;
}
}
}
} }


wlen[i] = b_wlen; wlen[i] = b_wlen;
Expand Down
83 changes: 73 additions & 10 deletions helper.py
Expand Up @@ -8,17 +8,34 @@


np.set_printoptions(precision=4) np.set_printoptions(precision=4)


def mergeList(list_of_list):
return list(itertools.chain.from_iterable(list_of_list))

def checkFile(filename):
return pathlib.Path(filename).is_file()

def set_gpu(gpus): def set_gpu(gpus):
"""
Sets the GPU to be used for the run
Parameters
----------
gpus: List of GPUs to be used for the run
Returns
-------
"""
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpus os.environ["CUDA_VISIBLE_DEVICES"] = gpus


def debug_nn(res_list, feed_dict): def debug_nn(res_list, feed_dict):
"""
Function for debugging Tensorflow model
Parameters
----------
res_list: List of tensors/variables to view
feed_dict: Feed dict required for getting values
Returns
-------
Returns the list of values of given tensors/variables after execution
"""
import tensorflow as tf import tensorflow as tf


config = tf.ConfigProto() config = tf.ConfigProto()
Expand All @@ -30,6 +47,20 @@ def debug_nn(res_list, feed_dict):
return res return res


def get_logger(name, log_dir, config_dir): def get_logger(name, log_dir, config_dir):
"""
Creates a logger object
Parameters
----------
name: Name of the logger file
log_dir: Directory where logger file needs to be stored
config_dir: Directory from where log_config.json needs to be read
Returns
-------
A logger object which writes to both file and stdout
"""
config_dict = json.load(open( config_dir + 'log_config.json')) config_dict = json.load(open( config_dir + 'log_config.json'))
config_dict['handlers']['file_handler']['filename'] = log_dir + name.replace('/', '-') config_dict['handlers']['file_handler']['filename'] = log_dir + name.replace('/', '-')
logging.config.dictConfig(config_dict) logging.config.dictConfig(config_dict)
Expand All @@ -42,14 +73,33 @@ def get_logger(name, log_dir, config_dir):


return logger return logger


def partition(lst, n):
division = len(lst) / float(n)
return [ lst[int(round(division * i)): int(round(division * (i + 1)))] for i in range(n) ]

def getChunks(inp_list, chunk_size): def getChunks(inp_list, chunk_size):
"""
Splits inp_list into lists of size chunk_size
Parameters
----------
inp_list: List to be splittted
chunk_size: Size of each chunk required
Returns
-------
chunks of the inp_list each of size chunk_size, last one can be smaller (leftout data)
"""
return [inp_list[x:x+chunk_size] for x in range(0, len(inp_list), chunk_size)] return [inp_list[x:x+chunk_size] for x in range(0, len(inp_list), chunk_size)]


def read_mappings(fname): def read_mappings(fname):
"""
A helper function for reading an object to identifier mapping
Parameters
----------
fname: Name of the file containing mapping
Returns
-------
mapping: Dictionary object containing mapping information
"""
mapping = {} mapping = {}
for line in open(fname): for line in open(fname):
vals = line.strip().split('\t') vals = line.strip().split('\t')
Expand All @@ -58,6 +108,19 @@ def read_mappings(fname):
return mapping return mapping


def getEmbeddings(embed_loc, wrd_list, embed_dims): def getEmbeddings(embed_loc, wrd_list, embed_dims):
"""
Gives embedding for each word in wrd_list
Parameters
----------
model: Word2vec model
wrd_list: List of words for which embedding is required
embed_dims: Dimension of the embedding
Returns
-------
embed_matrix: (len(wrd_list) x embed_dims) matrix containing embedding for each word in wrd_list in the same order
"""
embed_list = [] embed_list = []


wrd2embed = {} wrd2embed = {}
Expand Down

0 comments on commit 7cb43c0

Please sign in to comment.