In [17]:
import numpy as np
import pandas as pd

from sklearn.manifold import Isomap, LocallyLinearEmbedding, SpectralEmbedding, TSNE
from sklearn.model_selection import train_test_split

#### Load and prepare data

In [59]:
df = pd.read_csv("datasets/iris.csv")

In [60]:
train, test = train_test_split(df, test_size=0.25)
train.reset_index(inplace=True, drop=True)
test.reset_index(inplace=True, drop=True)

In [61]:
len(train)

112

#### Embedding

In [62]:
df_org = train
df = pd.DataFrame(train, columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'])
embedding = TSNE(n_components=2)
df_trans = embedding.fit_transform(df)
df_org["Emb_dim1"] = df_trans[:,0]
df_org["Emb_dim2"] = df_trans[:,1]
#df_org.to_csv("embedded_datasets/iris_isomap.csv")
df_org.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species,Emb_dim1,Emb_dim2
0,4.4,3.0,1.3,0.2,setosa,-14.11847,17.438848
1,5.0,3.4,1.6,0.4,setosa,-11.968365,17.445724
2,5.5,3.5,1.3,0.2,setosa,-10.834895,17.776367
3,5.5,2.4,3.8,1.1,versicolor,8.29675,-5.880253
4,6.3,2.7,4.9,1.8,virginica,6.35317,-10.326626


#### Fit Model 

In [63]:
import mb_modelbase as mbase

In [64]:
mymod = mbase.MixableCondGaussianModel("Iris_tsne_testdata")
mymod.fit(df=df_org, bool_test_data=False)

<mb_modelbase.models_core.mixable_cond_gaussian.MixableCondGaussianModel at 0x7f670d2dac88>

In [65]:
len(mymod.data)

112

#### Predict for embeddings in test data

In [66]:
emb1, emb2 = [], []
for row in test.iterrows():
    mymod_cond = mymod.copy()
    for col in test.columns:
        mymod_cond = mymod_cond.copy().condition(mbase.Condition(col, "==", row[1][col]))
    argmax = mymod_cond.aggregate("maximum")
    emb1.append(argmax[-2])
    emb2.append(argmax[-1])


In [67]:
test["Emb_dim1"] = emb1
test["Emb_dim2"] = emb2
test.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species,Emb_dim1,Emb_dim2
0,6.0,2.2,5.0,1.5,virginica,5.654923,-9.591659
1,5.6,2.5,3.9,1.1,versicolor,7.774174,-6.424803
2,6.3,3.3,4.7,1.6,versicolor,7.507542,-9.032277
3,5.4,3.4,1.5,0.4,setosa,-10.992574,16.426212
4,5.1,2.5,3.0,1.1,versicolor,7.724357,-5.162045


In [68]:
mymod.test_data = test

In [69]:
mymod.save(model=mymod, filename="Iris_tsne_testdata.mdl")

150