-
Notifications
You must be signed in to change notification settings - Fork 1
/
TFMolManage.py
129 lines (113 loc) · 4.03 KB
/
TFMolManage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#
# Either trains, tests, evaluates or provides an interface for optimization.
#
from TensorMolData import *
from TFMolInstance import *
import numpy as np
import gc
class TFMolManage:
def __init__(self, Name_="", TData_=None, Train_=True, NetType_="fc_sqdiff", Test_TData_=None): #Test_TData_ is some other randon independent test data
self.path = "./networks/"
if (Name_ != ""):
# This will unpickle and instantiate TData...
self.name = Name_
self.Prepare()
return
self.TData = TData_
self.Test_TData = Test_TData_
self.NetType = NetType_
print self.TData.AvailableDataFiles
self.name = self.TData.name+self.TData.dig.name+"_"+self.NetType
print "--- TF will be fed by ---",self.TData.name
self.TrainedAtoms=[] # In order of the elements in TData
self.TrainedNetworks=[] # In order of the elements in TData
self.Instances=None # In order of the elements in TData
if (Train_):
self.Train()
return
return
def Print(self):
print "-- TensorMol, Tensorflow Manager Status--"
return
def Save(self):
print "Saving TFManager."
self.TData.CleanScratch()
f=open(self.path+self.name+".tfm","wb")
pickle.dump(self.__dict__, f, protocol=1)
f.close()
return
def Load(self):
print "Unpickling TFManager..."
f = open(self.path+self.name+".tfm","rb")
tmp=pickle.load(f)
self.__dict__.update(tmp)
f.close()
print "TFManager Metadata Loaded, Reviving Networks."
self.Print()
return
def Train(self, maxstep=10000):
if (self.TData.dig.eshape==None):
raise Exception("Must Have Digester")
# It's up the TensorData to provide the batches and input output shapes.
if (self.NetType == "fc_classify"):
self.Instances = Instance_fc_classify(self.TData, None, self.Test_TData)
elif (self.NetType == "fc_sqdiff"):
self.Instances = Instance_fc_sqdiff(self.TData, None, self.Test_TData)
else:
raise Exception("Unknown Network Type!")
self.Instances.train(maxstep) # Just for the sake of debugging.
nm = self.Instances.name
# Here we should print some summary of the pupil's progress as well, maybe.
if self.TrainedNetworks.count(nm)==0:
self.TrainedNetworks.append(nm)
self.Save()
gc.collect()
return
def Eval(self, test_input):
return self.Instances.evaluate(test_input)
def Prepare(self):
self.Load()
self.Instances= None # In order of the elements in TData
if (self.NetType == "fc_classify"):
self.Instances = Instance_fc_classify(None, self.TrainedNetworks[0], None)
elif (self.NetType == "fc_sqdiff"):
self.Instances = Instance_fc_sqdiff(None, self.TrainedNetworks[0], None)
else:
raise Exception("Unknown Network Type!")
# Raise TF instances for each atom which have already been trained.
return
def Eval_Mol(self, mol):
total_case = len(mol.mbe_frags[self.TData.order])
if total_case == 0:
return 0.0
natom = mol.mbe_frags[self.TData.order][0].NAtoms()
cases = np.zeros((total_case, self.TData.dig.eshape))
cases_deri = np.zeros((total_case, natom, natom, 6)) # x1,y1,z1,x2,y2,z2
casep = 0
for frag in mol.mbe_frags[self.TData.order]:
ins, embed_deri = self.TData.dig.EvalDigest(frag)
cases[casep:casep+1] += ins
cases_deri[casep:casep+1]=embed_deri
casep += 1
print "evaluating order:", self.TData.order
nn, nn_deri=self.Eval(cases)
mean, std = self.TData.Get_Mean_Std()
nn = nn*std+mean
nn_deri = nn_deri*std
#print "nn:",nn, "nn_deri:",nn_deri, "cm_deri:", cases_deri, "cases:",cases, "coord:", mol.coords
mol.Set_Frag_Force_with_Order(cases_deri, nn_deri, self.TData.order)
return nn.sum()
def Test(self, save_file="mbe_test.dat"):
ti, to = self.TData.LoadData( True)
NTest = int(self.TData.TestRatio * ti.shape[0])
ti= ti[ti.shape[0]-NTest:]
to = to[to.shape[0]-NTest:]
acc_nn = np.zeros((to.shape[0],2))
acc=self.TData.ApplyNormalize(to)
nn=self.Eval(ti)
acc_nn[:,0]=acc.reshape(acc.shape[0])
acc_nn[:,1]=nn[0].reshape(nn[0].shape[0])
mean, std = self.TData.Get_Mean_Std()
acc_nn = acc_nn*std+mean
np.savetxt(save_file,acc_nn)
return