In [11]:
import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error
import os
from datetime import datetime
from collections import defaultdict
import bisect
from matplotlib.ticker import MaxNLocator
from matplotlib import pyplot as plt
from math import sqrt
from catboost import CatBoostRegressor
from scipy import *
from scipy.linalg import norm, pinv
import math


class RBF:

        def __init__(self, indim, numCenters, outdim):
            self.indim = indim
            self.outdim = outdim
            self.numCenters = numCenters
            self.centers = [np.random.uniform(-1, 1, indim) for i in range(numCenters)]
            self.beta = 8
            self.W = np.random.random((self.numCenters, self.outdim))

        def _basisfunc(self, c, d):
            assert len(d) == self.indim
            return norm(c-d)**3

        def _calcAct(self, X):
            # calculate activations of RBFs
            G = np.zeros((X.shape[0], self.numCenters), float)
            for ci, c in enumerate(self.centers):
                for xi, x in enumerate(X):
                    G[xi,ci] = self._basisfunc(c, x)
            return G

        def train(self, X, Y):
            """ X: matrix of dimensions n x indim 
                y: column vector of dimension n x 1 """

            # choose random center vectors from training set
            rnd_idx = np.random.permutation(X.shape[0])[:self.numCenters]
            self.centers = [X[i,:] for i in rnd_idx]

            #print("center", self.centers)
            # calculate activations of RBFs
            G = self._calcAct(X)
            #print(G)

            # calculate output weights (pseudoinverse)
            self.W = np.dot(pinv(G), Y)

        def test(self, X):
            """ X: matrix of dimensions n x indim """

            G = self._calcAct(X)
            Y = np.dot(G, self.W)
            return Y
        
for x in range(2,3):
    print("\n\nObject{}".format(x))
    path = 'FBL/Object{}/Deform/'.format(x)


    X1_train = pd.read_csv(path+'afbl_x1_d_train_{}.csv'.format(x))
    X2_train = pd.read_csv(path+'afbl_x2_d_train_{}.csv'.format(x))
    X1_test = pd.read_csv(path+'afbl_x1_d_test_{}.csv'.format(x))
    X2_test = pd.read_csv(path+'afbl_x2_d_test_{}.csv'.format(x))

    Y_train = pd.read_csv(path+'afbl_f_d_train_{}.csv'.format(x))
    Y_test = pd.read_csv(path+'afbl_f_d_test_{}.csv'.format(x))

    print(X1_train.shape, Y_train.shape, X2_test.shape)

    x_train = np.array(pd.concat((X1_train,X2_train),axis=1))
    x_test = np.array(pd.concat((X1_test,X2_test),axis=1))
    y_train,y_test = np.array(Y_train),np.array(Y_test)
    print(x_train.shape,x_test.shape,y_train.shape,y_test.shape)

    t1 = datetime.now()

    rbf = RBF(2,50, 1)
    rbf.train(x_train, y_train)
    t2 = datetime.now()

    y1 = rbf.test(x_test)
    print('time taken by FBL:', t2-t1)
    print('Root Mean Squared error: ', mean_squared_error(y_test,y1)**0.5)
    
    statement1 = 'time taken by FBL: {} \nRoot Mean Squared error: {}\n'.format(t2-t1, mean_squared_error(y_test,y1)**0.5)
    

    cat = CatBoostRegressor(iterations=1200, learning_rate=0.01, random_seed=42)
    t1 = datetime.now()
    cat.fit(x_train,y_train)
    t2 = datetime.now()

    y2 = cat.predict(x_test)
    print('time taken by Catboost on FBL data:', t2-t1)
    print('Root Mean Squared error: ', mean_squared_error(y_test,y2)**0.5)
    
    statement2 = 'time taken by catboost: {} \nRoot Mean Squared error: {}\n'.format(t2-t1, mean_squared_error(y_test,y2)**0.5)
    
    file1 = open("FBL_Results1.txt","a")
    file1.write("\n Object {} \n".format(x))
    file1.write(statement1)
    file1.write(statement2)
    file1.close()



Object2
(28506, 1) (28506, 1) (146164, 1)
(28506, 2) (146164, 2) (28506, 1) (146164, 1)
time taken by FBL: 0:00:07.079692
Root Mean Squared error:  0.07709059823027975
0:	learn: 1.1648782	total: 2.59ms	remaining: 3.1s
1:	learn: 1.1538992	total: 4.95ms	remaining: 2.96s
2:	learn: 1.1431715	total: 7.3ms	remaining: 2.91s
3:	learn: 1.1324049	total: 9.66ms	remaining: 2.89s
4:	learn: 1.1218931	total: 12ms	remaining: 2.86s
5:	learn: 1.1115298	total: 14.3ms	remaining: 2.84s
6:	learn: 1.1011083	total: 16.7ms	remaining: 2.85s
7:	learn: 1.0909005	total: 19.1ms	remaining: 2.84s
8:	learn: 1.0807560	total: 21.4ms	remaining: 2.83s
9:	learn: 1.0705805	total: 23.9ms	remaining: 2.84s
10:	learn: 1.0606653	total: 26.3ms	remaining: 2.85s
11:	learn: 1.0509112	total: 28.7ms	remaining: 2.84s
12:	learn: 1.0410793	total: 30.9ms	remaining: 2.82s
13:	learn: 1.0313129	total: 33.2ms	remaining: 2.81s
14:	learn: 1.0217004	total: 35.5ms	remaining: 2.81s
15:	learn: 1.0121310	total: 37.9ms	remaining: 2.8s
16:	learn: 1.

