Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tkipf committed Mar 2, 2018
1 parent f31ad4b commit 258b224
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 264 deletions.
2 changes: 1 addition & 1 deletion LICENSE
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2018 ethanfetaya
Copyright (c) 2018 Ethan Fetaya, Thomas Kipf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
Empty file added data/__init__.py
Empty file.
12 changes: 0 additions & 12 deletions data/synthetic_sim.py
Expand Up @@ -3,9 +3,6 @@
import time


# np.random.seed(0)


class SpringSim(object):
def __init__(self, n_balls=5, box_size=5., loc_std=.5, vel_norm=.5,
interaction_strength=.1, noise_var=0.):
Expand Down Expand Up @@ -305,17 +302,8 @@ def sample_trajectory(self, T=10000, sample_freq=10,
for i in range(loc.shape[-1]):
plt.plot(loc[:, 0, i], loc[:, 1, i])
plt.plot(loc[0, 0, i], loc[0, 1, i], 'd')
# #plt.plot(vel_norm[:,i])
plt.figure()
energies = [sim._energy(loc[i, :, :], vel[i, :, :], edges) for i in
range(loc.shape[0])]
plt.plot(energies)
# mom = vel.sum(axis=2)
# mom_diff = (mom[1:,:]-mom[:-1,:]).sum(axis=1)
# plt.figure()
# plt.plot(mom_diff)
plt.show()

# np.save("loc.npy", loc)
# np.save("vel.npy", vel)
# np.save("edges.npy", edges)
51 changes: 7 additions & 44 deletions lstm_baseline.py
Expand Up @@ -9,7 +9,6 @@

import torch.optim as optim
from torch.optim import lr_scheduler
from torch import autograd

from utils import *
from modules import *
Expand All @@ -20,15 +19,15 @@
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=500,
help='Number of epochs to train.')
parser.add_argument('--batch_size', type=int, default=128,
parser.add_argument('--batch-size', type=int, default=128,
help='Number of samples per batch.')
parser.add_argument('--lr', type=float, default=0.0005,
help='Initial learning rate.')
parser.add_argument('--hidden', type=int, default=256,
help='Number of hidden units.')
parser.add_argument('--num_atoms', type=int, default=5,
help='Number of atoms in simulation.')
parser.add_argument('--num_layers', type=int, default=2,
parser.add_argument('--num-layers', type=int, default=2,
help='Number of LSTM layers.')
parser.add_argument('--suffix', type=str, default='_springs',
help='Suffix for training data (e.g. "_charged".')
Expand Down Expand Up @@ -58,8 +57,6 @@
parser.add_argument('--var', type=float, default=5e-5,
help='Output variance.')

print("NOTE: For Kuramoto model, set variance to 0.01.")

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args)
Expand Down Expand Up @@ -93,24 +90,8 @@
print("WARNING: No save_folder provided!" +
"Testing (within this script) will throw an error.")

if args.motion:
train_loader, valid_loader, test_loader = load_motion_data(args.batch_size,
args.suffix)
elif args.suffix == "_kuramoto5" or args.suffix == "_kuramoto10":
train_loader, valid_loader, test_loader = load_kuramoto_data(
args.batch_size,
args.suffix)
else:
train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data(
args.batch_size, args.suffix)


# data, relations = train_loader.__iter__().next()
# data, relations = data.cuda(), relations.cuda()
# data, relations = Variable(data), Variable(relations, requires_grad=True)
# logits = encoder(data, rel_rec, rel_send)
# edges = gumbel_softmax(logits, tau=args.temp, hard=False)
# np.save("data/motion_edges.npy", edges.data.cpu().numpy())
train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data(
args.batch_size, args.suffix)


class RecurrentBaseline(nn.Module):
Expand Down Expand Up @@ -146,7 +127,7 @@ def batch_norm(self, inputs):
def step(self, ins, hidden=None):
# Input shape: [num_sims, n_atoms, n_in]
x = F.relu(self.fc1_1(ins))
# x = F.dropout(x, self.dropout_prob, training=self.training)
x = F.dropout(x, self.dropout_prob, training=self.training)
x = F.relu(self.fc1_2(x))
x = x.view(ins.size(0), -1)
# [num_sims, n_atoms*n_hid]
Expand Down Expand Up @@ -205,7 +186,6 @@ def forward(self, inputs, prediction_steps, burn_in=False, burn_in_steps=1):
model.load_state_dict(torch.load(model_file))
args.save_folder = False


optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_decay,
gamma=args.gamma)
Expand Down Expand Up @@ -233,8 +213,6 @@ def train(epoch, best_val_loss):
mse_baseline_val = []
mse_train = []
mse_val = []
mse_last_train = []
mse_last_val = []

model.train()
scheduler.step()
Expand All @@ -246,8 +224,6 @@ def train(epoch, best_val_loss):

optimizer.zero_grad()

# output = model(data, args.prediction_steps)

output = model(data, 100,
burn_in=True,
burn_in_steps=args.timesteps - args.prediction_steps)
Expand All @@ -256,15 +232,13 @@ def train(epoch, best_val_loss):
loss = nll_gaussian(output, target, args.var)

