In [None]:
import torch
import torch.nn.functional as F
import lime
import lime.lime_tabular


def predict_fn_1d(x):
    x_tensor = torch.from_numpy(x).float().to(cfg.DEVICE)
    x_tensor = x_tensor.unsqueeze(1)

    with torch.no_grad():
        logits = trained_model(x_tensor)

    probs = F.softmax(logits, dim = 1)
    return probs.cpu().numpy()

X_train_for_lime = X_train_norm.reshape(X_train.shape[0], -1)

explainer = lime.lime_tabular.LimeTabularExplainer(
    training_data = X_train_for_lime,
    feature_names=[f'{i}' for i in range(X_train_for_lime.shape[1])],
    class_names=cfg.CLASS_NAMES,
    mode = 'classification'
)

for idx in range(500,515): 
    sample_to_explain = X_test_norm[idx].flatten()
    true_label_idx = Y_test[idx]
    print(sample_to_explain)
    true_label_name = cfg.CLASS_NAMES[true_label_idx]
    device = next(trained_model.parameters()).device  # lấy device của model
    sample_tensor = torch.from_numpy(sample_to_explain).float().unsqueeze(0).unsqueeze(0).to(device)
    with torch.no_grad():
        pred_probs = F.softmax(trained_model(sample_tensor), dim=1)
        pred_label_idx = torch.argmax(pred_probs).item()
        pred_label_name = cfg.CLASS_NAMES[pred_label_idx]

    print(f'Sample #{idx}')
    print(f'True label: {true_label_name}')
    print(f'Pred label: {pred_label_name}')

    explanation = explainer.explain_instance(
    data_row= sample_to_explain,
    predict_fn=predict_fn_1d,
    num_features= 10,
    labels=(0,1,2)
    )

    from IPython.display import display, HTML

    pred_label_idx = torch.argmax(pred_probs).item() # Bạn đã có dòng này

# 3. Hiển thị giải thích cho chính lớp được dự đoán đó
    html_explanation = explanation.as_html(labels=(pred_label_idx,), show_table=True)
    display(HTML(html_explanation))

In [None]:
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(model, (1, 86), 
                                        as_strings=True, 
                                        print_per_layer_stat=True)
print(f'Computational complexity: {macs}')
print(f'Number of parameters: {params}')