Skip to content

Commit

Permalink
Fix norm gradient explosion
Browse files Browse the repository at this point in the history
  • Loading branch information
daemon committed Feb 5, 2018
1 parent a110b03 commit b06547b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 17 deletions.
9 changes: 5 additions & 4 deletions vdpwi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import data
import model as mod

Context = namedtuple("Context", "model, train_loader, dev_loader, test_loader, optimizer, criterion")
Context = namedtuple("Context", "model, train_loader, dev_loader, test_loader, optimizer, criterion, params")
EvaluateResult = namedtuple("EvaluateResult", "pearsonr, spearmanr")

def create_context(config):
Expand Down Expand Up @@ -45,12 +45,12 @@ def collate_fn(batch):

train_loader = utils.data.DataLoader(train_set, shuffle=True, batch_size=1, collate_fn=collate_fn)
dev_loader = utils.data.DataLoader(dev_set, batch_size=1, collate_fn=collate_fn)
test_loader = utils.data.DataLoader(dev_set, batch_size=1, collate_fn=collate_fn)
test_loader = utils.data.DataLoader(test_set, batch_size=1, collate_fn=collate_fn)

params = list(filter(lambda x: x.requires_grad, model.parameters()))
optimizer = optim.RMSprop(params, lr=config.lr, alpha=config.decay, momentum=config.momentum)
criterion = nn.KLDivLoss()
return Context(model, train_loader, dev_loader, test_loader, optimizer, criterion)
return Context(model, train_loader, dev_loader, test_loader, optimizer, criterion, params)

def test(config):
pass
Expand All @@ -72,13 +72,14 @@ def train(config):
for epoch_no in range(config.n_epochs):
print("Epoch number: {}".format(epoch_no + 1))
loader_wrapper = tqdm(enumerate(context.train_loader), total=len(context.train_loader), desc="Loss")
context.model.train()
for i, (sent1, sent2, label_pmf) in loader_wrapper:
context.model.train()
context.optimizer.zero_grad()
scores = F.log_softmax(context.model(sent1, sent2))

loss = context.criterion(scores, label_pmf)
loss.backward()
nn.utils.clip_grad_norm(context.params, 50)
loader_wrapper.set_description("Loss = {}".format(loss.cpu().data[0]))
context.optimizer.step()
result = evaluate(context.model, context.dev_loader)
Expand Down
4 changes: 2 additions & 2 deletions vdpwi/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def base_config():
parser.add_argument("--decay", type=float, default=0.95)
parser.add_argument("--input_model", type=str, default="local_saves/model.pt")
parser.add_argument("--lr", type=float, default=1E-4)
parser.add_argument("--mbatch_size", type=int, default=40)
parser.add_argument("--mbatch_size", type=int, default=1)
parser.add_argument("--mode", type=str, default="train", choices=["train", "test"])
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--n_epochs", type=int, default=40)
parser.add_argument("--n_labels", type=int, default=5)
parser.add_argument("--output_model", type=str, default="local_saves/model.pt")
parser.add_argument("--restore", action="store_true", default=False)
parser.add_argument("--rnn_hidden_dim", type=int, default=300)
parser.add_argument("--rnn_hidden_dim", type=int, default=250)
parser.add_argument("--wordvecs_file", type=str, default="local_data/glove/glove.840B.300d.txt")
return parser.parse_known_args()[0]

Expand Down
29 changes: 18 additions & 11 deletions vdpwi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,23 @@ def pad_side(idx, max_size):
class VDPWIModel(SerializableModule):
def __init__(self, embedding, config, classifier_net=None):
super().__init__()
self.rnn = nn.LSTM(300, config.rnn_hidden_dim, 1, bidirectional=True, batch_first=True)
self.hidden_dim = config.rnn_hidden_dim
self.rnn = nn.LSTM(300, self.hidden_dim, 1, batch_first=True)
self.embedding = embedding
self.use_cuda = not config.cpu
self.classifier_net = VDPWIConvNet(config.n_labels) if classifier_net is None else classifier_net

def compute_sim_cube(self, seq1, seq2):
def compute_sim(prism1, prism2):
prism1_len = torch.sqrt(torch.sum(prism1**2, 2))
prism2_len = torch.sqrt(torch.sum(prism2**2, 2))
prism1_len = prism1.norm(dim=2)
prism2_len = prism2.norm(dim=2)

dot_prod = torch.matmul(prism1.unsqueeze(2), prism2.unsqueeze(3))
dot_prod = dot_prod.squeeze(2).squeeze(2)
cos_dist = dot_prod / (prism1_len * prism2_len + 1E-8)
l2_dist = torch.sqrt(torch.sum((prism1 - prism2)**2, 2))
l2_dist = (prism1 - prism2).norm(dim=2)
return torch.stack([dot_prod, cos_dist, l2_dist], 0)

def compute_prism(seq1, seq2):
prism1 = seq1.repeat(seq2.size(0), 1, 1)
prism2 = seq2.repeat(seq1.size(0), 1, 1)
Expand All @@ -75,12 +77,13 @@ def compute_prism(seq1, seq2):
return compute_sim(prism1, prism2)

sim_cube = Variable(torch.Tensor(13, seq1.size(0), seq2.size(0)))
sim_cube[12] = 0
if self.use_cuda:
sim_cube = sim_cube.cuda()
seq1_f = seq1[:, :300]
seq1_b = seq1[:, 300:]
seq2_f = seq2[:, :300]
seq2_b = seq2[:, 300:]
seq1_f = seq1[:, :self.hidden_dim]
seq1_b = seq1[:, self.hidden_dim:]
seq2_f = seq2[:, :self.hidden_dim]
seq2_b = seq2[:, self.hidden_dim:]
sim_cube[0:3] = compute_prism(seq1, seq2)
sim_cube[3:6] = compute_prism(seq1_f, seq2_f)
sim_cube[6:9] = compute_prism(seq1_b, seq2_b)
Expand All @@ -103,16 +106,20 @@ def build_mask(index):
if s1tag[pos1] + s2tag[pos2] == 0:
s1tag[pos1] = s2tag[pos2] = 1
mask[:, int(pos1), int(pos2)] = 1
build_mask(9)
build_mask(10)
build_mask(11)
mask[12, :, :] = 1
return mask * sim_cube

def forward(self, x1, x2):
x1 = self.embedding(x1)
x2 = self.embedding(x2)
seq1, _ = self.rnn(x1)
seq2, _ = self.rnn(x2)
seq1f, _ = self.rnn(x1)
seq2f, _ = self.rnn(x2)
seq1b, _ = self.rnn(torch.cat(x1.split(1, 1)[::-1], 1))
seq2b, _ = self.rnn(torch.cat(x2.split(1, 1)[::-1], 1))
seq1 = torch.cat([seq1f, seq1b], 2)
seq2 = torch.cat([seq2f, seq2b], 2)
seq1 = seq1.squeeze(0) # batch size assumed to be 1
seq2 = seq2.squeeze(0)
sim_cube = self.compute_sim_cube(seq1, seq2)
Expand Down

0 comments on commit b06547b

Please sign in to comment.