156:	learn: 0.2867926	total: 335ms	remaining: 2.22s
157:	learn: 0.2844411	total: 337ms	remaining: 2.22s
158:	learn: 0.2821455	total: 339ms	remaining: 2.22s
159:	learn: 0.2798311	total: 341ms	remaining: 2.21s
160:	learn: 0.2775854	total: 343ms	remaining: 2.21s
161:	learn: 0.2753025	total: 345ms	remaining: 2.21s
162:	learn: 0.2730465	total: 347ms	remaining: 2.21s
163:	learn: 0.2707901	total: 349ms	remaining: 2.2s
164:	learn: 0.2685946	total: 351ms	remaining: 2.2s
165:	learn: 0.2664853	total: 353ms	remaining: 2.2s
166:	learn: 0.2643669	total: 355ms	remaining: 2.19s
167:	learn: 0.2622545	total: 357ms	remaining: 2.19s
168:	learn: 0.2601554	total: 359ms	remaining: 2.19s
169:	learn: 0.2580494	total: 361ms	remaining: 2.19s
170:	learn: 0.2559518	total: 363ms	remaining: 2.18s
171:	learn: 0.2539063	total: 365ms	remaining: 2.18s
172:	learn: 0.2519004	total: 367ms	remaining: 2.18s
173:	learn: 0.2499072	total: 369ms	remaining: 2.18s
174:	learn: 0.2479310	total: 371ms	remaining: 2.17s
175:	learn: 0.2

330:	learn: 0.1017549	total: 698ms	remaining: 1.83s
331:	learn: 0.1014504	total: 700ms	remaining: 1.83s
332:	learn: 0.1011219	total: 702ms	remaining: 1.83s
333:	learn: 0.1008253	total: 704ms	remaining: 1.82s
334:	learn: 0.1005292	total: 706ms	remaining: 1.82s
335:	learn: 0.1002336	total: 708ms	remaining: 1.82s
336:	learn: 0.0999163	total: 710ms	remaining: 1.82s
337:	learn: 0.0995938	total: 712ms	remaining: 1.82s
338:	learn: 0.0993076	total: 714ms	remaining: 1.81s
339:	learn: 0.0990179	total: 716ms	remaining: 1.81s
340:	learn: 0.0987644	total: 718ms	remaining: 1.81s
341:	learn: 0.0984992	total: 720ms	remaining: 1.81s
342:	learn: 0.0982226	total: 722ms	remaining: 1.8s
343:	learn: 0.0979414	total: 725ms	remaining: 1.8s
344:	learn: 0.0976531	total: 727ms	remaining: 1.8s
345:	learn: 0.0973785	total: 729ms	remaining: 1.8s
346:	learn: 0.0971329	total: 731ms	remaining: 1.8s
347:	learn: 0.0968861	total: 733ms	remaining: 1.79s
348:	learn: 0.0966262	total: 735ms	remaining: 1.79s
349:	learn: 0.096

509:	learn: 0.0787041	total: 1.07s	remaining: 1.44s
510:	learn: 0.0786721	total: 1.07s	remaining: 1.44s
511:	learn: 0.0786085	total: 1.07s	remaining: 1.44s
512:	learn: 0.0785345	total: 1.07s	remaining: 1.44s
513:	learn: 0.0784784	total: 1.07s	remaining: 1.43s
514:	learn: 0.0784421	total: 1.08s	remaining: 1.43s
515:	learn: 0.0784136	total: 1.08s	remaining: 1.43s
516:	learn: 0.0783563	total: 1.08s	remaining: 1.43s
517:	learn: 0.0783293	total: 1.08s	remaining: 1.43s
518:	learn: 0.0782792	total: 1.08s	remaining: 1.42s
519:	learn: 0.0782359	total: 1.09s	remaining: 1.42s
520:	learn: 0.0781764	total: 1.09s	remaining: 1.42s
521:	learn: 0.0781410	total: 1.09s	remaining: 1.42s
522:	learn: 0.0780898	total: 1.09s	remaining: 1.41s
523:	learn: 0.0780621	total: 1.09s	remaining: 1.41s
524:	learn: 0.0780334	total: 1.1s	remaining: 1.41s
525:	learn: 0.0779951	total: 1.1s	remaining: 1.41s
526:	learn: 0.0779560	total: 1.1s	remaining: 1.41s
527:	learn: 0.0779323	total: 1.1s	remaining: 1.4s
528:	learn: 0.077

