Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using Recurrent and Concat together #30

Closed
cpennington opened this issue Jul 12, 2017 · 10 comments
Closed

Using Recurrent and Concat together #30

cpennington opened this issue Jul 12, 2017 · 10 comments

Comments

@cpennington
Copy link

I'm trying to use Recurrent and Concat together in the same network. In particular, I'm trying to run two LSTMs in parallel against different subsets of the input, and then want to Concat the results together.

I have something like this, so far:

type R = Recurrent
type F = FeedForward

type ShapeInput = 'D1 164

type CropOpponent = Crop 0 0 0 55
type CropPlayer = Crop 0 55 0 164

type LearnPlayer = RecurrentNetwork
    '[ F Reshape
    , F CropPlayer
    , F Reshape
    , R (LSTM 109 20)
    ]
    '[ ShapeInput
    , D2 1 164
    , D2 1 109
    , D1 109
    , D1 20
    ]

type LearnOpponent = RecurrentNetwork
    '[ F Reshape
    , F CropOpponent
    , F Reshape
    , R (LSTM 55 10)
    ]
    '[ ShapeInput
    , D2 1 164
    , D2 1 55
    , D1 55
    , D1 10
    ]

type RecNet = Network
    '[ Concat
        ShapeInput
        LearnPlayer
        ShapeInput
        LearnOpponent
    ]
    '[ ShapeInput
    , 'D1 30
    ]

randomNet :: MonadRandom m => m RecNet
randomNet = randomNetwork

On compilation, the error I'm getting is:

LearningBot.hs:69:13: error:
    • Couldn't match type ‘'False’ with ‘'True’
        arising from a use of ‘randomNetwork’
    • In the expression: randomNetwork
      In an equation for ‘randomNet’: randomNet = randomNetwork

Is there any way to accomplish what I'm looking for with Grenade right now?

P.S. I also tried this way, and got the same error:


type RecNet = RecurrentNetwork
    '[ F (
        Concat
            ShapeInput
            LearnPlayer
            ShapeInput
            LearnOpponent
        )
    ]
    '[ ShapeInput
    , 'D1 30
    ]

type RecInput = RecurrentInputs
    '[ F (
        Concat
            ShapeInput
            LearnPlayer
            ShapeInput
            LearnOpponent
        )
    ]

randomNet :: MonadRandom m => m (RecNet, RecInput)
randomNet = randomRecurrent
@HuwCampbell
Copy link
Owner

HuwCampbell commented Jul 12, 2017

Great question.

First up. I don't think your Crop layers are correct. The 164 doesn't seem right in CropPlayer. (The numbers are how many are taken from the left and right, not the resulting width).

Seems like I should add a version of Crop which works on 1D shapes, you can do this too if you like in your own code.

I believe right now it's probably not possible*. As really, RecNet should be a recurrent network, with an recurrent Concat layer (R instead of F).

Now the problem is that Concat isn't an instance of RecurrentLayer. I can't think of a fundamental reason that it shouldn't be, or at least that a layer just like Concat (RecConcat) couldn't exist which takes two layers which are tagged with F or R and makes a new recurrent layer.

Would you like to try writing it?

EDIT:

  • With the current Concat layer (or without an orphan instance). One can write their own layers downstream

@cpennington
Copy link
Author

Ah, good to know about how Crop works.

I'll take a stab at a 1D Crop and making Concat an instance of RecurrentLayer. Hopefully the types should guide me in the right direction (and I'll drop back here for advice if I get stuck).

@HuwCampbell
Copy link
Owner

Ahh, there's actually another problem. I haven't yet written an instance of RecurrentLayer for RecurrentNetwork.

I think it's possible, but requires packing all the recurrent (sideways travelling) shapes into a single vector.

@HuwCampbell
Copy link
Owner

You might have to run both LSTM networks forwards individually for now. The GAN mnist example gives a non-recurrent example of something like this.

@cpennington
Copy link
Author

Ah, ok. I had thought about doing that, but hadn't looked closely enough at runBackwards/runGradient to see that they spit out something input-shaped.

Seems like runNetwork for both LSTMs, then combine their output, and feed that into runNetwork for the combining network. Then take the target output, and runBackward through the combining network to get target results for the two LSTMs, and then runGradient/applyUpdate for all networks should do the trick. I'll give it a try, see how it works out.

Thanks for your help!

@HuwCampbell
Copy link
Owner

HuwCampbell commented Jul 12, 2017

That's right.
Only difference is you'll need runRecurrent and backPropagateRecurrent for the LSTM nets.

Edit. Sorry:
runRecurrentForwards and runRecurrentBackwards would also be useful.

@cpennington
Copy link
Author

Cool, I'm making progress on this. One question that came up as I was working is whether there's an easy way to construct an all-zero vector for a particular RecurrentInput shape. I want to make sure my network is always starting from the same state at the start of every game.

@HuwCampbell
Copy link
Owner

HuwCampbell commented Jul 13, 2017

You can just use the literal 0.

S is an instance of Num so has fromInteger. In fact RecurrentInputs xs is also an instance of Num, so that should work for the entire stack.

If you look at the code for backPropagateRecurrent you can see I do this (for the back propagated sideways gradients at least).

@HuwCampbell
Copy link
Owner

HuwCampbell commented Jul 14, 2017

Please see #32

In that branch, this will compile

type R = Recurrent
type F = FeedForward

type ShapeInput = 'D1 10

type LearnPlayer = RecurrentNetwork
   '[ R (LSTM 10 20) ]
   '[ ShapeInput , D1 20 ]

type LearnOpponent = RecurrentNetwork
   '[ R (LSTM 10 20) ]
   '[ ShapeInput, D1 20 ]

type RecNet = RecurrentNetwork
    '[ R (
        ConcatRecurrent
          (D1 20)
          (R LearnPlayer)
          (D1 20)
          (R LearnOpponent)
        )
    ]
   '[ ShapeInput, 'D1 40 ]

randomNet :: MonadRandom m => m RecNet
randomNet = randomRecurrent

@HuwCampbell
Copy link
Owner

I believe this is fixed, but feel free to follow up with any problems you're having.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants