In [1]:
import pickle

import torch
import numpy as np
import matplotlib.pyplot as plt 
from matplotlib.widgets import Slider, Button, RadioButtons
from matplotlib.colors import hsv_to_rgb
from torch.autograd import Variable

from utils import Normalization


class Feasibility:

    def __init__(self):
        from feasibility_gan import GeneratorProduction
        self.generator = GeneratorProduction()
        with open('generator_production.pkl', 'rb') as f:
            self.generator.load_state_dict(pickle.load(f))
        self.generator.eval()

    def sample(self, width, height, mass, friction, x, y, θ, x_, y_, θ_, n_samples):
        obj = Variable(torch.FloatTensor([[width,
                                           height,
                                           mass,
                                           friction,
                                           x,
                                           y,
                                           θ]]))
        obj_ = Variable(torch.FloatTensor([[x_,
                                            y_,
                                            θ_]]))

        return self.generator(obj, obj_, n_samples=n_samples).data.numpy()
    
    
feasibility = Feasibility()

In [40]:
%matplotlib auto
fig, ax = plt.subplots(figsize=(5, 6.5))
plt.subplots_adjust(left=0.1, bottom=0.28)
n_samples = 256
samples = [plt.plot(0, 0, 'o', markersize=3.0, alpha=0.3)[0] for _ in range(n_samples)]
size = 0.75
plt.axis([-size, size, -size, size])
plt.grid()
n_circles = 10
circles = []
circle = plt.Circle((0, 0), radius=0.05, linewidth=0.5, ec='k', fc='#00dd00', color='red')
goal = plt.Circle((0, 0), radius=0.05, linewidth=0.5, ec='#00aa00', ls='--', fill=False, color='red')
plt.gca().add_artist(circle)
plt.gca().add_artist(goal)

axangle = plt.axes([0.15, 0.03, 0.65, 0.03])
axdist = plt.axes([0.15, 0.08, 0.65, 0.03])
#axmass = plt.axes([0.15, 0.20, 0.65, 0.03])
axsize = plt.axes([0.15, 0.13, 0.65, 0.03])
axfriction = plt.axes([0.15, 0.18, 0.65, 0.03])

angle = Slider(axangle, 'Angle', 0.0, 360, valinit=90)
dist = Slider(axdist, 'Distance', 0.0, 0.6, valinit=0.3)
#mass = Slider(axmass, 'Mass', 0.01, 0.2, valinit=0.1)
friction = Slider(axfriction, 'Friction', 0.01, 0.3, valinit=0.15)
radius = Slider(axsize, 'Radius', 0.01, 0.10, valinit=0.05)
fig_id = 0


def update(val):
    global fig_id
    θ = angle.val / 180 * np.pi
    res = feasibility.sample(width=radius.val,
                             height=radius.val,
                             mass=mass.val,
                             friction=friction.val,
                             x=0.0,
                             y=0.0,
                             θ=0.0,
                             x_=np.cos(θ) * dist.val,
                             y_=np.sin(θ) * dist.val,
                             θ_=0.0,
                             n_samples=n_samples)
    circle.set_radius(radius.val)
    goal.set_radius(radius.val)
    goal.center = (np.cos(θ) * dist.val, np.sin(θ) * dist.val)
    for sample, sample_ in zip(samples, res):
        sample.set_xdata(sample_[0])
        sample.set_ydata(sample_[1])
        θ = np.arctan2(np.sin(sample_[2]), np.cos(sample_[2]))
        hue = (θ + np.pi) / (2 * np.pi)
        hue = (hue + 0.25) % 1.0
        sample.set_color(hsv_to_rgb((hue, 1, 1.0)))
    fig.canvas.draw_idle()
    plt.savefig('anim/animated_{:03}.png'.format(fig_id))
    fig_id += 1
angle.on_changed(update)
dist.on_changed(update)
#mass.on_changed(update)
friction.on_changed(update)
radius.on_changed(update)

plt.show()

Using matplotlib backend: TkAgg