685:	learn: 0.0742548	total: 1.43s	remaining: 1.07s
686:	learn: 0.0742351	total: 1.43s	remaining: 1.07s
687:	learn: 0.0742213	total: 1.43s	remaining: 1.06s
688:	learn: 0.0741893	total: 1.43s	remaining: 1.06s
689:	learn: 0.0741823	total: 1.44s	remaining: 1.06s
690:	learn: 0.0741734	total: 1.44s	remaining: 1.06s
691:	learn: 0.0741605	total: 1.44s	remaining: 1.06s
692:	learn: 0.0741525	total: 1.44s	remaining: 1.05s
693:	learn: 0.0741460	total: 1.44s	remaining: 1.05s
694:	learn: 0.0741304	total: 1.45s	remaining: 1.05s
695:	learn: 0.0741181	total: 1.45s	remaining: 1.05s
696:	learn: 0.0741017	total: 1.45s	remaining: 1.04s
697:	learn: 0.0740926	total: 1.45s	remaining: 1.04s
698:	learn: 0.0740629	total: 1.45s	remaining: 1.04s
699:	learn: 0.0740450	total: 1.46s	remaining: 1.04s
700:	learn: 0.0740197	total: 1.46s	remaining: 1.04s
701:	learn: 0.0740121	total: 1.46s	remaining: 1.03s
702:	learn: 0.0739979	total: 1.46s	remaining: 1.03s
703:	learn: 0.0739919	total: 1.46s	remaining: 1.03s
704:	learn: 

860:	learn: 0.0721295	total: 1.78s	remaining: 701ms
861:	learn: 0.0721185	total: 1.78s	remaining: 699ms
862:	learn: 0.0721140	total: 1.78s	remaining: 697ms
863:	learn: 0.0720990	total: 1.79s	remaining: 695ms
864:	learn: 0.0720926	total: 1.79s	remaining: 693ms
865:	learn: 0.0720901	total: 1.79s	remaining: 691ms
866:	learn: 0.0720630	total: 1.79s	remaining: 689ms
867:	learn: 0.0720508	total: 1.79s	remaining: 687ms
868:	learn: 0.0720365	total: 1.8s	remaining: 685ms
869:	learn: 0.0720255	total: 1.8s	remaining: 683ms
870:	learn: 0.0720058	total: 1.8s	remaining: 681ms
871:	learn: 0.0719887	total: 1.8s	remaining: 678ms
872:	learn: 0.0719821	total: 1.8s	remaining: 676ms
873:	learn: 0.0719713	total: 1.81s	remaining: 674ms
874:	learn: 0.0719585	total: 1.81s	remaining: 672ms
875:	learn: 0.0719487	total: 1.81s	remaining: 670ms
876:	learn: 0.0719383	total: 1.81s	remaining: 668ms
877:	learn: 0.0719269	total: 1.81s	remaining: 666ms
878:	learn: 0.0719202	total: 1.82s	remaining: 664ms
879:	learn: 0.071

1040:	learn: 0.0703865	total: 2.14s	remaining: 328ms
1041:	learn: 0.0703726	total: 2.15s	remaining: 326ms
1042:	learn: 0.0703690	total: 2.15s	remaining: 323ms
1043:	learn: 0.0703552	total: 2.15s	remaining: 321ms
1044:	learn: 0.0703464	total: 2.15s	remaining: 319ms
1045:	learn: 0.0703437	total: 2.15s	remaining: 317ms
1046:	learn: 0.0703385	total: 2.16s	remaining: 315ms
1047:	learn: 0.0703364	total: 2.16s	remaining: 313ms
1048:	learn: 0.0703264	total: 2.16s	remaining: 311ms
1049:	learn: 0.0703229	total: 2.16s	remaining: 309ms
1050:	learn: 0.0703066	total: 2.17s	remaining: 307ms
1051:	learn: 0.0702950	total: 2.17s	remaining: 305ms
1052:	learn: 0.0702859	total: 2.17s	remaining: 303ms
1053:	learn: 0.0702730	total: 2.17s	remaining: 301ms
1054:	learn: 0.0702650	total: 2.17s	remaining: 299ms
1055:	learn: 0.0702589	total: 2.17s	remaining: 297ms
1056:	learn: 0.0702567	total: 2.18s	remaining: 295ms
1057:	learn: 0.0702545	total: 2.18s	remaining: 293ms
1058:	learn: 0.0702470	total: 2.18s	remaining: