Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketch-RNN #11

Closed
shiffman opened this issue Oct 26, 2017 · 31 comments
Closed

Sketch-RNN #11

shiffman opened this issue Oct 26, 2017 · 31 comments

Comments

@shiffman
Copy link
Member

Simple example using sketch-rnn plus p5.js could be integrated as part of this project, or at least linked to! (cc @hardmaru yet again!)

@nonoesp
Copy link

nonoesp commented May 4, 2018

Hi all!

@garciadelcastillo and myself have been working on modularizing SketchRNN for a few projects to generate a drawing prediction from an input set strokes. I hope this can help porting SketchRNN to the ml5-library.

Right now, this simple_predict.js module has the logic to load a model from a local file, set input strokes as relative or absolute, and generate a drawing prediction. (It contains the guts of the original simple_predict demo sketch plus some helpers to work with both relative and absolute sketch coordinates.)

One of our main goals was to serve SketchRNN as an HTTP service (see http-server.js) and a WebSocket client (see websocket-client.js). A sample use of the HTTP service is this p5 sketch.

Also, this video (minute 4:30) explains how SketchRNN encodes its strokes. Internally, each position is a relative movement from the previous position.

=)

@cvalenzuela
Copy link
Member

Hi @nonoesp!

This looks very cool! It would be awesome to have to a ml5.SketchRNN() class!

I imagine we could have a collection of models and a few examples to play with. The model you are using are not in the repo you linked to? I couldn't find them. Just wondering if you are preprocessing them in any way.

If you and @garciadelcastillo are interested, I will be glad to help push a PR to incorporte this.
@shiffman, thoughts?

@garciadelcastillo
Copy link

Hi @cvalenzuela,

