Skip to content
Permalink
Browse files

Tidy up some code (#174)

  • Loading branch information...
daemon committed Feb 6, 2019
1 parent 84225d2 commit 385525487070dea91dc91110369ab44c88af273c
@@ -17,8 +17,6 @@
from datasets.yelp2014 import Yelp2014CharQuantized as Yelp2014




class UnknownWordVecCache(object):
"""
Caches the first randomly generated word vector for a certain size to make it is reused.
@@ -6,7 +6,7 @@
def get_args():
parser = ArgumentParser(description="Kim CNN")
parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda')
parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU
parser.add_argument('--gpu', type=int, default=0, help='Use -1 for CPU')
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.001)
@@ -5,8 +5,9 @@


class CharCNN(nn.Module):

def __init__(self, config):
super(CharCNN, self).__init__()
super().__init__()
self.is_cuda_enabled = config.cuda
dataset = config.dataset
num_conv_filters = config.num_conv_filters
@@ -15,7 +16,7 @@ def __init__(self, config):
target_class = config.target_class
input_channel = 68

self.conv1 = nn.Conv1d(input_channel, num_conv_filters, kernel_size=7) # Default padding=0
self.conv1 = nn.Conv1d(input_channel, num_conv_filters, kernel_size=7)
self.conv2 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=7)
self.conv3 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=3)
self.conv4 = nn.Conv1d(num_conv_filters, num_conv_filters, kernel_size=3)
@@ -72,7 +72,6 @@ def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, d
train_loader, dev_loader, test_loader = PIT2015.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
embedding = nn.Embedding.from_pretrained(PIT2015.TEXT_FIELD.vocab.vectors)
return PIT2015, embedding, train_loader, test_loader, dev_loader

elif dataset_name == 'snli':
dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'snli_1.0/')
train_loader, dev_loader, test_loader = SNLI.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
@@ -44,11 +44,7 @@ def get_evaluator(dataset_cls, model, embedding, data_loader, batch_size, device
if data_loader is None:
return None

if nce:
evaluator_map = EvaluatorFactory.evaluator_map_nce
else:
evaluator_map = EvaluatorFactory.evaluator_map

evaluator_map = EvaluatorFactory.evaluator_map_nce if nce else EvaluatorFactory.evaluator_map
if not hasattr(dataset_cls, 'NAME'):
raise ValueError('Invalid dataset. Dataset should have NAME attribute.')

@@ -4,10 +4,13 @@
import numpy as np
import torch.utils.data as data


def sst_tokenize(sentence):
return sentence.split()


class SSTEmbeddingLoader(object):

def __init__(self, dirname, fmt="stsa.fine.{}", word2vec_file="word2vec.sst-1"):
self.dirname = dirname
self.fmt = fmt
@@ -30,7 +33,9 @@ def load_embed_data(self):
unk_vocab_set.add(word)
return (id_dict, np.array(weights), list(unk_vocab_set))


class SSTDataset(data.Dataset):

def __init__(self, sentences):
super().__init__()
self.sentences = sentences
@@ -10,6 +10,7 @@


class ConvRNNModel(nn.Module):

def __init__(self, word_model, **config):
super().__init__()
embedding_dim = word_model.dim
@@ -97,7 +98,9 @@ def forward(self, x):
def lookup(self, sentences):
raise NotImplementedError


class SSTWordEmbeddingModel(WordEmbeddingModel):

def __init__(self, id_dict, weights, unknown_vocab=[]):
super().__init__(id_dict, weights, unknown_vocab, padding_idx=16259)

@@ -120,6 +123,7 @@ def lookup(self, sentences):
indices.extend([self.padding_idx] * (max_len - len(indices)))
return indices_list, lengths


def set_seed(seed=0, no_cuda=False):
np.random.seed(seed)
if not no_cuda:
@@ -84,6 +84,7 @@ def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, d
train, val, test = cls.splits(path)
return BucketIterator.splits((train, val, test), batch_size=batch_size, repeat=False, shuffle=shuffle, device=device)


class AAPDHierarchical(AAPD):
NESTING_FIELD = Field(batch_first=True, tokenize=clean_string)
TEXT_FIELD = NestedField(NESTING_FIELD, tokenize=split_sents)
@@ -53,7 +53,7 @@ def __init__(self, path, load_ext_feats=False):
example = Example.fromlist(example_list, fields)
examples.append(example)

super(CastorPairDataset, self).__init__(examples, fields)
super().__init__(examples, fields)

@classmethod
def set_vectors(cls, field, vector_path):
@@ -7,6 +7,7 @@

from datasets.castor_dataset import CastorPairDataset


def get_class_probs(sim, *args):
"""
Convert a single label into class probabilities.
@@ -42,7 +42,7 @@ def __init__(self, path):
"""
Create a SICK dataset instance
"""
super(SICK, self).__init__(path)
super().__init__(path)

@classmethod
def splits(cls, path, train='train', validation='dev', test='test', **kwargs):
@@ -7,6 +7,7 @@

from datasets.castor_dataset import CastorPairDataset


def get_class_probs(sim, *args):
"""
Convert a single label into class probabilities.
@@ -27,11 +27,11 @@ def __init__(self, path):
"""
Create a TRECQA dataset instance
"""
super(TRECQA, self).__init__(path, load_ext_feats=True)
super().__init__(path, load_ext_feats=True)

@classmethod
def splits(cls, path, train='train-all', validation='raw-dev', test='raw-test', **kwargs):
return super(TRECQA, cls).splits(path, train=train, validation=validation, test=test, **kwargs)
return super().splits(path, train=train, validation=validation, test=test, **kwargs)

@classmethod
def iters(cls, path, vectors_name, vectors_dir, batch_size=64, shuffle=True, device=0, pt_file=False, vectors=None, unk_init=torch.Tensor.zero_):
@@ -31,7 +31,7 @@ def __init__(self, path):

@classmethod
def splits(cls, path, train='train', validation='dev', test='test', **kwargs):
return super(WikiQA, cls).splits(path, train=train, validation=validation, test=test, **kwargs)
return super().splits(path, train=train, validation=validation, test=test, **kwargs)

@classmethod
def iters(cls, path, vectors_name, vectors_dir, batch_size=64, shuffle=True, device=0, pt_file=False, vectors=None,
@@ -10,6 +10,7 @@


class DecAtt(nn.Module):

def __init__(self, num_units, num_classes, embedding_size, dropout, device=0,
training=True, project_input=True,
use_intra_attention=False, distance_biases=10, max_sentence_length=30):
@@ -20,7 +20,6 @@
from han.model import HAN



class UnknownWordVecCache(object):
"""
Caches the first randomly generated word vector for a certain size to make it is reused.
@@ -1,6 +1,5 @@
import os

from argparse import ArgumentParser
import os


def get_args():
@@ -1,28 +1,29 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
#from utils import
import torch.nn.functional as F

from han.sent_level_rnn import SentLevelRNN
from han.word_level_rnn import WordLevelRNN


class HAN(nn.Module):
def __init__(self, config):
super(HAN, self).__init__()
dataset = config.dataset
self.mode = config.mode
self.word_attention_rnn = WordLevelRNN(config)
self.sentence_attention_rnn = SentLevelRNN(config)
def forward(self, x, **kwargs):
x = x.permute(1,2,0) ## Expected : #sentences, #words, batch size
num_sentences = x.size()[0]
word_attentions = None
for i in range(num_sentences):
_word_attention = self.word_attention_rnn(x[i,:,:])
if word_attentions is None:
word_attentions = _word_attention
else:
word_attentions = torch.cat((word_attentions, _word_attention),0)
return self.sentence_attention_rnn(word_attentions)

def __init__(self, config):
super().__init__()
dataset = config.dataset
self.mode = config.mode
self.word_attention_rnn = WordLevelRNN(config)
self.sentence_attention_rnn = SentLevelRNN(config)

def forward(self, x, **kwargs):
x = x.permute(1, 2, 0) # Expected : # sentences, # words, batch size
num_sentences = x.size(0)
word_attentions = None
for i in range(num_sentences):
word_attn = self.word_attention_rnn(x[i, :, :])
if word_attentions is None:
word_attentions = word_attn
else:
word_attentions = torch.cat((word_attentions, word_attn), 0)
return self.sentence_attention_rnn(word_attentions)

@@ -1,32 +1,29 @@
import torch
import torch.nn as nn
from torch.autograd import Variable

import torch.nn.functional as F

class SentLevelRNN(nn.Module):

def __init__(self, config):
super(SentLevelRNN, self).__init__()
super().__init__()
dataset = config.dataset
sentence_num_hidden = config.sentence_num_hidden
word_num_hidden = config.word_num_hidden
target_class = config.target_class
self.sentence_context_wghts = nn.Parameter(torch.rand(2*sentence_num_hidden, 1))
self.sentence_context_wghts.data.uniform_(-0.1, 0.1)
self.sentence_GRU = nn.GRU(2*word_num_hidden, sentence_num_hidden, bidirectional = True)
self.sentence_linear = nn.Linear(2*sentence_num_hidden, 2*sentence_num_hidden, bias = True)
self.fc = nn.Linear(2*sentence_num_hidden , target_class)
self.sentence_context_weights = nn.Parameter(torch.rand(2 * sentence_num_hidden, 1))
self.sentence_context_weights.data.uniform_(-0.1, 0.1)
self.sentence_gru = nn.GRU(2 * word_num_hidden, sentence_num_hidden, bidirectional=True)
self.sentence_linear = nn.Linear(2 * sentence_num_hidden, 2 * sentence_num_hidden, bias=True)
self.fc = nn.Linear(2 * sentence_num_hidden , target_class)
self.soft_sent = nn.Softmax()
self.final_log_soft = F.log_softmax

def forward(self,x):
sentence_h,_ = self.sentence_GRU(x)
x = torch.tanh(self.sentence_linear(sentence_h))
x = torch.matmul(x, self.sentence_context_wghts)
x = x.squeeze(dim=2)
x = self.soft_sent(x.transpose(1,0))
x = torch.mul(sentence_h.permute(2,0,1), x.transpose(1,0))
x = torch.sum(x,dim = 1).transpose(1,0).unsqueeze(0)
#x = self.final_log_soft(self.fc(x.squeeze(0)))
x = self.fc(x.squeeze(0))
return x
sentence_h,_ = self.sentence_gru(x)
x = torch.tanh(self.sentence_linear(sentence_h))
x = torch.matmul(x, self.sentence_context_weights)
x = x.squeeze(dim=2)
x = self.soft_sent(x.transpose(1,0))
x = torch.mul(sentence_h.permute(2, 0, 1), x.transpose(1, 0))
x = torch.sum(x, dim=1).transpose(1, 0).unsqueeze(0)
x = self.fc(x.squeeze(0))
return x
@@ -1,49 +1,48 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F


class WordLevelRNN(nn.Module):

def __init__(self, config):
super(WordLevelRNN, self).__init__()
super().__init__()
dataset = config.dataset
word_num_hidden = config.word_num_hidden
words_num = config.words_num
words_dim = config.words_dim
self.mode = config.mode
if self.mode == 'rand':
rand_embed_init = torch.Tensor(words_num, words_dim).uniform(-0.25, 0.25)
self.embed = nn.Embedding.from_pretrained(rand_embed_init, freeze = False)
rand_embed_init = torch.Tensor(words_num, words_dim).uniform(-0.25, 0.25)
self.embed = nn.Embedding.from_pretrained(rand_embed_init, freeze=False)
elif self.mode == 'static':
self.static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze = True)
self.static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=True)
elif self.mode == 'non-static':
self.non_static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze = False)
self.non_static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=False)
else:
print("Unsupported order")
exit()
self.word_context_wghts = nn.Parameter(torch.rand(2*word_num_hidden,1))
self.GRU = nn.GRU(words_dim, word_num_hidden, bidirectional = True)
self.linear = nn.Linear(2*word_num_hidden, 2*word_num_hidden, bias = True)
self.word_context_wghts.data.uniform_(-0.25, 0.25)
print("Unsupported order")
exit()
self.word_context_weights = nn.Parameter(torch.rand(2 * word_num_hidden, 1))
self.GRU = nn.GRU(words_dim, word_num_hidden, bidirectional=True)
self.linear = nn.Linear(2 * word_num_hidden, 2 * word_num_hidden, bias=True)
self.word_context_weights.data.uniform_(-0.25, 0.25)
self.soft_word = nn.Softmax()

