New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VAMPnet partial fit #130
Comments
Hi, we are soon going to add an MD example for vampnets (tagging @amardt), then hopefully things become more clear. The Here is some boilerplate code how you could work with your large data: paths_train = paths[:-1]
path_val = paths[1:]
data_val = TimeLaggedDataset.from_trajectory(lagtime=500, data=np.load(path_val).astype(np.float32))
loader_val = DataLoader(val_data, batch_size=len(data_val), shuffle=False)
lobe = MLP(units=ns, nonlinearity=nn.ReLU)
vampnet = VAMPNet(lobe=lobe, learning_rate=1e-4)
from random import shuffle
for _ in range(n_rounds):
for path in shuffle(paths):
data = np.load(path)
dataset = TimeLaggedDataset.from_trajectory(lagtime=500, data=data.astype(np.float32))
loader_train = DataLoader(dataset, batch_size=512, shuffle=True)
vampnet.fit(loader_train, n_epochs=80, validation_loader=loader_val)
model = vampnet.fetch_model() Note that I haven't actually run this so there might be some typos in there. Also with regards to validation you might need to be more careful depending on how your data looks like. |
Thank you for the quick answer! So the fit function every time is called updates the network it doesn't overwrite right?
is it a typo or you wanted to provide a list of trajectories? It seems a typo although it's quite consistent between the two lines so I'm wondering whether I didn't understand something. |
Oops yes, it should be
so this would take all but the last file for training and validates on the last file. Fit indeed doesn't overwrite the model (which it would in any other estimator so there is a small inconsistency here...). Eventually I want to provide an adapter between the pyemma source object and pytorch datasets (or even implement the pyemma source as such), then everything should work seamlessly, i.e., create source with pyemma and then fit with deeptime. Also it should be |
ok thanks! no problem at all! but then if I wanted to use multiple trajectories for validation purposes (and not just one) would it be a good idea putting the |
So for validation you dont have to separate, I would validate always against the full validation set. In case of multiple trajectories I suggest to do something like this (outside the loop): class ValidationSet:
def __init__(self, files, lagtime):
self.data = [np.load(f) for f in files]
self.lagtime = lagtime
def __getitem__(self, item):
return self.data[item][:-self.lagtime], self.data[item][self.lagtime:]
def __len__(self):
return len(self.data)
class ValidationLoader:
def __init__(self, val_data):
self.val_data = val_data
self.loader_val_internal = DataLoader(val_data, batch_size=1, shuffle=False)
def __len__(self):
return len(self.data)
def __iter__(self):
for X, Y in loader_val_internal:
yield X.squeeze(), Y.squeeze()
val_data = ValidationSet(paths_val, 200)
loader_val = ValidationLoader(val_data) The squeeze bit is important to get the right shape out of the loader. |
That being said I think it is a useful addition to the TimeLaggedDataset to also accept lists of trajectories so this kind of stuff can be handled internally! I will keep it in mind. |
many thanks! It seems it's working perfectly! I do have another question though. The script doesn't seem to use the GPU at all and actually it's pretty slow. I can see that pytorch can see the cuda installation and the GPU correctly:
but when, for instance, I regularly check the memory used by the script using
and it does seem the device is set to cuda but it doesn't use it |
Glad it did the trick! For the device you actually discovered a bug in the documentation. When you set it in the VAMPNet constructor (as in |
I'm glad in this way I've helped a little bit. I've inserted this on the estimator but I get this error
Note that if I do the same in the CPU I don't get any error, the data should be loaded as we discussed so they fit extremely well the memory ~200MB/32GB |
This is indeed strange, does the error persist if you restart your machine? How does your nvidia-smi output look like? You could also try decreasing the batch size. |
sorry you're absolutely right it was my fault since I was using without realizing the GPU memory elsewhere... although unfortunately now I get this problem:
I don't understand where I could load the data on the GPU, from the documentation I thought it was going to do it automatically |
try this before creating the vampnet estimator: clearly the documentation is still lacking on this account, sorry for that! |
thank you very much! that does the trick! I'm happy I helped you somehow on debugging/checking the documentation while you helped me making things to work. |
Perfect, I am glad it's resolved! Let me know if you have other issues downstream, happy to help and in the end it also helps to improve the library 🙂 For how I'll close the issue and make an update to the docs soon. |
I'd like to use VAMPnet for a large amount of data. I'm coming from pyEMMA where managing this is easy thanks to the function
pyemma.coordinates.source
. I see that deeptime is lacking this function but I do see thepartial_fit
function in almost all the functions. My problem is how this can be used in VAMPnet? Thefit
andpartial_fit
functions seem to do different things: in the first one for instance it is asked also for validation data while the second is satisfied by just the training data, same thing for the number of epochs.Another thing is whether I should fetch the model at the end. Right now I'm trying to do a loop over my data in the following way:
Note I'm not using the
train_data
and theval_data
as you did in the documentation sincepartial_fit
doesn't require it, but I'm pretty sure that I should somehow.I think that from the documentation is not clear how you should deal with this kind of problem.
Thank you very much for your time
The text was updated successfully, but these errors were encountered: