In [None]:
%%capture
import sys

# Añade el directorio principal al path de búsqueda para importar módulos desde esa ubicación
sys.path.insert(0, "..")

# Desactivar los warnings para evitar mensajes innecesarios durante la ejecución
import warnings

# Importación de bibliotecas necesarias
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split

from likelihood.models.deep import (
    AutoClassifier,
)  # Modelos de deep learning personalizados
from likelihood.tools import OneHotEncoder, get_metrics, apply_lora  # Herramientas auxiliares

import tensorflow as tf

In [2]:
# Cargar el dataset de cáncer de mama desde sklearn
df = datasets.load_breast_cancer()

# Convertir los datos a un DataFrame de pandas para facilitar la manipulación
df_cancer = pd.DataFrame(data=df.data, columns=df.feature_names)
df_cancer["target"] = df.target  # Añadir la columna de etiquetas 'target'

# OneHotEncoder convierte las etiquetas a formato one-hot encoding
y_encoder = OneHotEncoder()
y = y_encoder.encode(df_cancer["target"].to_list())  # Codificar las etiquetas de la clase (target)
X = df_cancer.drop(
    columns="target"
).to_numpy()  # Extraer las características (sin la columna 'target')
X = np.asarray(X).astype(np.float32)  # Convertir X a tipo float32 para la entrada del modelo
y = np.asarray(y).astype(np.float32)  # Convertir y a tipo float32

# Dividir los datos en conjuntos de entrenamiento y prueba
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [3]:
# Example usage
input_shape = (X.shape[1],)
num_classes = y.shape[1]

# Define the AutoClassifier model
model = AutoClassifier(
    input_shape_parm=input_shape[-1],
    num_classes=num_classes,
    units=17,
    activation="selu",
    l2_reg=0.0,
    num_layers=2,
    lora_mode=True,
)

# Compilación del modelo: optimizador, función de pérdida y métricas
model.compile(
    optimizer="adam",  # Optimizador Adam
    loss=tf.keras.losses.CategoricalCrossentropy(),  # Función de pérdida para clasificación multiclase
    metrics=[
        tf.keras.metrics.F1Score(threshold=0.5)
    ],  # Métrica F1 (threshold = 0.5 para predicciones)
)

history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test), verbose=False)

# Hacer predicciones sobre el conjunto de entrenamiento
pred = model.predict(X)

# Convertir las predicciones a las etiquetas predichas (máxima probabilidad)
pred_label = np.argmax(pred, axis=1)

# Añadir las predicciones al DataFrame original para su análisis
df = df_cancer.copy()
y_labels = df.drop(columns="target").columns.to_list()
df_cancer["prediction"] = pred_label  # Columna de las etiquetas predichas
df_cancer["label_0"] = pred[:, 0]  # Probabilidad de la clase 0
df_cancer["label_1"] = pred[:, 1]  # Probabilidad de la clase 1

# Calcular y mostrar las métricas del modelo comparando las etiquetas reales y las predicciones
get_metrics(df_cancer, "target", "prediction", verbose=True)

# Guardar el modelo entrenado en el disco en formato TensorFlow
model.save("lora_model", save_format="tf")

Input shape: (None, 38)
Dense weights shape: 38x17
LoRA weights shape: A(38, 4), B(4, 17)
Accuracy: 91.74%
Precision: 88.94%
Recall: 99.16%
F1-Score: 93.77
Cohen's Kappa: 0.8161


In [4]:
model.classifier.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 LoRA_0 (LoRALayer)          (None, 17)                220       
                                                                 
 activation (Activation)     (None, 17)                0         
                                                                 
 dense_4 (Dense)             (None, 2)                 36        
                                                                 
