In [1]:
import numpy as np
import time

In [2]:
from factorgp import FactorGP
from inference import run_gibbs

In [3]:
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF

In [4]:
gp = GaussianProcessRegressor(kernel=RBF(length_scale=10))

In [24]:
x = np.linspace(1, 100, 100)
curves = gp.sample_y(x.reshape((100, 1)), 4)
F = curves

In [25]:
loading = np.random.normal(0, 1, [4, 50])
Y_true = np.matmul(F, loading)

In [26]:
t, q = Y_true.shape
n = 1
Y = np.zeros((n * t, q))
for i in range(n):
    Y[(i * t):(i * t + t), :] = Y_true + np.random.normal(0, 0.5, [100, 50])

In [27]:
print(Y.shape)

(100, 50)


In [28]:
dims = [n, 100, 50, 4]  # n, t, q, r
model = FactorGP(dims)

In [29]:
start = time.time()
model.update_conditional_latent(Y)
end = time.time()
print(end - start)

7.72900485992


In [30]:
results = run_gibbs(Y, model, 100, 10, 0.5, verbose=True)




  0%|          | 0/100 [00:00<?, ?it/s][A[A[A


  1%|          | 1/100 [00:10<17:55, 10.87s/it][A[A[A

Current MSE: 0.362173493998
Current length scale: [8.78614837 6.13471585 7.95762993 7.89223949]





  2%|▏         | 2/100 [00:22<18:05, 11.08s/it][A[A[A

Current MSE: 0.257246115738
Current length scale: [9.02450591 6.19349269 7.95866062 7.89247732]





  3%|▎         | 3/100 [00:34<18:18, 11.33s/it][A[A[A

Current MSE: 0.253840641845
Current length scale: [8.7730379  6.41092371 8.89126453 7.71583901]





  4%|▍         | 4/100 [00:45<17:53, 11.18s/it][A[A[A

Current MSE: 0.253264138947
Current length scale: [8.86954942 6.25024528 8.8971833  7.60344375]





  5%|▌         | 5/100 [00:56<17:33, 11.09s/it][A[A[A

Current MSE: 0.251486391106
Current length scale: [8.48053513 6.1472792  9.23801467 7.78945874]





  6%|▌         | 6/100 [01:06<17:17, 11.03s/it][A[A[A

Current MSE: 0.250725810254
Current length scale: [ 8.27469801  6.11510277 10.69698866  8.17831401]





  7%|▋         | 7/100 [01:17<17:01, 10.98s/it][A[A[A

Current MSE: 0.252397725944
Current length scale: [ 8.48611844  6.26035239 10.98266583  7.83037067]





  8%|▊         | 8/100 [01:29<17:11, 11.21s/it][A[A[A

Current MSE: 0.25585093654
Current length scale: [ 8.3663322   6.47141378 11.05400754  7.61648063]





  9%|▉         | 9/100 [01:40<16:48, 11.08s/it][A[A[A

Current MSE: 0.255516117233
Current length scale: [ 8.93035734  6.42488378 11.00306828  7.95496374]





 10%|█         | 10/100 [01:51<16:27, 10.98s/it][A[A[A

Current MSE: 0.254145172662
Current length scale: [ 9.19111516  6.39406668 10.92408694  7.95428745]





 11%|█         | 11/100 [02:02<16:31, 11.14s/it][A[A[A

Current MSE: 0.251720150135
Current length scale: [ 9.11001508  7.09048124 10.89442085  8.40412955]





 12%|█▏        | 12/100 [02:15<17:11, 11.72s/it][A[A[A

Current MSE: 0.251430351618
Current length scale: [ 8.93984211  7.13181422 10.60837608  9.18992184]





 13%|█▎        | 13/100 [02:27<17:00, 11.73s/it][A[A[A

Current MSE: 0.253534029545
Current length scale: [ 9.61864523  7.03732586 10.68185003  9.4370102 ]





 14%|█▍        | 14/100 [02:38<16:39, 11.62s/it][A[A[A

Current MSE: 0.252265070428
Current length scale: [ 9.22729252  6.91513222  9.963425   10.00793565]





 15%|█▌        | 15/100 [02:50<16:24, 11.58s/it][A[A[A

Current MSE: 0.250913552977
Current length scale: [ 9.14736675  6.72721843 10.43597747  9.68164837]





 16%|█▌        | 16/100 [03:02<16:27, 11.76s/it][A[A[A

Current MSE: 0.252510835521
Current length scale: [ 8.90483122  7.29335063 10.27111438  9.64035001]





 17%|█▋        | 17/100 [03:13<16:10, 11.69s/it][A[A[A

Current MSE: 0.252498958962
Current length scale: [ 8.79607948  7.85897325 10.26985634  9.37465006]





 18%|█▊        | 18/100 [03:26<16:09, 11.83s/it][A[A[A

Current MSE: 0.251268552403
Current length scale: [ 9.34955301  8.1599313  10.00982991  9.25893598]





 19%|█▉        | 19/100 [03:42<17:42, 13.12s/it][A[A[A

Current MSE: 0.252269135235
Current length scale: [ 9.34424309  8.37596649 10.19867374  9.59752754]





 20%|██        | 20/100 [03:53<16:50, 12.63s/it][A[A[A

Current MSE: 0.251775891105
Current length scale: [ 9.49641937  9.0453018  10.83477457  9.27221803]





 21%|██        | 21/100 [04:04<16:04, 12.21s/it][A[A[A

Current MSE: 0.25341134406
Current length scale: [ 9.18199864  9.0248623  11.21773818  9.08942085]





 22%|██▏       | 22/100 [04:18<16:12, 12.46s/it][A[A[A

Current MSE: 0.2522275401
Current length scale: [ 9.42759857  9.60053022 10.88064513  9.41602931]





 23%|██▎       | 23/100 [04:30<15:54, 12.40s/it][A[A[A

Current MSE: 0.253297949409
Current length scale: [ 9.8570252   9.15257808 10.71683734  9.18904954]





 24%|██▍       | 24/100 [04:43<15:59, 12.62s/it][A[A[A

Current MSE: 0.253436169658
Current length scale: [ 9.46660333  8.97957585 10.75475622  9.5512656 ]





 25%|██▌       | 25/100 [04:56<16:01, 12.83s/it][A[A[A

Current MSE: 0.253072043294
Current length scale: [ 8.93361001  8.55303041 10.75660042  9.20307567]





 26%|██▌       | 26/100 [05:09<15:37, 12.67s/it][A[A[A

Current MSE: 0.252853957298
Current length scale: [ 9.9127002   8.58601551 10.71126491  8.87262394]





 27%|██▋       | 27/100 [05:21<15:17, 12.57s/it][A[A[A

Current MSE: 0.250482421025
Current length scale: [ 8.77958895  8.86795668 10.39603236  8.95216352]





 28%|██▊       | 28/100 [05:33<14:50, 12.37s/it][A[A[A

Current MSE: 0.251610560349
Current length scale: [8.73817305 9.46835131 9.54030173 9.00738583]





 29%|██▉       | 29/100 [05:45<14:31, 12.28s/it][A[A[A

Current MSE: 0.252329515386
Current length scale: [8.70029474 9.62403771 8.9705243  9.45246566]





 30%|███       | 30/100 [05:56<14:05, 12.08s/it][A[A[A

Current MSE: 0.252101600767
Current length scale: [8.45816322 9.19970695 8.40692025 9.44539414]





 31%|███       | 31/100 [06:08<13:36, 11.84s/it][A[A[A

Current MSE: 0.251335848435
Current length scale: [8.80389469 9.82500522 8.54656706 9.13285696]





 32%|███▏      | 32/100 [06:20<13:38, 12.04s/it][A[A[A

Current MSE: 0.253125212296
Current length scale: [8.65962368 9.49828498 8.73034695 9.5604259 ]





 33%|███▎      | 33/100 [06:34<14:04, 12.60s/it][A[A[A

Current MSE: 0.251623684832
Current length scale: [8.7795011  9.17611467 9.05459487 9.68791409]





 34%|███▍      | 34/100 [06:47<14:00, 12.73s/it][A[A[A

Current MSE: 0.2523153267
Current length scale: [8.98746039 8.83282601 8.96823652 9.30004444]





 35%|███▌      | 35/100 [06:58<13:08, 12.13s/it][A[A[A

Current MSE: 0.25181364526
Current length scale: [8.93908054 9.01615524 8.5772835  8.97482142]





 36%|███▌      | 36/100 [07:09<12:30, 11.72s/it][A[A[A

Current MSE: 0.253586149278
Current length scale: [8.69487945 9.00223155 8.85206975 8.93835974]





 37%|███▋      | 37/100 [07:22<12:42, 12.10s/it][A[A[A

Current MSE: 0.25245830019
Current length scale: [8.93607839 9.05367027 9.02809311 9.39299057]





 38%|███▊      | 38/100 [07:34<12:39, 12.24s/it][A[A[A

Current MSE: 0.252604621152
Current length scale: [8.84282696 9.09755735 9.20662207 9.55739186]





 39%|███▉      | 39/100 [07:45<11:58, 11.78s/it][A[A[A

Current MSE: 0.25013483523
Current length scale: [8.91053301 8.6390057  9.0773074  9.35550197]





 40%|████      | 40/100 [08:08<15:00, 15.02s/it][A[A[A

Current MSE: 0.25468516682
Current length scale: [9.0805732  8.35575654 8.70272617 9.25665854]





 41%|████      | 41/100 [08:19<13:41, 13.92s/it][A[A[A

Current MSE: 0.252810391135
Current length scale: [9.15478223 8.02331727 9.23353276 9.06146935]





 42%|████▏     | 42/100 [08:31<13:03, 13.51s/it][A[A[A

Current MSE: 0.251492557011
Current length scale: [9.5726563  8.31676382 9.33460391 8.51703954]





 43%|████▎     | 43/100 [08:52<14:48, 15.59s/it][A[A[A

Current MSE: 0.250767901941
Current length scale: [9.3735575  8.5711351  9.12951303 8.34609596]





 44%|████▍     | 44/100 [09:19<17:53, 19.16s/it][A[A[A

Current MSE: 0.252565441226
Current length scale: [9.23970974 8.10254365 8.71474192 8.52551619]





 45%|████▌     | 45/100 [09:35<16:41, 18.21s/it][A[A[A

Current MSE: 0.251581785594
Current length scale: [9.23970974 8.57634201 8.05578714 8.42508327]





 46%|████▌     | 46/100 [09:48<14:56, 16.60s/it][A[A[A

Current MSE: 0.252734979848
Current length scale: [9.24192803 8.69026486 8.13556425 8.320482  ]





 47%|████▋     | 47/100 [10:00<13:17, 15.04s/it][A[A[A

Current MSE: 0.251817401261
Current length scale: [9.35866898 8.3623212  8.08717309 8.40860714]





 48%|████▊     | 48/100 [10:11<12:09, 14.03s/it][A[A[A

Current MSE: 0.251111880016
Current length scale: [8.84191074 8.02422022 8.28915947 8.5641187 ]





 49%|████▉     | 49/100 [10:23<11:15, 13.25s/it][A[A[A

Current MSE: 0.250303997418
Current length scale: [9.17317898 8.11680855 8.24406878 8.87076118]





 50%|█████     | 50/100 [10:34<10:35, 12.71s/it][A[A[A

Current MSE: 0.250740193723
Current length scale: [9.35663593 8.5467818  8.10209239 9.37063777]





 51%|█████     | 51/100 [10:46<10:10, 12.46s/it][A[A[A

Current MSE: 0.251969107615
Current length scale: [8.5153591  8.76269754 8.48160706 9.56743471]





 52%|█████▏    | 52/100 [10:58<09:48, 12.26s/it][A[A[A

Current MSE: 0.251338778635
Current length scale: [8.81723705 9.40611842 8.67996693 9.28848717]





 53%|█████▎    | 53/100 [11:09<09:25, 12.04s/it][A[A[A

Current MSE: 0.25236013501
Current length scale: [9.47578551 8.88354782 8.77240959 9.72631985]





 54%|█████▍    | 54/100 [11:21<09:10, 11.96s/it][A[A[A

Current MSE: 0.251583257
Current length scale: [9.03120994 8.84278633 9.34123623 9.81466305]





 55%|█████▌    | 55/100 [11:32<08:49, 11.76s/it][A[A[A

Current MSE: 0.251180508218
Current length scale: [8.78536149 9.05547757 9.22727193 9.9699665 ]





 56%|█████▌    | 56/100 [11:44<08:37, 11.76s/it][A[A[A

Current MSE: 0.252565414971
Current length scale: [ 8.51561693  9.54055855  9.37646939 10.08366044]





 57%|█████▋    | 57/100 [11:56<08:24, 11.73s/it][A[A[A

Current MSE: 0.252915669722
Current length scale: [8.09762751 9.29858077 9.04419147 9.57616019]





 58%|█████▊    | 58/100 [12:07<08:09, 11.65s/it][A[A[A

Current MSE: 0.251922321477
Current length scale: [8.26380934 9.64138012 8.59800282 9.1755052 ]





 59%|█████▉    | 59/100 [12:21<08:27, 12.38s/it][A[A[A

Current MSE: 0.252869777355
Current length scale: [8.57670319 9.87594678 8.09103152 9.59908385]





 60%|██████    | 60/100 [12:35<08:28, 12.72s/it][A[A[A

Current MSE: 0.252452849239
Current length scale: [8.64504163 9.25700611 7.93384397 9.28865128]





 61%|██████    | 61/100 [12:51<08:53, 13.68s/it][A[A[A

Current MSE: 0.252275248148
Current length scale: [8.2091692  9.63572374 7.87600814 9.22797109]





 62%|██████▏   | 62/100 [13:13<10:19, 16.29s/it][A[A[A

Current MSE: 0.25349233674
Current length scale: [8.29060873 9.65513849 8.51359116 8.89069301]





 63%|██████▎   | 63/100 [13:34<10:57, 17.77s/it][A[A[A

Current MSE: 0.250291990663
Current length scale: [8.19637589 8.99474338 8.70774618 8.21895036]





 64%|██████▍   | 64/100 [13:47<09:40, 16.12s/it][A[A[A

Current MSE: 0.252223219557
Current length scale: [8.02153729 8.60001107 8.53668651 8.52581017]





 65%|██████▌   | 65/100 [14:01<09:00, 15.45s/it][A[A[A

Current MSE: 0.252807374341
Current length scale: [7.64252672 8.30138909 8.61778292 8.14520845]





 66%|██████▌   | 66/100 [14:20<09:20, 16.49s/it][A[A[A

Current MSE: 0.25355376648
Current length scale: [8.09902616 9.20687434 8.37074586 8.22187718]





 67%|██████▋   | 67/100 [14:35<08:49, 16.05s/it][A[A[A

Current MSE: 0.249430875734
Current length scale: [8.29444332 9.22131949 8.58648223 8.690597  ]


KeyboardInterrupt: 

In [30]:
from factorgp import IterFactorGP

In [31]:
dims = [n, 50, 50, 10]  # n, t, q, r
model = IterFactorGP(dims)

In [32]:
start = time.time()
model.update_conditional_latent(Y)
end = time.time()
print(end - start)

KeyboardInterrupt: 