-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_inference.py
78 lines (66 loc) · 1.82 KB
/
main_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 3 17:26:00 2018
@author: kalifou
"""
import numpy as np
import utils
import torch
from torch.autograd import Variable
from Modules import SketchRNN, Lr, Lkl, early_stopping_Loss
import torch.optim as optim
from matplotlib import pyplot as plt
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import time
import _pickle as pickle
t1 = time.time()
filename = "sketch-rnn-datasets/aaron_sheep/aaron_sheep.npz"
load_data = np.load(filename, encoding = 'latin1')
train_set = load_data['train']
valid_set = load_data['valid']
test_set = load_data['test']
nb_steps = 10000
feature_len=5
batch_size = 1
max_seq_len=250
augment_stroke_prob=0.1
random_scale_factor=0.1
train_set = utils.DataLoader(
train_set,
batch_size,
max_seq_length=max_seq_len,
random_scale_factor=random_scale_factor,
augment_stroke_prob=augment_stroke_prob)
normalizing_scale_factor = train_set.calculate_normalizing_scale_factor()
train_set.normalize(normalizing_scale_factor)
reload_ = True
cuda = True
M = 20
obs_size=5
Y_size = 6*M+3
N_he = 512 #4
N_hd = 512
N_z = 128
w_lk = 0.5
lr=1e-3
N_max = max_seq_len
X_decoder_size = N_z * batch_size + obs_size
if reload_:
s2s_vae = pickle.load(open('sketch_rnn_save_80000.p', 'rb'))
s2s_vae.decoder.Nz=N_z
s2s_vae.decoder.batchSize = batch_size
print("Model reloaded")
else :
#strokeSize, batchSize, Nhe, Nhd, Nz, Ny, max_seq_len
s2s_vae = SketchRNN(obs_size, batch_size, N_he, N_hd, N_z, 6*M+3, max_seq_len)
if cuda:
#s2s_vae.encoder.cuda()
#s2s_vae.decoder.cuda()
s2s_vae.cuda()
cudnn.benchmark = True
print("Using cuda")
for _ in range(10):
_, x, s = train_set.random_batch()
x = Variable(torch.from_numpy(x).type(torch.FloatTensor).cuda())
s2s_vae.predict(x, M)