Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working with 2-D series #9

Closed
SwapnilDreams100 opened this issue Dec 8, 2021 · 10 comments
Closed

Working with 2-D series #9

SwapnilDreams100 opened this issue Dec 8, 2021 · 10 comments

Comments

@SwapnilDreams100
Copy link

Hi Ian,
Can I modify the RNN to take in 2D series instead of 1D, is it feasible?
Or is there some other solution which think is possible, like concatenating the values somehow?
My current solution is just taking the mean across the 2nd dimension so that it becomes 1D.
Appreciate any help!

@iancovert
Copy link
Owner

Can you explain a bit more how the time series is 2-dimensional? What would the shape of the tensor be, and what is the size of the input at each time point? It may be possible but I just want to understand the data more.

@SwapnilDreams100
Copy link
Author

SwapnilDreams100 commented Dec 9, 2021

The overall time series is 10x750x450; so the output adjacency matrix should be 10x10.
So each series is 750x450. I was taking the mean across 450, to get 750 data points as the time series, but I was hoping to encode both the axes somehow using the RNN, taking the mean/median is quite lossy.

@iancovert
Copy link
Owner

Okay, so it sounds like this is a multivariate time series with 10 total time series (or 10 components) because we want an adjacency matrix of size 10x10. What do the other dimensions represent? One should be the number of time points, right? And is the other possibly a batch dimension?

@SwapnilDreams100
Copy link
Author

The other dimension is the epochs, so the same experiment is repeated 450 times

@iancovert
Copy link
Owner

Okay great. So if I understand correctly, you have 450 experiments, each of which contains 750 time points, where each time point has values for 10 dimensions. And the goal is to figure out which dimensions Granger cause which other dimensions.

If you check out the notebooks, for example this one with the cLSTM, you'll notice that the data argument X for the function train_model_ista also has three dimensions. The first dimension is a batch dimension (in your case, 450), the second one is the time dimension (in your case, 750), and the third dimension is the dimensionality of the time series (in your case, 10).

So you should have no problem training these models with your data, just make sure you arrange the data tensor with the dimensions in the appropriate order. That is, (450, 750, 10).

@SwapnilDreams100
Copy link
Author

Hi Ian,
The approach makes sense, just that 450 epochs are just different measurements of each of the time points taken in succession, so the shape should probably be 750x450x10?
Also using data format causes memory issues, maybe because of no batch_size parameter for the models?

@iancovert
Copy link
Owner

I may be misunderstanding what you mean by epochs, but whichever dimension corresponds to the batch should go first. That is, if you would say that there are 750 separate trials (for lack of a better word) each of length 450, then the input should be 750x450x10.

About the memory issues: do you mean that you're running out of GPU memory? If so, that may be because the existing training algorithms do full batch gradient descent. And to make matters worse, when you use the cRNN or cLSTM, long time series are broken up into many short time series, which further increases the memory required to store the data.

Let me know if it seems like that's what's going on. If so, the solution will be to modify the optimization to do stochastic gradient descent with a mini batch of time series, rather than taking a gradient step on all the data at once. It should be relatively easy to implement, it's just not something we tested in our paper.

@SwapnilDreams100
Copy link
Author

Got it. Yeah I meant the GPU space. Minibatch gradient descent makes sense. Where should I change the ISTA procedure to implement it?

@iancovert
Copy link
Owner

Okay makes sense. So I've tried to write a function that can handle this amount of data with the cLSTM. It should also work for the cRNN, although not the cMLP (but my guess is that you may not encounter memory errors with the cMLP anyway, because it doesn't break up long time series into many small ones).

The way it works actually isn't by doing stochastic (minibatch) gradient descent. Instead, it iterates over the data and accumulates gradients, which is functionally equivalent to the original training code but has lower memory requirements. (I think it's up for debate whether doing stochastic gradient descent could help/hurt, but this is more consistent with what we've tested.)

So check out the function below and see how it works. I haven't tested it so you may have to do a little debugging, sorry. And make sure that when you pass the data X as an argument, it should be a PyTorch tensor but on CPU rather than GPU. (The function automatically moves it to the same device as the model right before making predictions.)

import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from clstm import arrange_input, regularize, ridge_regularize, prox_update, restore_parameters
from torch.utils.data import DataLoader, TensorDataset


def train_model_accumulated_ista(clstm, X, context, mbsize, lr, max_iter, lam=0,
                                 lam_ridge=0, lookback=5, check_every=50,
                                 verbose=1):
    '''Train model with Adam.'''
    p = X.shape[-1]
    loss_fn = nn.MSELoss(reduction='mean')
    train_loss_list = []

    # Set up data.
    X, Y = zip(*[arrange_input(x, context) for x in X])
    X = torch.cat(X, dim=0)
    Y = torch.cat(Y, dim=0)

    # Set up data loader.
    dataset = TensorDataset(X, Y)
    loader = DataLoader(dataset, batch_size=mbsize, shuffle=True,
                        drop_last=False)
    device = next(clstm.parameters()).device

    # For early stopping.
    best_it = None
    best_loss = np.inf
    best_model = None

    for it in range(max_iter):
        for x, y in loader:
            # Move to device.
            x = x.to(device)
            y = y.to(device)

            # Calculate loss.
            pred = [clstm.networks[i](x)[0] for i in range(p)]
            loss = (len(x) / len(X)) * sum(
                [loss_fn(pred[i][:, :, 0], y[:, :, i]) for i in range(p)])

            # Accumulate gradients.
            loss.backward()

        # Accumulate gradients for smooth penalty.
        ridge = sum(
            [ridge_regularize(net, lam_ridge) for net in clstm.networks])
        ridge.backward()

        # Take gradient step.
        for param in clstm.parameters():
            param.data -= lr * param.grad

        # Take prox step.
        if lam > 0:
            for net in clstm.networks:
                prox_update(net, lam, lr)

        # Zero grad.
        clstm.zero_grad()

        # Check progress.
        if (it + 1) % check_every == 0:
            with torch.no_grad():
                total_loss = 0
                for x, y in loader:
                    # Move to device.
                    x = x.to(device)
                    y = y.to(device)

                    # Calculate loss.
                    pred = [clstm.networks[i](x)[0] for i in range(p)]
                    loss = sum([loss_fn(pred[i][:, :, 0], y[:, :, i])
                                for i in range(p)])
                    total_loss = total_loss + len(x) / len(X) * loss

                # Add smooth penalty.
                ridge = sum([ridge_regularize(net, lam_ridge)
                             for net in clstm.networks])
                smooth = ridge + total_loss

                # Add nonsmooth penalty.
                nonsmooth = sum(
                    [regularize(net, lam) for net in clstm.networks])
                mean_loss = (smooth + nonsmooth) / p
                train_loss_list.append(mean_loss.detach())

            if verbose > 0:
                print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1))
                print('Loss = %f' % mean_loss)
                print('Variable usage = %.2f%%'
                      % (100 * torch.mean(clstm.GC().float())))

            # Check for early stopping.
            if mean_loss < best_loss:
                best_loss = mean_loss
                best_it = it
                best_model = deepcopy(clstm)
            elif (it - best_it) == lookback * check_every:
                if verbose:
                    print('Stopping early')
                break

    # Restore best model.
    restore_parameters(clstm, best_model)

    return train_loss_list

@SwapnilDreams100
Copy link
Author

Thank you so much for the help! The code is running perfectly. Happy to close the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants