# Interactive version of GPLVM

In [1]:
import sys

import random
import numpy as np
import numpy.random

import matplotlib
matplotlib.use("Qt5Agg")
from PyQt5 import QtGui, QtCore, QtWidgets
from PyQt5.QtCore import QPointF
from PyQt5.QtWidgets import QSizePolicy
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

In [2]:
class RenderAndMouse(QtWidgets.QWidget):
    def __init__(self, parent, viewport, cb):
        super().__init__(parent)
        self.viewport = viewport
        self.data = [[-1,-1], (1,0), (1,1)]
        self.setMouseTracking(True)
        self.setMinimumSize(300, 300)
        self.mousemove = cb
    def mouseMoveEvent(self, event):
        #print(event.pos(), self.size())
        vp = self.viewport
        w, h = self.width(), self.height()
        x = event.x() / w * (vp[2] - vp[0]) + vp[0]
        y = vp[3] - event.y() / h * (vp[3] - vp[1])
        self.mousemove(x, y)
    def sizeHint(self):
        return QtCore.QSize(self.parent().width(), 400)
    def feed(self, data):
        self.data = data
        self.repaint()
    def paintEvent(self, QPaintEvent):
        vp = self.viewport
        w, h = self.width(), self.height()
        def tr(p):
            u, v = p
            x = (u - vp[0]) / (vp[2] - vp[0]) * w
            y = (v - vp[1]) / (vp[3] - vp[1]) * h
            return QPointF(x, h-y)
        paint = QtGui.QPainter()
        paint.begin(self)
        #self.data[0][0] = -random.random()
        paint.drawPolyline(*[tr(p) for p in self.data])
        paint.end()

In [3]:
class MatplotlibCanvas(FigureCanvas):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = fig.add_subplot(111)
        # We want the axes cleared every time plot() is called
        self.axes.hold(False)

        self.compute_initial_figure()

        #
        FigureCanvas.__init__(self, fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self,
                QSizePolicy.Expanding,
                QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

    def compute_initial_figure(self):
        pass

In [4]:
class PredictedFrame(MatplotlibCanvas):
    """A canvas that updates itself every second with a new plot."""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        #timer = QtCore.QTimer(self)
        #timer.timeout.connect(self.update_figure)
        #timer.start(1000)

    def compute_initial_figure(self):
        #self.axes.plot([0, 1, 2, 3], [1, 2, 0, 4], 'r')
        pass

    def update_figure(self, data):
        self.axes.imshow(data, cmap='Greys_r')
        self.draw()

In [5]:
class Explorer(QtWidgets.QMainWindow):
    def __init__(self, produce_image):
        self.produce_image = produce_image
        QtWidgets.QMainWindow.__init__(self)
        pane = QtWidgets.QWidget()
        self.setCentralWidget(pane)
        layout = QtWidgets.QVBoxLayout()
        pane.setLayout(layout)

        self.r_and_m = RenderAndMouse(self, (-3, -3, 3, 3), self.regenerate_image)
        self.predicted = PredictedFrame()
        layout.addWidget(self.r_and_m)
        layout.addWidget(self.predicted)
        #self.edit = QtWidgets.QLineEdit(self)
        self.updateGeometry()
        self.regenerate_image(0, 0)
    def regenerate_image(self, x, y):
        im, x2d = self.produce_image(x, y)
        self.r_and_m.feed(x2d)
        self.predicted.update_figure(im)

In [6]:
def save_list(fname, l):
    f = open(fname, "wb")
    for e in l:
        np.save(f, e)
    f.close()
    
def load_list(fname, count):
    f = open(fname,"rb")
    res = [np.load(f) for _ in range(count)]
    f.close()
    return list(res)

In [7]:
import cv2
import GPy
import os

w,h = 120,90
sample = np.load('data/samplevideo.npy')
N = sample.shape[0] # count of samples
D = sample.shape[1] # dimensionality of observed space
Q = 20 # dimensionality of the latent space
X = np.zeros((N,Q)) # Initialize the latent variable
#normilizing data:
sample_mean = np.mean(sample, 0)
sample_std = np.std(sample, 0)
sample = (sample - sample_mean)/sample_std # normalization

kernel = GPy.kern.RBF(Q,ARD=True)
if os.path.exists("cache.npy"):
    (X, Kx) = load_list("cache.npy", 2)
else:
    m = GPy.models.GPLVM(sample, input_dim=Q, kernel=kernel, init='PCA')
    m.optimize(messages=True, optimizer = 'SCG', max_iters=1000)
    X = m.X
    Kx = kernel.K(X[0:N-1])
    save_list("cache.npy", [X, Kx])
    kernel = None # prevent reusing the kernel as it is not saved/reloaded properly 
N_synt = 1
X_synt = np.zeros((N_synt,Q))
Kx_inv = np.linalg.inv(Kx)
X_synt[0,:] = X[N-1:N]

In [None]:
def produce_image(x0, x1):
    X_synt[0,0] = x0
    X_synt[0,1] = x1
    Y_synt = np.zeros((N_synt,D))

    kernel2 = GPy.kern.RBF(Q,ARD=True)
    Ky = kernel2.K(X)
    Ky_inv = np.linalg.inv(Ky)
    Y_t = np.transpose(sample)
    
    k_y = kernel2.K(X,X_synt)
    k_yy = kernel2.K(X_synt,X_synt)
    meann = np.matmul(np.matmul(Y_t,Ky_inv),k_y)
    cov = (k_yy - np.matmul(np.matmul(np.transpose(k_y),Ky_inv),k_y))
    for i in range(D):
        Y_synt[:,i] = np.matmul(cov,np.random.randn(N_synt)) + meann[i,:]
    frame = Y_synt[0,:]
    return frame.reshape((h,w)), X[:, 0:2]

In [None]:
if __name__ == '__main__':
    import sys
    app = QtWidgets.QApplication(sys.argv)
    win = Explorer(produce_image)
    win.show()
    sys.exit(app.exec_())

sys.exit()