Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 4, 2019
1 parent 2c3d776 commit a140cc5
Showing 1 changed file with 93 additions and 134 deletions.
227 changes: 93 additions & 134 deletions examples/tree.py
Expand Up @@ -12,9 +12,12 @@
config = {"method": "reinforce", "baseline": "mean", "opt": "adadelta",
"lr_struct": 0.1, "lr_params": 1, "train_model":True,
"var_norm": False, "entropy": 0.001, "v": 3, "RL_K": 5,
"H" : 100,
"H" : 100, "train_len": 100, "div_ent" : 1
}


NAME = "yoyo3"

wandb.init(project="pytorch-struct", config=config)

def clip(p):
Expand All @@ -34,7 +37,7 @@ def expand_spans(spans, words, K, V):
return new_spans


def valid_sup(valid_iter, model, tree_lstm):
def valid_sup(valid_iter, model, tree_lstm, V):
total = 0
correct = 0
struct = CKY_CRF
Expand All @@ -45,7 +48,7 @@ def valid_sup(valid_iter, model, tree_lstm):
words = words.cuda()

def tree_reward(spans):
new_spans = expand_spans(spans.unsqueeze(0))
new_spans = expand_spans(spans, words, 1, V)
g, labels, indices, topo = TreeLSTM.spans_to_dgl(new_spans)
_, am = tree_lstm(g, labels, indices, topo, lengths).max(-1)
return (label == am).sum(), label.shape[0]
Expand All @@ -68,20 +71,23 @@ def run_train(train_iter, valid_iter, model, tree_lstm, V):
opt_params = torch.optim.Adadelta(list(tree_lstm.parameters()), lr=config["lr_params"])

model.train()
tree_lstm.train()
losses = []
struct = CKY_CRF

for epoch in range(50):
entropy_fn = struct(EntropySemiring)
step = 0
trees = None
for epoch in range(100):
print("Epoch", epoch)

for i, ex in enumerate(train_iter):
step += 1
words, lengths = ex.word
label = ex.label
batch = label.shape[0]
_, N = words.shape
words = words.cuda()


def tree_reward(spans, K):
new_spans = expand_spans(spans, words, K, V)
g, labels, indices, topo = TreeLSTM.spans_to_dgl(new_spans)
Expand All @@ -91,53 +97,98 @@ def tree_reward(spans, K):

sc = SelfCritical(CKY_CRF, tree_reward)
phi = model(words, lengths)
structs, rewards, score, max_score = sc.forward(phi, lengths, K=config["RL_K"], )
structs, rewards, score, max_score = sc.forward(phi, lengths, K=config["RL_K"])

if config["train_model"]:
opt_params.zero_grad()
score.mean().backward()
clip(tree_lstm.parameters())
opt_params.step()

opt_params.zero_grad()

if config["method"] == "reinforce":
opt_struct.zero_grad()
log_partition, entropy = struct(EntropySemiring).sum(phi, lengths=lengths, _raw=True).unbind()
log_partition, entropy = entropy_fn.sum(phi, lengths=lengths, _raw=True).unbind()
r = struct().score(phi.unsqueeze(0), structs, batch_dims=[0,1]) \
- log_partition.unsqueeze(0)
obj = rewards.mul(r).mean(-1).mean(-1)
policy = obj - config["entropy"] * entropy.mean()
policy = obj - config["entropy"] * entropy.div(lengths.float().cuda()).mean()
policy.backward()
clip(model.parameters())
opt_struct.step()
losses.append(-max_score.mean().detach())



# DEBUG
if i % 50 == 9:
print(torch.tensor(losses).mean(), words.shape)
print("Round")
print("Entropy", entropy.mean())
print("Reward", rewards.mean())
if i % 200 == 9:
valid_loss = valid_sup(valid_iter, model, tree_lstm)
print("Entropy", entropy.mean().item())
print("Reward", rewards.mean().item())
if i % 1000 == 9:
valid_loss = valid_sup(valid_iter, model, tree_lstm, V)
fname = "/tmp/checkpoint.%s.%0d.%0d.%s"%(NAME, epoch, i, valid_loss)
torch.save((model, tree_lstm), fname)
wandb.save(fname)
trees = valid_show(valid_iter, model)
else:
print(valid_loss)
# valid_show()
wandb.log({"entropy": entropy.mean(), "valid_loss": valid_loss,

wandb.log({"entropy": entropy.mean(),
"valid_loss": valid_loss,
"reward": rewards.mean(),
"step": step,
"tree": trees,
"reward_var": rewards.var(),
"loss" : torch.tensor(losses).mean()})
losses = []


def valid_show(valid_iter, model):
struct = CKY_CRF
table = wandb.Table(columns=["Sent", "Predicted Tree", "True Tree"])
for i, ex in enumerate(valid_iter):
words, lengths = ex.word
label = ex.label
batch = label.shape[0]
words = words.cuda()
phi = model(words, lengths)
argmax = struct(MaxSemiring).marginals(phi, lengths=lengths)
argmax_tree = struct().from_parts(argmax.detach())[0].cpu()
for b in range(words.shape[0]):
out = [WORD.vocab.itos[w.item()] for w in words[b]]
sent = " ".join(out)
def show(tree):
output = ""
start = {}
end = {}
for i, j, _ in tree.nonzero():
i = i.item()
j = j.item()
start.setdefault(i, -1)
end.setdefault(j, -1)
start[i] += 1
end[j] += 1
for i, w in enumerate(out):
for _ in range(start.get(i, 0)):
output += "( "
output += w + " "
for _ in range(end.get(i, 0)):
output += ") "
return output
predict_text = show(ex.tree[b].cpu())
true_text = show(argmax_tree[b].cpu())
table.add_data(sent, predict_text, true_text)
break
return table

WORD = None
def main():
global WORD
WORD = data.Field(include_lengths=True, batch_first=True, eos_token=None, init_token=None)
LABEL = data.Field(sequential=False, batch_first=True)
TREE = data.RawField(postprocessing=ListOpsDataset.tree_field(WORD))
TREE.is_target=False
train = ListOpsDataset("data/train_d20s.tsv", (("word", WORD), ("label", LABEL), ("tree", TREE)),
filter_pred=lambda x: 5 < len(x.word) < 50)
filter_pred=lambda x: 5 < len(x.word) < config["train_len"])
WORD.build_vocab(train)
LABEL.build_vocab(train)
valid = ListOpsDataset("data/test_d20s.tsv", (("word", WORD), ("label", LABEL), ("tree", TREE)),
Expand All @@ -149,128 +200,36 @@ def main():
train_iter.repeat = False
valid_iter = data.BucketIterator(train,
batch_size=50,
train=False,
sort=False,
device="cuda:0")

NT = 1
T = len(WORD.vocab)
V = T

tree_lstm = TreeLSTM(config["H"],
len(WORD.vocab) + 100, len(LABEL.vocab)).cuda()
for p in tree_lstm.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)

model = SpanLSTM(NT, len(WORD.vocab), config["H"]).cuda()
for p in model.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)
struct = CKY_CRF