Total params: 256 (1.00 KB)
Trainable params: 256 (1.00 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [None]:
# Cargar el modelo guardado desde el disco
model = tf.keras.models.load_model("lora_model")
model.summary()
# Hacer predicciones nuevamente con el modelo cargado
pred = model.predict(X)

# Obtener las etiquetas predichas para las nuevas predicciones
pred_label = np.argmax(pred, axis=1)

# Añadir las nuevas predicciones al DataFrame original
df["prediction"] = pred_label

# Calcular y mostrar las métricas nuevamente con el modelo cargado
get_metrics(df, "target", "prediction", verbose=True)

AssertionError: Found 4 Python objects that were not bound to checkpointed values, likely due to changes in the Python program. Showing 4 of 4 unmatched objects: [<tf.Variable 'kernel:0' shape=(38, 17) dtype=float32, numpy=
array([[-0.08682367, -0.3138136 , -0.20465228, -0.00373584, -0.32972506,
         0.02088168,  0.1655002 ,  0.3237375 , -0.23878354,  0.22563583,
        -0.0238317 , -0.03691277, -0.23013854, -0.00163841, -0.07690227,
        -0.27483836, -0.16635971],
       [ 0.32840705, -0.00694564,  0.14094016,  0.2901131 ,  0.13030165,
         0.12820882, -0.0227012 ,  0.23289216,  0.2233954 , -0.17634791,
         0.10094994, -0.0710341 ,  0.00509053, -0.23132785, -0.23133186,
        -0.3034119 ,  0.18487024],
       [ 0.2136882 , -0.06205299,  0.20492852, -0.014164  ,  0.08798984,
         0.30495155,  0.15833545, -0.32297817, -0.16988239, -0.0037345 ,
         0.23332423, -0.02173758, -0.21040073, -0.2632298 ,  0.23551631,
        -0.09602039,  0.05802312],
       [ 0.23853642, -0.27450535, -0.24513055,  0.12364203, -0.27806526,
         0.10680291, -0.09182018, -0.21730095, -0.176694  ,  0.15454072,
         0.27692783, -0.11102572,  0.0406535 ,  0.07185465, -0.05374157,
        -0.32055828,  0.2565589 ],
       [-0.26446444,  0.23059076,  0.3298564 , -0.17647107, -0.05213884,
        -0.19932686, -0.01500508,  0.31451428, -0.26958632,  0.14867777,
         0.04414955, -0.1353525 , -0.10732563, -0.27350274,  0.15647009,
         0.27623665, -0.15559152],
       [ 0.10518929,  0.01769391, -0.00269347, -0.19952254,  0.00998127,
         0.2992488 ,  0.09951761, -0.28507438,  0.2028355 ,  0.24906397,
        -0.21115072,  0.09743783, -0.30375624, -0.19869885, -0.11884239,
         0.25981444,  0.32237923],
       [-0.24456562,  0.21920747, -0.18364808,  0.04999879, -0.14999497,
        -0.0497213 , -0.09977773,  0.23952216, -0.13433753,  0.18558031,
        -0.18343113,  0.23260164, -0.11254026, -0.0444046 ,  0.16520709,
         0.12000194,  0.0080407 ],
       [ 0.2555806 , -0.20095353,  0.25825763, -0.22966196, -0.0712164 ,
        -0.08347803,  0.01078764,  0.12110496, -0.1894474 ,  0.22510046,
         0.10930416, -0.03788954, -0.18013997,  0.24639285, -0.05837345,
         0.27180094, -0.18632933],
       [-0.01220888,  0.05444878, -0.16558996,  0.15101197,  0.26512915,
        -0.07565868,  0.16527024,  0.21163857,  0.06546709, -0.15275213,
        -0.04915464, -0.07297143,  0.26028156,  0.08746469,  0.29496902,
         0.22835869, -0.26776144],
       [ 0.17094052, -0.2915169 ,  0.27044016, -0.04535705, -0.21750757,
         0.1578207 ,  0.27081436, -0.07259029,  0.3167696 , -0.26242214,
        -0.19992879, -0.30924225,  0.0873079 , -0.07172447,  0.10897523,
        -0.17452279,  0.26872277],
       [ 0.29443985,  0.01357552,  0.27323878,  0.15225738,  0.06846356,
        -0.15139721,  0.12873247, -0.2328442 , -0.09712017,  0.18183231,
        -0.01494706, -0.15089473,  0.07889518,  0.03955323, -0.30253857,
         0.32959652, -0.21023458],
       [-0.27703714, -0.30233872, -0.24606276, -0.24782267,  0.29057205,
         0.31303406, -0.04165626, -0.03282729,  0.22948968, -0.14880824,
         0.25730026,  0.23347735, -0.14863682, -0.22283372, -0.3271782 ,
         0.15886015,  0.10776252],
       [-0.09907719,  0.13487607,  0.21359527, -0.28831923, -0.2556925 ,
         0.3177144 , -0.13801329,  0.18886822, -0.14633842,  0.14851603,
         0.14484867, -0.11113171, -0.1989877 , -0.20204835,  0.02282885,
        -0.13256651,  0.3003307 ],
       [ 0.04342356,  0.01977387, -0.2321699 ,  0.22949916,  0.32946777,
        -0.3235127 ,  0.30813795,  0.1379436 ,  0.25786096, -0.02381289,
         0.24892396,  0.14716345, -0.29944664, -0.2986673 , -0.1900017 ,
        -0.26191053, -0.03109571],
       [ 0.15157533,  0.03371572, -0.10308038,  0.28116685, -0.20264195,
         0.08352584,  0.02015257,  0.09169394, -0.2788572 ,  0.18193161,
         0.19231498, -0.2900066 ,  0.3184259 ,  0.08061233,  0.22602457,
        -0.25188673,  0.2749244 ],
       [-0.28754127, -0.07609338,  0.07102102, -0.20527516,  0.3095311 ,
         0.27834332,  0.15986913,  0.22996569, -0.21511705,  0.17038858,
        -0.2089751 , -0.19455431, -0.057161  , -0.15915112,  0.03976238,
        -0.08786668, -0.06161556],
       [ 0.28448462,  0.17523742, -0.11065371,  0.2711659 ,  0.12535864,
         0.2933551 ,  0.08345085,  0.18026447, -0.06337541, -0.02397329,
         0.2863741 , -0.2345882 , -0.15958203, -0.25846007,  0.1390037 ,
         0.02790025,  0.1335881 ],
       [ 0.01514006, -0.03702578,  0.32717335, -0.2284517 , -0.2301361 ,
         0.08685794, -0.24446395,  0.16756502, -0.1871683 , -0.14606698,
        -0.2957194 ,  0.1709916 , -0.20180589,  0.15409714, -0.29147375,
         0.13428006,  0.09476352],
       [-0.18060458, -0.11335443, -0.12832975, -0.16638586,  0.12998328,
        -0.12905313, -0.32349783, -0.05046788,  0.01963621, -0.22241288,
         0.12358037, -0.1544389 , -0.3008229 ,  0.28895307,  0.21748096,
         0.20418048,  0.09559351],
       [ 0.2790497 ,  0.3092928 , -0.28391474,  0.3081    , -0.01453009,
        -0.13405207,  0.01874629,  0.14425501, -0.16289799,  0.17734814,
         0.29037303,  0.02083907, -0.1273092 ,  0.03990588,  0.07803819,
        -0.05041859,  0.28808087],
       [-0.10591559, -0.0597887 , -0.09134203,  0.23391372,  0.04770827,
        -0.08228627,  0.23915583, -0.14291899, -0.2688984 ,  0.21232057,
         0.3018698 ,  0.21273702, -0.01775241, -0.06714147,  0.21741301,
         0.0016548 , -0.09447286],
       [ 0.000624  ,  0.23186725, -0.05092534,  0.20046788,  0.25263005,
         0.18426275, -0.06400537,  0.3020339 , -0.16036627,  0.10120934,
        -0.18622696,  0.06728125,  0.05954033,  0.07352716,  0.139947  ,
         0.13005504, -0.27813312],
       [ 0.32575637,  0.05905399, -0.03901121,  0.21882069, -0.13526052,
        -0.01968536, -0.19405064, -0.22876108,  0.13484836, -0.2201623 ,
        -0.01929083,  0.0584296 ,  0.05136222,  0.20822757,  0.06403711,
         0.09232566, -0.06660977],
       [-0.1751167 , -0.23781368, -0.25604063,  0.07410696, -0.3033674 ,
        -0.0718312 ,  0.25290418,  0.17812002,  0.25795478, -0.18255751,
         0.0349288 , -0.2487207 , -0.05831125, -0.32672086, -0.08725277,
         0.14738837,  0.1804238 ],
       [-0.11432578,  0.1342524 ,  0.2792226 ,  0.30641127,  0.14707771,
         0.19627464,  0.1106512 , -0.01858431,  0.24504936, -0.07743159,
        -0.23850234, -0.02413109,  0.01661021, -0.13208506,  0.2742777 ,
         0.18608963,  0.06403318],
       [ 0.00702691, -0.21508792,  0.14327225,  0.14400396, -0.04041904,
         0.32147223,  0.19804895,  0.02896136, -0.17001286, -0.08378419,
         0.18626207,  0.02552813, -0.24635255, -0.10175113, -0.18855882,
        -0.31620252,  0.1786043 ],
       [-0.14661105, -0.15569933, -0.01885518, -0.09986308, -0.27809605,
        -0.07111058,  0.20253879, -0.03773031,  0.22898513, -0.22581989,
        -0.25677907, -0.06393307,  0.28698778, -0.15476696, -0.21757168,
         0.25329566,  0.15088198],
       [-0.30808994, -0.14953516,  0.26526892,  0.04921392, -0.32373634,
         0.03711358,  0.31877118,  0.19129944,  0.25078523,  0.15976772,
         0.01878834,  0.11550108,  0.10022217,  0.26329035,  0.117165  ,
        -0.23610528, -0.14682697],
       [-0.05437469,  0.00977102,  0.09822726, -0.1860571 , -0.185651  ,
        -0.0225901 , -0.2452456 , -0.08149257,  0.0102075 ,  0.2926181 ,
        -0.12787688,  0.11830676,  0.05338499, -0.18644722, -0.08827829,
        -0.03504056,  0.07527637],
       [ 0.19837695,  0.21762043,  0.25320935, -0.06394759,  0.20028418,
         0.2997436 ,  0.2600873 , -0.29362065, -0.19773333, -0.05682665,
        -0.2639742 , -0.00924632,  0.11585921,  0.16984695, -0.01381877,
         0.13525933,  0.24384803],
       [ 0.13621211,  0.03055826, -0.25475866,  0.3076142 , -0.12397969,
        -0.02442765, -0.14706163,  0.1998936 , -0.2331985 ,  0.16550419,
        -0.03808516,  0.16588375, -0.13093156, -0.08983143,  0.20536625,
        -0.01596519, -0.19611949],
       [-0.32878923, -0.32887688,  0.14575198,  0.14154282, -0.03638452,
         0.20965594, -0.20612791, -0.10947196,  0.17917007, -0.15571587,
         0.09546515, -0.0936887 ,  0.07992589,  0.15326172,  0.23362267,
         0.2593785 ,  0.2854147 ],
       [-0.15679383, -0.08869636, -0.28784043,  0.13202828,  0.19181848,
         0.29854357,  0.07885352, -0.28090227, -0.04930684,  0.04700318,
         0.16586336,  0.09856927, -0.15961006, -0.06832087,  0.14918709,
         0.32320708,  0.19829983],
       [-0.32910344, -0.17378311,  0.04157159,  0.11472997,  0.13120577,
        -0.2203383 , -0.3118653 , -0.17045952,  0.17011577,  0.31059337,
         0.03818381,  0.04402891,  0.25369602, -0.30725884, -0.323625  ,
         0.29356158,  0.00900891],
       [ 0.03911775,  0.30498832,  0.24011934, -0.1579007 ,  0.08634514,
        -0.15816687, -0.06660622,  0.2797336 ,  0.06462976, -0.1577243 ,
        -0.04100123, -0.10776126,  0.16077513,  0.22667468, -0.16730104,
        -0.28834292,  0.11314741],
       [-0.09199932, -0.10185452,  0.06555489, -0.16246527,  0.22187054,
         0.22365302, -0.19726029,  0.0861136 , -0.06777924, -0.1792821 ,
        -0.17215297,  0.1541369 ,  0.16586792,  0.17680693, -0.18873104,
         0.22325885, -0.2793512 ],
       [-0.19409765, -0.18654676,  0.2187481 , -0.04761851,  0.0263702 ,
        -0.15415691, -0.02693725,  0.12600034,  0.2987154 ,  0.19341785,
        -0.20995203, -0.08887102, -0.08393264, -0.11564888,  0.11076689,
        -0.11868435, -0.00344819],
       [ 0.18468148,  0.28816336,  0.16236535, -0.30700132, -0.07861587,
        -0.04789397,  0.04000783, -0.21222633,  0.20151657,  0.16836777,
        -0.0273909 , -0.01031068,  0.22546607, -0.22320989, -0.3240307 ,
        -0.04172617, -0.03756258]], dtype=float32)>, <tf.Variable 'bias:0' shape=(17,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)>, <tf.Variable 'kernel:0' shape=(17, 2) dtype=float32, numpy=
array([[-0.38152802,  0.13978821],
       [-0.03806615, -0.377932  ],
       [ 0.07835948,  0.22889477],
       [ 0.34216762,  0.32002866],
       [-0.4514347 , -0.32433408],
       [-0.03198189,  0.18715942],
       [ 0.18358481,  0.18062496],
       [-0.18311751,  0.30289692],
       [ 0.14004147,  0.12718076],
       [-0.35649258,  0.43444365],
       [-0.07424363, -0.5358118 ],
       [-0.32834572,  0.10237652],
       [ 0.061988  ,  0.11317223],
       [-0.01967007,  0.18511099],
       [-0.10125646, -0.22219214],
       [-0.06936985, -0.5337951 ],
       [-0.03329998,  0.33770537]], dtype=float32)>, <tf.Variable 'bias:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]