mse = F.mse_loss(output, target)
mse_last = F.mse_loss(output[:, :, -1, :], target[:, :, -1, :])
mse_baseline = F.mse_loss(data[:, :, :-1, :], data[:, :, 1:, :])

loss.backward()
optimizer.step()

loss_train.append(loss.data[0])
mse_train.append(mse.data[0])
mse_last_train.append(mse_last.data[0])
mse_baseline_train.append(mse_baseline.data[0])

model.eval()
Expand All @@ -281,22 +255,18 @@ def train(epoch, best_val_loss):
loss = nll_gaussian(output, target, args.var)

mse = F.mse_loss(output, target)
mse_last = F.mse_loss(output[:, :, -1, :], target[:, :, -1, :])
mse_baseline = F.mse_loss(data[:, :, :-1, :], data[:, :, 1:, :])

loss_val.append(loss.data[0])
mse_val.append(mse.data[0])
mse_last_val.append(mse_last.data[0])
mse_baseline_val.append(mse_baseline.data[0])

print('Epoch: {:04d}'.format(epoch),
'nll_train: {:.10f}'.format(np.mean(loss_train)),
'mse_train: {:.12f}'.format(np.mean(mse_train)),
# 'mse_last_train: {:.12f}'.format(np.mean(mse_last_train)),
'mse_baseline_train: {:.10f}'.format(np.mean(mse_baseline_train)),
'nll_val: {:.10f}'.format(np.mean(loss_val)),
'mse_val: {:.12f}'.format(np.mean(mse_val)),
# 'mse_last_val: {:.12f}'.format(np.mean(mse_last_val)),
'mse_baseline_val: {:.10f}'.format(np.mean(mse_baseline_val)),
'time: {:.4f}s'.format(time.time() - t))
if args.save_folder and np.mean(loss_val) < best_val_loss:
Expand All @@ -305,11 +275,9 @@ def train(epoch, best_val_loss):
print('Epoch: {:04d}'.format(epoch),
'nll_train: {:.10f}'.format(np.mean(loss_train)),
'mse_train: {:.12f}'.format(np.mean(mse_train)),
# 'mse_last_train: {:.12f}'.format(np.mean(mse_last_train)),
'mse_baseline_train: {:.10f}'.format(np.mean(mse_baseline_train)),
'nll_val: {:.10f}'.format(np.mean(loss_val)),
'mse_val: {:.12f}'.format(np.mean(mse_val)),
# 'mse_last_val: {:.12f}'.format(np.mean(mse_last_val)),
'mse_baseline_val: {:.10f}'.format(np.mean(mse_baseline_val)),
'time: {:.4f}s'.format(time.time() - t), file=log)
log.flush()
Expand All @@ -320,7 +288,6 @@ def test():
loss_test = []
mse_baseline_test = []
mse_test = []
mse_last_test = []
tot_mse = 0
tot_mse_baseline = 0
counter = 0
Expand All @@ -346,20 +313,18 @@ def test():
loss = nll_gaussian(output, target, args.var)

mse = F.mse_loss(output, target)
mse_last = F.mse_loss(output[:, :, -1, :], target[:, :, -1, :])
mse_baseline = F.mse_loss(ins_cut[:, :, :-1, :], ins_cut[:, :, 1:, :])

loss_test.append(loss.data[0])
mse_test.append(mse.data[0])
mse_last_test.append(mse_last.data[0])
mse_baseline_test.append(mse_baseline.data[0])

if args.motion or args.non_markov:
# RNN decoder evaluation setting

# For plotting purposes
output = model(inputs, 100, burn_in=True,
burn_in_steps=args.timesteps)
burn_in_steps=args.timesteps)

output = output[:, :, args.timesteps:, :]
target = inputs[:, :, -args.timesteps:, :]
Expand All @@ -380,7 +345,7 @@ def test():

# For plotting purposes
output = model(inputs, 100, burn_in=True,
burn_in_steps=args.timesteps)
burn_in_steps=args.timesteps)

output = output[:, :, args.timesteps:args.timesteps + 20, :]
target = inputs[:, :, args.timesteps + 1:args.timesteps + 21, :]
Expand Down Expand Up @@ -417,7 +382,6 @@ def test():
print('--------------------------------')
print('nll_test: {:.10f}'.format(np.mean(loss_test)),
'mse_test: {:.12f}'.format(np.mean(mse_test)),
# 'mse_last_test: {:.12f}'.format(np.mean(mse_last_test)),
'mse_baseline_test: {:.10f}'.format(np.mean(mse_baseline_test)))
print('MSE: {}'.format(mse_str))
print('MSE Baseline: {}'.format(mse_baseline_str))
Expand All @@ -427,7 +391,6 @@ def test():
print('--------------------------------', file=log)
print('nll_test: {:.10f}'.format(np.mean(loss_test)),
'mse_test: {:.12f}'.format(np.mean(mse_test)),
# 'mse_last_test: {:.12f}'.format(np.mean(mse_last_test)),
'mse_baseline_test: {:.10f}'.format(np.mean(mse_baseline_test)),
file=log)
print('MSE: {}'.format(mse_str), file=log)
Expand Down

0 comments on commit 258b224

Please sign in to comment.