wandb.watch(model)
tree = run_train(train_iter, valid_iter, model, tree_lstm, V)

if True:
tree_lstm = TreeLSTM(config["H"],
len(WORD.vocab) + 100, len(LABEL.vocab)).cuda()
for p in tree_lstm.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)

model = SpanLSTM(NT, len(WORD.vocab), config["H"]).cuda()
for p in model.parameters():
if p.dim() > 1:
torch.nn.init.xavier_uniform_(p)
struct = CKY_CRF

wandb.watch((model, tree_lstm))
print(wandb.config)
tree = run_train(train_iter, valid_iter, model, tree_lstm, V)
else:
print("loading")
model, tree_lstm = torch.load("cp.yoyo.model")
print(valid_sup(valid_iter, model, tree_lstm, V))
main()



# def valid_show():
# for i, ex in enumerate(valid_iter):
# words, lengths = ex.word
# label = ex.label
# batch = label.shape[0]
# words = words.cuda()
# out = [WORD.vocab.itos[w.item()] for w in words[0]]
# print(" ".join(out))
# def show(tree):
# start = {}
# end = {}
# for i, j, _ in tree.nonzero():
# i = i.item()
# j = j.item()
# start.setdefault(i, -1)
# end.setdefault(j, -1)
# start[i] += 1
# end[j] += 1
# for i, w in enumerate(out):
# for _ in range(start.get(i, 0)):
# print("(", end=" ")
# print(w, end=" ")
# for _ in range(end.get(i, 0)):
# print(")", end=" ")
# print()
# show(ex.tree[0])
# phi = model(words, lengths)
# argmax = struct(MaxSemiring).marginals(phi, lengths=lengths)
# argmax_tree = struct().from_parts(detach(argmax))[0].cpu()
# show(argmax_tree[0])
# break


# if config["method"] == "ppo":
# # Run PPO
# old = None
# for p in range(10):
# opt_struct.zero_grad()
# obj = []
# t("model")
# phi = model(words, lengths)
# for sample, reward in zip(structs, rewards):
# #if running_reward is None:
# # running_reward = reward.var().detach()
# #else:
# # if p == 0:
# # running_reward = running_reward * alpha + reward.var() * (1.0 - alpha)
# # reward = reward / running_reward.sqrt().clamp(min=1.0)
# t("dp")
# reward = reward.detach()
# log_partition, entropy = struct(EntropySemiring).sum(phi, lengths=lengths, _raw=True).unbind()
# t("add")
# cur = struct().score(phi, sample.cuda()) - log_partition

# if p == 0:
# old = cur.clone().detach()
# r = (cur - old).exp()
# clamped_r = torch.clamp(r, 0.98, 1.02)
# obj.append(torch.max(reward.mul(r), reward.mul(clamped_r)).mean())
# t("rest")
# policy = torch.stack(obj).mean(dim=0) - config["entropy"] * entropy.mean()
# (policy).backward()
# t("update")
# torch.nn.utils.clip_grad_norm_(parameters = model.parameters(), max_norm = 0.5, norm_type=float("inf"))
# opt_struct.step()
# opt_struct.zero_grad()

# def sample_baseline_b(reward_fn, phi, lengths, K=5):
# t("sample")
# sample = struct(MultiSampledSemiring).marginals(phi, lengths=lengths)
# sample = detach(sample)

# t("construct")
# trees = []
# samples = []
# for k in range(K):
# tmp_sample = MultiSampledSemiring.to_discrete(sample, k+1)
# samples.append(tmp_sample)
# sampled_tree = struct().from_parts(tmp_sample)[0].cpu()
# trees.append(sampled_tree)
# structs = torch.stack(samples)
# argmax = struct(MaxSemiring).marginals(phi, lengths=lengths)
# argmax_tree = struct().from_parts(detach(argmax))[0].cpu()
# trees.append(argmax_tree)

# t("use")
# sample_score = reward_fn(torch.cat(trees), K+1)

# t("finish")
# total = sample_score[:-1].mean(dim=0)
# # for k in range(K):
# # samples.append([trees[k][1], sample_scores[k]])
# # if k == 0:
# # total = sample_score.clone()
# # else:
# # total += sample_score
# max_score = sample_score[-1].clone().detach()
# rewards = sample_score[:-1] - max_score.view(1, sample_score.shape[1])
# return structs, rewards, total, max_score

0 comments on commit a140cc5

Please sign in to comment.