def forward(self, x):
##################
## x expected to be of dimensions--> (num_words, batch_size)
# x expected to be of dimensions--> (num_words, batch_size)
if self.mode == 'rand':
x = self.embed(x)
x = self.embed(x)
elif self.mode == 'static':
x = self.static_embed(x)
x = self.static_embed(x)
elif self.mode == 'non-static':
x = self.non_static_embed(x)
x = self.non_static_embed(x)
else :
print("Unsuported mode")
exit()
h,_ = self.GRU(x)
print("Unsupported mode")
exit()
h, _ = self.GRU(x)
x = torch.tanh(self.linear(h))
x = torch.matmul(x, self.word_context_wghts)
x = torch.matmul(x, self.word_context_weights)
x = x.squeeze(dim=2)
x = self.soft_word(x.transpose(1,0))
x = torch.mul(h.permute(2,0,1), x.transpose(1,0))
x = torch.sum(x, dim = 1).transpose(1,0).unsqueeze(0)

x = self.soft_word(x.transpose(1, 0))
x = torch.mul(h.permute(2, 0, 1), x.transpose(1, 0))
x = torch.sum(x, dim=1).transpose(1, 0).unsqueeze(0)
return x
Oops, something went wrong.

0 comments on commit 3855254

Please sign in to comment.
You can’t perform that action at this time.