We have been working a lot with the library these days, would be happy to find some time to do a formal PR. However, some thoughts:

  • We are currently wrapping the Sketch-RNN functionality inside an http server, and serving it via POST requests. I wonder how ml5 serves other similar models and deals with models-as-a-service (`MaaS'? 💥)

  • All the trained models (which we are not processing at all) occupy 1.5Gb approx, it is quite a large download. The library could be written to request models ad-hoc from a CDN, but it would slow the process down tremendously (most models are 11Mb in size). Thoughts?

JL

@nonoesp
Copy link

nonoesp commented May 5, 2018

Hi @cvalenzuela! I'd be glad to help with this as well.

Here is a list with all the generative models. By changing gen to vae you can download the full variational auto-encoder model (which allows to use latent vectors). It would make sense to download them on-demand as in the sketch-rnn demo.

From the tensorflow/magenta-demos repo:

Pre-trained weight files

The RNN model has 2 modes: unconditional and conditional generation. Unconditional generation means the model will just generate a random vector image from scratch and not use any latent vectors as an input. Conditional generation mode requires a latent vector (128-dim) as an input, and whatever the model generates will be defined by those 128 numbers that can control various aspects of the image.
 
Whether conditional or not, all of the raw weights of these models are individually stored as .json files inside the models directory. For example, for the 'butterfly' class, there are 2 models that come pretrained:
 
butterfly.gen.json - unconditional model
 
butterfly.vae.json - conditional model

@cvalenzuela
Copy link
Member

cvalenzuela commented May 5, 2018

thanks @nonoesp and @garciadelcastillo!

What kind of functionality are you wrapping in the server? Or is it just serving the .json files?

We try to keep ml5 as "client-side" as possible. We just fetch weights, when necessary, from a constant URL to keep the library small. So if the server you are running is just storing the urls for those .json files, I imagine that an on-demand approach will be the best. This might look something like this:

// Providing the 'cat' attribute will make the class fetch the right .json file
let catRNN = new ml5.SketchRNN('cat', onModelLoaded);

// Callback when the model loads
function onModelLoaded() {
  // Generate
  catRNN.generate();
}

sketch-rnn demo demo takes a couple of seconds to download a model, I guess that's fine for our case too.

@shiffman
Copy link
Member Author

shiffman commented May 6, 2018

I am so excited about this! I can imagine a server-side component for ml5 eventually but I agree with @cvalenzuela that coming up a client-side only example first would be great. With ml5 we are also not as concerned with perfection/accuracy as we are with ease of use and friendliness. So sacrificing some quality for smaller model files is something we can explore/discuss too.

I wonder as a step 2 (or 200?) if there is a way we can do either transfer learning or training from scratch also with new user data.

@cvalenzuela
Copy link
Member

@hardmaru mentioned to me he was interested in helping make this happen!

@nonoesp
Copy link

nonoesp commented Jun 19, 2018

Hi! Sorry for the radio silence!

@garciadelcastillo and I were experiment with serving SketchRNN (and other libraries) over HTTP and WebSocket for a workshop, to have participants interact with machine learning libraries (such as SketchRNN) from different coding environments.

I'm currently porting the barebones of simple_predict.js as an ml5 module in the nonoesp-sketchrnn branch—still really work-in-progress. This module would run on the client side, potentially loading models from the same source as sketch-rnn-demo is.

https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/category.type.json

Here, you would access the generative model of a bird, bicycle, or angel, with the following URLs. (More on this here.)

Of course, I expect @hardmaru to be able to expose a lot more functionality.

@shiffman
Copy link
Member Author

shiffman commented Jun 20, 2018

This is very exciting! It probably makes sense for us to mirror the ml5.LSTM API to the extent that it makes sense. Building on @cvalenzuela's earlier comment, I'm thinking a simple example could look something like. . .?

const sketchRNN = ml5.sketchRNN('rainbow', modelReady);

function modelReady() {
  console.log('Ready to generate');
}

function setup() {
  createCanvas(400, 400);
  // These would all be optional it could generate something by default?

  // A seed and options would be optional?
  // Seed would be array of objects with x,y, and pen state?
  let initialSketch = [
    {x: 100, y: 100, pen: true},
    {x: 100, y: 200, pen: true},
    {x: 200, y: 200, pen: false}
  ];
  let options = {
    temperature: 0.5,  // temperature
    length: 100,       // how many points of a path to generate
    seed: initialSketch
  };
  sketchRNN.generate(options, gotSketch);
}

function gotSketch(sketch) {
  // "sketch" is an array of objects with x, y, and pen state?
  for (let i = 1; i < sketch.length; i++) {
    let current = sketch[i];
    if (current.pen) {
      let previous = sketch[i-1];
      line(previous.x, previous.y, current.x, current.y);
    }
  }
}

We could consider integrating with p5.Vector but perhaps this would tie it too closely to p5?

@hardmaru
Copy link

hardmaru commented Jun 20, 2018

Hi @shiffman ! Thanks for all the discussion.

That's a nice suggestion. In my original model api, I did it where we sample each point incrementally, rather than sample the entire drawing, since it might allow for more creative applications, such as allowing the algorithm to extend what the user has drawn.

When I wrote the model, deeplearn.js wasn't available yet so I just implemented my own LSTM using javascript, but the interface in the code shouldn't be too difficult to port over to your ml5.LSTM. That being said, the code as it is now is fairly efficient and works quite fast on the client side even on an old mobile device.

I've also been thinking of cleaning up an old script that can convert TensorFlow-trained sketch-rnn models over to the compressed JSON format that the JS version can use, so in theory we can use non-quickdraw datasets. Will probably try to do that first.

@shiffman
Copy link
Member Author

@hardmaru yes, that makes a lot of sense! Perhaps the default behavior can be to just sample one point at a time with an option to ask for an array? My concern is that I'm assuming that sampling will require a callback which could get quite confusing for a beginner trying to do something with a draw() loop in p5. Would it be able to do one point at a time without a callback, i.e.?

let sketchRNN = ml5.sketchRNN('rainbow', modelReady);
let ready = false;
let previous = null;

// Using preload would make this much simpler for a beginner example!
function modelReady() {
  ready = true;
}

function setup() {
  createCanvas(400, 400);
}

function draw() {
  if (ready) {
    // optionally can pass in a seed / temperature, etc.?
    let next = sketchRNN.generate();
    if (previous && next.pen) {
      line(previous.x, previous.y, next.x, next.y);
      previous = next;
    }
  }
}

@hardmaru
Copy link

hardmaru commented Jun 20, 2018

@shiffman that seems nice and simple enough to understand, I like it! I wonder with this example, how would one want to create a demo where we let the user start a sketch, and have sketchRNN finish it? maybe will have to separately encode that feature somehow.

In the version I had, it might be more straight forward to extend to do such things, at the expense of a bit more complexity, which is always a tradeoff:

... initialization code before draw(), see doc

function draw() {
  // see if we finished drawing
  if (prev_pen[2] == 1) {
    p.noLoop(); // stop drawing
    return;
  }

  // using the previous pen states, and hidden state, get next hidden state
  // the below line takes the most CPU power, especially for large models.
  rnn_state = model.update([dx, dy, pen_down, pen_up, pen_end], rnn_state);

  // get the parameters of the probability distribution (pdf) from hidden state
  pdf = model.get_pdf(rnn_state);

  // sample the next pen's states from our probability distribution
  [dx, dy, pen_down, pen_up, pen_end] = model.sample(pdf, temperature);

  // only draw on the paper if the pen is touching the paper
  if (prev_pen[0] == 1) {
    p.stroke(line_color);
    p.strokeWeight(2.0);
    p.line(x, y, x+dx, y+dy); // draw line connecting prev point to current point.
  }

  // update the absolute coordinates from the offsets
  x += dx;
  y += dy;

  // update the previous pen's state to the current one we just sampled
  prev_pen = [pen_down, pen_up, pen_end];
}

In this version, since the prev_pen state can be either sampled using SketchRNN's pdf (your generate()), or can be overwritten using the user's actual mouse/touch movement, it is easy to incorporate the interactive component to get the human in the loop. Maybe there can be an elegant way to incorporate this into your proposed framework too, maybe something like:

let sketchRNN = ml5.sketchRNN('rainbow', modelReady, optionalTemperature, optionalSeed);
let ready = false;
let previous = null;

// Using preload would make this much simpler for a beginner example!
function modelReady() {
  ready = true;
}

function setup() {
  createCanvas(400, 400);
}

function draw() {
  if (ready) {
    let next = sketchRNN.generate();
    if (previous && next.pen) {
      line(previous.x, previous.y, next.x, next.y);
      previous = next;
    }
    sketchRNN.update(previous);
    // previous can be overwritten by human input, so doesn't necessarily have to be what is generated by sketchRNN
  }
}

@shiffman
Copy link
Member Author

shiffman commented Jun 20, 2018

Ah yes, this makes sense! We should definitely allow for the user to pass in human input and override the model's generated data. This could also possibly be an argument to generate() where it's something like:

    // previous can be overwritten by human input!
    let next = sketchRNN.generate(previous);
    if (next.pen) {
      line(previous.x, previous.y, next.x, next.y);
      previous = next;
    }

In looking at your code I see that the model provides dx,dy rather than literal x,y coordinates. I think this makes sense to keep, I was just making stuff up without looking closely!

We can probably conflate pen_up and pen_down into one state pen (true or false)? What is pen_end?

@hardmaru
Copy link

hardmaru commented Jun 21, 2018

Hi @shiffman

In addition to modelling when the pen should touch the canvas and when it should be lifted away from the canvas, Sketch-RNN also models when to finish drawing (via the event pen_end). So [pen_down, pen_up, pen_end] is a one-hot vector sampled from a categorial distribution.

Unlike an LSTM generating Hemingway forever, if we let an LSTM doodle birds without end, it will fill the entire canvas with black ink eventually (i.e. kanji example)!

@hardmaru
Copy link

Hi @shiffman @cvalenzuela @nonoesp @garciadelcastillo

A few updates from me:

  1. I ported the sketch-rnn-js model over to TensorFlow.js using the TypeScript style of the magenta.js project. The API is very similar to sketch-rnn-js, but just GPU accelerated. I'll try to put this on the magenta.js repo soon, after porting over a few demos over and testing a few things.

  2. Wrote a small IPython notebook to show how to quickly train a sketch-rnn model with TensorFlow, and convert that model over to the JSON format that can be used by sketch-rnn-js (and the TensorFlow.js version in (1)): https://github.com/tensorflow/magenta-demos/blob/master/jupyter-notebooks/Sketch_RNN_TF_To_JS_Tutorial.ipynb

After I put (1) out it should be fairly easy to wrap ml5.js over it so that sketch-rnn can be readily available.

Currently, this is how I deal with the model loading in magenta.js but I think the ml5.js way is more elegant:

var sketch = function( p ) {
  "use strict";

  console.log("SketchRNN JS demo.");
  var model;
  var dx, dy; // offsets of the pen strokes, in pixels
  var pen_down, pen_up, pen_end; // keep track of whether pen is touching paper
  var x, y; // absolute coordinates on the screen of where the pen is
  var prev_pen = [1, 0, 0]; // group all p0, p1, p2 together
  var rnn_state; // store the hidden states of rnn's neurons
  var pdf; // store all the parameters of a mixture-density distribution
  var temperature = 0.45; // controls the amount of uncertainty of the model
  var line_color;
  var model_loaded = false;

  // loads the TensorFlow.js version of sketch-rnn model, with the "cat" model's weights.
  model = new ms.SketchRNN("https://storage.googleapis.com/quickdraw-models/sketchRNN/models/cat.gen.json");

  Promise.all([model.initialize()]).then(function() {
    // initialize the scale factor for the model. Bigger -> large outputs
    model.set_pixel_factor(3.0);

    // initialize pen's states to zero.
    [dx, dy, pen_down, pen_up, pen_end] = model.zero_input(); // the pen's states

    // zero out the rnn's initial states
    rnn_state = model.zero_state();

    model_loaded = true;
    console.log("model loaded.");
  });

  p.setup = function() {
    var screen_width = p.windowWidth; //window.innerWidth
    var screen_height = p.windowHeight; //window.innerHeight
    x = screen_width/2.0;
    y = screen_height/3.0;
    p.createCanvas(screen_width, screen_height);
    p.frameRate(60);

    // define color of line
    line_color = p.color(p.random(64, 224), p.random(64, 224), p.random(64, 224));
  };

  p.draw = function() {
    if (!model_loaded) {
      return;
    }
    // see if we finished drawing
    if (prev_pen[2] == 1) {
      p.noLoop(); // stop drawing
      return;
    }

    // using the previous pen states, and hidden state, get next hidden state
    // the below line takes the most CPU power, especially for large models.
    rnn_state = model.update([dx, dy, pen_down, pen_up, pen_end], rnn_state);

    // get the parameters of the probability distribution (pdf) from hidden state
    pdf = model.get_pdf(rnn_state, temperature);

    // sample the next pen's states from our probability distribution
    [dx, dy, pen_down, pen_up, pen_end] = model.sample(pdf);

    // only draw on the paper if the pen is touching the paper
    if (prev_pen[0] == 1) {
      p.stroke(line_color);
      p.strokeWeight(2.0);
      p.line(x, y, x+dx, y+dy); // draw line connecting prev point to current point.
    }

    // update the absolute coordinates from the offsets
    x += dx;
    y += dy;

    // update the previous pen's state to the current one we just sampled
    prev_pen = [pen_down, pen_up, pen_end];
  };

};
var custom_p5 = new p5(sketch, 'sketch');

@cvalenzuela
Copy link
Member

Amazing @hardmaru! this will be super nice to have in ml5. Let us know when you publish your code so when can make a wrapper around it!

Once that is ready, we can also put the training instructions and script here: https://ml5js.org/docs/training-introduction

@nonoesp
Copy link

nonoesp commented Jul 13, 2018

Nice! So glad to hear about TypeScript @hardmaru, and thanks so much for sharing the IPython notebook. Looking forward to the release.

@hardmaru
Copy link

I put the code in my fork for now but should be merged in the next few days.

There are 3 working demos that use sketch-rnn with TensorFlow.js, linked in the README.md

@hardmaru
Copy link

hardmaru commented Aug 6, 2018

The TensorFlow.js has gone thru code review and accepted into the main repo. The current interface is more or less inspired by the p5.js style workflow, and in fact all the demos use p5.js

https://github.com/tensorflow/magenta-js/blob/master/sketch/README.md

The next step is to try to wrap it over with ml5.js and make the ml5.SketchRNN() class.

@cvalenzuela
Copy link
Member

great! I'll make a branch and start working on it

@shiffman
Copy link
Member Author

shiffman commented Aug 6, 2018

Yay! I am so excited about this! I would love to help work on this too.

@reiinakano reiinakano mentioned this issue Aug 6, 2018
@hardmaru
Copy link

hardmaru commented Aug 7, 2018

Thanks for the help @reiinakano @cvalenzuela @shiffman

I published a more optimized version 0.1.2 (no change to the API) today:

https://www.npmjs.com/package/@magenta/sketch

This version reduced the number of dataSync() calls and improves performance by a little bit.

@cvalenzuela
Copy link
Member

Added in #189!

@shiffman
Copy link
Member Author

I'm briefly re-opening this issue to cover some API decisions @cvalenzuela just made in our weekly ml5 meeting!

Instead of storing the sketch data as:

var initialStrokes = [
  [-4, 0, 1, 0, 0],
  [-15, 9, 0, 1, 0],
  [-10, 17, 0, 0, 1]
];

we propose:

var initialStroke = [
  { dx: -4, dy: 0, pen: "down"}, 
  { dx: -15, dy: 9, pen: "up"}, 
  { dx: -10, dy: 17, pen: "end"}
];

and then sketch data generated would look like:

function gotResult(err, result) {
  if (previous.pen === "down") 
    stroke(255, 0, 0)
    strokeWeight(3.0);
    line(x, y, x + result.dx, y + result.dy);
  }

  x += result.dx;
  y += result.dy;
  previous = result; 
}

Feel free to weigh in with any thoughts or comments!

@shiffman shiffman reopened this Oct 11, 2018
@hardmaru
Copy link

hardmaru commented Oct 11, 2018 via email

@hardmaru
Copy link

hardmaru commented Oct 11, 2018 via email

@shiffman
Copy link
Member Author

shiffman commented Oct 11, 2018

Ah, yes this is a very good point! I think this relates (?) to the current discussion about stateful LSTM's in this pull request! I wonder if we could adopt a similar API for SketchRNN where we have simple "generate a drawing mode" as well as "generate one pen motion at a time where the user can take over" etc.

See: #221 (comment) for more.

Something like:

function draw() {
  if (user is drawing) {
    var next =  { 
      dx: mouseX - pmouseX,
      dy: mouseY - pmouseY, 
      pen: "down" // dynamic based on mouseIsPressed?
    };
    line(mouseX, mouseY, pmouseX, pmouseY);
    sketchRNN.update(next);
    previous = next;
  } else if (model is drawing) {
    let next = sketchRNN.next(0.1);
    if (previous.pen === "down") 
      stroke(255, 0, 0)
      strokeWeight(3.0);
      line(x, y, x + result.dx, y + result.dy);
    }
    x += result.dx;
    y += result.dy;
    previous = next; 
    sketchRNN.update(next);
  }
}

I'm ignoring the asynchronous aspect here and making up variables but is this the right idea?

@hardmaru
Copy link

hardmaru commented Oct 12, 2018

The way I handled the interactivity is to completely abandon the async nature of the API (although this might be the wrong decision since there is a tradeoff vs performance).

In the current magenta version of sketch-rnn (https://www.npmjs.com/package/@magenta/sketch?activeTab=readme), the API is basically completely synchronous, and the code is similar to what your comment describes. Here is the sketch loop for generating a sketch:

function draw() {
 
  // see if we finished drawing
  if (prev_pen[2] == 1) {
    noLoop(); // stop drawing
    return;
  }
 
  // using the previous pen states, and hidden state, get next hidden state
  // the below line takes the most CPU power, especially for large models.
  rnn_state = model.update([dx, dy, pen_down, pen_up, pen_end], rnn_state);
 
  // get the parameters of the probability distribution (pdf) from hidden state
  pdf = model.getPDF(rnn_state, temperature);
 
  // sample the next pen's states from our probability distribution
  [dx, dy, pen_down, pen_up, pen_end] = model.sample(pdf);
 
  // only draw on the paper if the pen is touching the paper
  if (prev_pen[0] == 1) {
    stroke(line_color);
    strokeWeight(3.0);
    line(x, y, x+dx, y+dy); // draw line connecting prev point to current point.
  }
 
  // update the absolute coordinates from the offsets
  x += dx;
  y += dy;
 
  // update the previous pen's state to the current one we just sampled
  prev_pen = [pen_down, pen_up, pen_end];
};

So to incorporate the interactivity, I can just override what sketch-rnn generates with what the user draws using the mouse/tablet data in the draw loop.

Maybe an easy way is to leave the current mode for async, and copy in the non-async api from the magenta version (with the syntactic sugar and also remaining to dx/dy/pen state names)?

@shiffman shiffman mentioned this issue Oct 17, 2018
@shiffman
Copy link
Member Author

As a matter of update I have a working example for my A2Z class here:

https://github.com/shiffman/A2Z-F18/tree/master/week8-charRNN/04_sketchRNN

Are certain models there automatically and others I'll need to download? Right now it works with "cat" out of the box. Next step is I'll work on the SketchRNN class to implement some of the feature suggestions in this thread, as well as make an example with interactivity.

cats3

@hardmaru
Copy link

hardmaru commented Nov 1, 2018

Looks fun! The pre-trained models are all in JSON format that should be dynamically loaded.

There's a few interactive demos in the magenta-js version that can prob be ported to this version (though the API will prob need to be refactored depending on the level of abstraction we want to give the user):

https://github.com/tensorflow/magenta-js/tree/master/sketch

@shiffman
Copy link
Member Author

Closing!!! (New issues coming with remaining to do's. . .)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants