Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
3riccc committed Apr 10, 2019
1 parent 5b5422a commit f1d5070
Show file tree
Hide file tree
Showing 10 changed files with 373 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -53,4 +53,4 @@ If you use this code in your own work, please cite our paper:
journal={arXiv preprint arXiv:1812.11482},
year={2019}
}
```
```
200 changes: 200 additions & 0 deletions cml_128_not_batch.py
@@ -0,0 +1,200 @@
import matplotlib
matplotlib.use('Agg')
import pickle
import torch.optim as optim
import matplotlib.pyplot as plt
import argparse
from torch.utils.data import DataLoader
from tools import *
from utils.model import *
from utils.process_128 import *
# from tensorboardX import SummaryWriter


def train_dynamics(args, dynamics_learner, gumbel_generator, optimizer,
device, train_loader, epoch, experiment,skip_conn,object_matrix):
matrix = gumbel_generator.sample(hard=True) # Sample from gumbel generator
matrix = torch.tensor(object_matrix).cuda()


fig = plt.figure()
plt.imshow(matrix.to('cpu').numpy(), cmap='gray')
plt.close()

loss_records = []
mse_records = []
mses = 0
for step in range(1, args.dynamics_steps + 1):
loss_record = []
mse_record = []
mses = 0
for batch_idx, data in enumerate(train_loader):
data = data.to(device)
mse = train_dynamics_learner(optimizer, dynamics_learner,
matrix, data, args.nodes, args.prediction_steps,skip_conn)
mses += mse
# loss_record.append(loss.item())
mse_record.append(mse.item())
# print(batch_idx/)
if batch_idx % 128 == 0:
mses.backward()
optimizer.step()
print(np.mean(mse_record))
mse_list = []
loss = 0
mses = 0
optimizer.zero_grad()
# loss_records.append(np.mean(loss_record))
# mse_records.append(np.mean(mse_record))
# print('\nDynamics learning step: %d, loss: %f, MSE: %f' % (step, np.mean(loss_record), np.mean(mse_record)))


def val_dynamics(args, dynamics_learner, gumbel_generator,
device, val_loader, epoch, experiment, best_val_loss,skip_conn):
matrix = gumbel_generator.sample(hard=True) # Sample from gumbel generator

loss_record = []
mse_record = []
for batch_idx, data in enumerate(val_loader):
data = data.to(device)
loss, mse = val_dynamics_learner(dynamics_learner, matrix, args.nodes, data, args.prediction_steps,skip_conn)
loss_record.append(loss.item())
mse_record.append(mse.item())

print('\nDynamics validation: loss: %f, MSE: %f' % (np.mean(loss_record), np.mean(mse_record)))

if best_val_loss > np.mean(loss_record):
torch.save(dynamics_learner.state_dict(), args.dynamics_path)
torch.save(gumbel_generator.state_dict(), args.gumbel_path)

return np.mean(loss_record)


def train_gumbel(args, dynamics_learner, gumbel_generator, optimizer_network,
device, train_loader, object_matrix, epoch, experiment,skip_conn):
object_matrix.to(device)

loss_records = []
net_error_records = []
tpr_records = []
fpr_records = []
for step in range(1, args.reconstruct_steps + 1):
loss_record = []
for batch_idx, data in enumerate(train_loader):
data = data.to(device)
loss, _ = train_net_reconstructor(optimizer_network, gumbel_generator, dynamics_learner,
args.nodes, data, args.prediction_steps,skip_conn)
loss_record.append(loss.item())
loss_records.append(np.mean(loss_record))
if step % 1 == 0:
net_error, tpr, fpr = constructor_evaluator(gumbel_generator, 500, object_matrix, args.nodes)
net_error_records.append(net_error)
tpr_records.append(tpr)
fpr_records.append(fpr)
print('\nGumbel training step: %d, loss: %f' % (step, np.mean(loss_record)))
print('Net error: %f, TPR: %f, FPR: %f' % (net_error, tpr, fpr))



def test(args, dynamics_learner, gumbel_generator, device, test_loader, object_matrix, experiment,skip_conn):
# load model
dynamics_learner.load_state_dict(torch.load(args.dynamics_path))
gumbel_generator.load_state_dict(torch.load(args.gumbel_path))

# evaluate network
net_error, tpr, fpr = constructor_evaluator(gumbel_generator, 500, object_matrix, args.nodes)

# evaluate dynamics
matrix = gumbel_generator.sample(hard=True) # Sample from gumbel generator

dynamics_learner.eval()
loss_record = []
mse_record = []
for batch_idx, data in enumerate(test_loader):
data = data.to(device)
loss, mse = val_dynamics_learner(dynamics_learner, matrix, args.nodes, data, args.prediction_steps)
loss_record.append(loss.item())
mse_record.append(mse.item())

print('\nTest: Net error: %f, TPR: %f, FPR: %f' % (net_error, tpr, fpr))
print('loss: %f, mse: %f' % (np.mean(loss_record), np.mean(mse_record)))


def main():
# Training settings
parser = argparse.ArgumentParser(description='Coupled Map Lattice and Kuramoto')
parser.add_argument('--simulation-type', type=str, default='cml',
help='simulation type to choose(cml or kuramoto)')
parser.add_argument('--nodes', type=int, default=10,
help='number of nodes in data')
parser.add_argument('--dims', type=int ,default=1,
help='information dimension in data(1 for CML and 2 for kuramoto)')
parser.add_argument('--skip', type=int ,default=1,
help='weather to use skip connection(recommend to use in kuramoto)')
parser.add_argument('--batch-size', type=int, default=128,
help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=15,
help='number of epochs, default: 15)')
parser.add_argument('--experiments', type=int, default=10,
help='number of experiments (default: 10)')
parser.add_argument('--dynamics-steps', type=int, default=30,
help='number of steps for dynamics learning (default: 30)')
parser.add_argument('--reconstruct-steps', type=int, default=0,
help='number of steps for reconstruction (default: 5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=2050,
help='random seed (default: 2050)')
parser.add_argument('--prediction-steps', type=int, default=10,
help='prediction steps in data (default: 10)')
parser.add_argument('--dynamics-path', type=str, default='./saved/dynamics.pickle',
help='path to save dynamics learner (default: ./saved/dynamics.pickle)')
parser.add_argument('--gumbel-path', type=str, default='./saved/gumbel.pickle',
help='path to save gumbel generator (default: ./saved/gumbel.pickle)')
parser.add_argument('--data-path', type=str, default='./data/test.pickle',
help='path to load data (default: ./data/test.pickle)')

args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

print('Device:', device)

# Loading data
train_loader,val_loader,test_loader,object_matrix = load_cml_ggn(batch_size=1)

for experiment in range(1, args.experiments + 1):
print('\n---------- Experiment %d ----------' % experiment)

best_val_loss = np.inf
best_epoch = 0

gumbel_generator = Gumbel_Generator(sz=args.nodes, temp=10, temp_drop_frac=0.9999).to(device)
optimizer_network = optim.Adam(gumbel_generator.parameters(), lr=0.1)

dynamics_learner = GumbelGraphNetwork(args.dims).to(device)
optimizer = optim.Adam(dynamics_learner.parameters(), lr=0.0001)

for epoch in range(1, args.epochs + 1):
print('\n---------- Experiment %d Epoch %d ----------' % (experiment, epoch))
train_dynamics(args, dynamics_learner, gumbel_generator, optimizer,device, train_loader, epoch, experiment,args.skip,object_matrix)
val_loss = val_dynamics(args, dynamics_learner, gumbel_generator,
device, val_loader, epoch, experiment, best_val_loss,args.skip)
train_gumbel(args, dynamics_learner, gumbel_generator, optimizer_network,
device, train_loader, object_matrix, epoch, experiment,args.skip)

if val_loss < best_val_loss:
best_val_loss = val_loss
best_epoch = epoch
print('\nCurrent best epoch: %d, best val loss: %f' % (best_epoch, best_val_loss))

print('\nBest epoch: %d' % best_epoch)
test(args, dynamics_learner, gumbel_generator, device, test_loader, object_matrix, experiment,args.skip)


if __name__ == '__main__':
main()

1 change: 1 addition & 0 deletions data/data_generator_cml.py
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pickle


use_cuda = torch.cuda.is_available()
np.random.seed(2050)
torch.manual_seed(2050)
Expand Down
1 change: 0 additions & 1 deletion tools.py
Expand Up @@ -233,7 +233,6 @@ def load_cml_ggn(batch_size = 128):
print('\nMatrix dimension: %s Train data size: %s Val data size: %s Test data size: %s'
% (object_matrix.shape, train_data.shape, val_data.shape, test_data.shape))


train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
Expand Down
Binary file added utils/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file added utils/__pycache__/model.cpython-37.pyc
Binary file not shown.
Binary file added utils/__pycache__/process_128.cpython-37.pyc
Binary file not shown.
Binary file added utils/__pycache__/util.cpython-37.pyc
Binary file not shown.
17 changes: 15 additions & 2 deletions utils/model.py
Expand Up @@ -84,19 +84,32 @@ def __init__(self, input_size, hidden_size = 128):
self.output = torch.nn.Linear(input_size + hidden_size, input_size)

def forward(self, x, adj,skip_conn=0):
# adj[batch,node,node]
# x[batch,node,dim]
out = x

# innode [batch,node,node,dim]
innode = out.unsqueeze(1).repeat(1, adj.size()[1], 1, 1)
# outnode [batch,node,node,dim]
outnode = innode.transpose(1, 2)
# node2edge:[batch,node,node,2*dim->hidden]
node2edge = F.relu(self.edge1(torch.cat((innode,outnode), 3)))
# edge2edge:[batch,node,node,hidden]
edge2edge = F.relu(self.edge2edge(node2edge))
# adjs:[batch,node,node,1]
adjs = adj.view(adj.size()[0], adj.size()[1], adj.size()[2], 1)
# adjs:[batch,node,node,hidden]
adjs = adjs.repeat(1, 1, 1, edge2edge.size()[3])

# edges:[batch,node,node,hidden]
edges = adjs * edge2edge
out = torch.sum(edges, 1)
# out:[batch,node,hid]
out = torch.sum(edges, 2)
# out:[batch,node,hid]
out = F.relu(self.node2node(out))
out = F.relu(self.node2node2(out))
# out:[batch,node,dim+hid]
out = torch.cat((x, out), dim = -1)
# out:[batch,node,dim]
out = self.output(out)
if skip_conn == 1:
out = out + x
Expand Down

0 comments on commit f1d5070

Please sign in to comment.