In [None]:
!pip install datasets
!pip install xgboost
!pip install nltk
!pip install gensim
!pip install -U sentence-transformers
!pip install torch_explain
!pip install torch
#!pip -q install langchain huggingface_hub transformers sentence_transformers
!pip install mistralai

In [None]:
import transformers
import pandas as pd
from transformers import pipeline
import torch
from transformers import BertTokenizer, BertForSequenceClassification, BertForTokenClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import os
from transformers import AutoTokenizer
from torch_explain.nn.concepts import ConceptReasoningLayer, ConceptEmbedding
import numpy as np
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from typing import Dict, List
from utilities import *
import pickle
import json
from sklearn.tree import DecisionTreeClassifier
import xgboost as xgb
import joblib
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
device='cuda'

# Load the dataset

In [None]:
# cebab
class Args():
    max_length = 100
    dataset = 'cebab'
    pseudo_labeling = 'mixtral'
    configuration = 'modified_cem' # e modified_dcl, supervised_dcl
    backbone = 'mixtral'
    model_name =  None   # Big: 'all-mpnet-base-v2'  Medium: 'all-MiniLM-L12-v2'  Small: 'all-MiniLM-L6-v2'
    chuncks = None 
    sentences_per_concept = 1
    n_concepts = 4
    lr = 1e-6
    epochs = 50
    step_size = 10
    gamma = 0.1
    n_labels = 2
    concept_size = 16
    seed = 42
    N = None
    threshold_optimization = False
    lambda_coeff = 0.5
args = Args()

In [None]:
max_length = args.max_length
dataset = args.dataset
pseudo_labeling = args.pseudo_labeling
model_name =  args.model_name   
configuration = args.configuration
chuncks = args.chuncks
sentences_per_concept = args.sentences_per_concept
N = args.N
n_concepts = args.n_concepts
threshold_optimization = args.threshold_optimization
labeler_tokenizer = None
n_labels = args.n_labels
concept_size = args.concept_size
backbone = args.backbone
lambda_coeff = args.lambda_coeff

if pseudo_labeling == 'mistral':
    if model_name==None:
        model_name = 'mistralai/Mistral-7B-v0.1'
    labeler_tokenizer = AutoTokenizer.from_pretrained(model_name)
elif pseudo_labeling == 'mixtral':
    #model_name = 'mobiuslabsgmbh/Mixtral-8x7B-Instruct-v0.1-hf-attn-4bit-moe-2bit-metaoffload-HQQ'
    #from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer

    if model_name==None:
        model_name = 'mistralai/Mixtral-8x7B-v0.1'
    labeler_tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", model_max_length = max_length) 

In [None]:
file_path = f'concept_sentences/{dataset}/'

def read_from_json_file(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

if dataset=='depression':
    concepts = [['',''],['',''],['',''],['',''],['',''],['','']]
else:
    concepts = read_from_json_file(file_path+f'{sentences_per_concept}_sentenceXconcept.json')

In [None]:
if pseudo_labeling == 'sbert':
    loaded_set_folder = f'loaded_sets/{dataset}/{pseudo_labeling}_{model_name}_{chuncks}_{str(sentences_per_concept)}/'
else:
    loaded_set_folder = f'loaded_sets/{dataset}/{pseudo_labeling}/'

if not os.path.exists(loaded_set_folder):
    os.makedirs(loaded_set_folder)
    print('Path created!')

if len([x for x in os.listdir(loaded_set_folder) if 'ipynb_checkpoints' not in x])==0:

    if dataset=='cebab':
        loader = Cebab_loader('datasets/cebab', concepts, sentences_per_concept, chuncks, 'cebab_train.csv', 
                              'cebab_validation.csv', 'cebab_test.csv', max_length, tokenizer, 
                              pseudo_labeling, model_name, labeler_tokenizer)
        loader.load()
        loaded_train, loaded_val, loaded_test = loader.collator(batch_train=64, batch_val_test=32)

    elif dataset=='drug':
        loader = Drug_loader('datasets/drug/drug_train.tsv', 'datasets/drug/drug_test.tsv', concepts, sentences_per_concept, chuncks, max_length, tokenizer, 
                              pseudo_labeling, model_name, labeler_tokenizer, args.seed)  
        loader.load()
        loaded_train, loaded_val, loaded_test = loader.collator(batch_train=64, batch_val_test=32)

    elif dataset=='emo':
        loader = Emo_loader('datasets/emo/MultiEmotions-It.tsv', concepts, sentences_per_concept, chuncks, max_length, tokenizer, 
                              pseudo_labeling, model_name, labeler_tokenizer, args.seed)            
        loader.load()
        loaded_train, loaded_val, loaded_test = loader.collator(batch_train=64, batch_val_test=32)
    elif dataset=='depression':
        loader = depressed_loader('datasets/depression', concepts, sentences_per_concept, chuncks, max_length, tokenizer, 
                              pseudo_labeling, model_name, labeler_tokenizer, args.seed)            
        loader.load()
        loaded_train, loaded_val, loaded_test = loader.collator(batch_train=64, batch_val_test=32)

    with open(f'{loaded_set_folder}/loaded_train.pkl', 'wb') as f:
        pickle.dump(loaded_train, f)
    with open(f'{loaded_set_folder}/loaded_val.pkl', 'wb') as f:
        pickle.dump(loaded_val, f)
    with open(f'{loaded_set_folder}/loaded_test.pkl', 'wb') as f:
        pickle.dump(loaded_test, f)
else:
    with open(f'{loaded_set_folder}/loaded_train.pkl', 'rb') as f:
        loaded_train = pickle.load(f)
    with open(f'{loaded_set_folder}/loaded_val.pkl', 'rb') as f:
        loaded_val = pickle.load(f)
    with open(f'{loaded_set_folder}/loaded_test.pkl', 'rb') as f:
        loaded_test = pickle.load(f)

In [None]:
cnt=0
for lemme in loaded_train:
    cnt += lemme['embedded_reviews'].shape[0]
print(cnt)

# Evaluate quality of unsupervised concept discovery

# Threshold selection
N samples are extracted from the validation set and the the threhsold that maximizes 
the f1 score macro over the different concepts is selected. When we are dealing with mistral/mixtral classification this operation
doesn't change the annotation performance since those annotators classify using only 1 (concept present) or 0.

In [None]:
if pseudo_labeling == 'sbert' and threshold_optimization==True:
    # The concepts predictions for the whole validation are collected
    true_concepts = torch.zeros(1, n_concepts)
    predicted_concepts = torch.zeros(1, n_concepts, 2)
    for sentence_batch in loaded_val:
        if dataset=='cebab':
            food = sentence_batch['food']
            ambiance = sentence_batch['ambiance']
            service = sentence_batch['service']
            noise = sentence_batch['noise']
            true_concepts = torch.cat([true_concepts, torch.hstack([food.unsqueeze(1), ambiance.unsqueeze(1), service.unsqueeze(1), noise.unsqueeze(1)])])  
        elif dataset=='drug':
            effectiveness = sentence_batch['effectiveness']
            sideEffects = sentence_batch['sideEffects']
            true_concepts = torch.cat([true_concepts, torch.hstack([effectiveness.unsqueeze(1), sideEffects.unsqueeze(1)])])  
        elif dataset=='emo':
            gioia = sentence_batch['joy']
            fiducia = sentence_batch['trust']
            tristezza = sentence_batch['sadness']
            sorpresa = sentence_batch['surprise']
            true_concepts = torch.cat([true_concepts, torch.hstack([gioia.unsqueeze(1), fiducia.unsqueeze(1), tristezza.unsqueeze(1), sorpresa.unsqueeze(1)])])         
        unsupervised_concepts = sentence_batch['concept_score']
        predicted_concepts = torch.cat([predicted_concepts, unsupervised_concepts])

    true_concepts = true_concepts[1:,:]
    predicted_concepts = predicted_concepts[1:,:]
    
    # From the validation N samples are sampled in order to set the threshold
    p = torch.ones(len(loaded_val.dataset))/len(loaded_val.dataset)
    index = p.multinomial(num_samples=N, replacement=True)
    #predicted_concepts[index,:,:].shape, true_concepts[index,:].shape
    thrs = np.linspace(0,1,100)
    best_f1 = 0
    for thr in thrs:
        if dataset=='cebab':
            preds = torch.where(F.softmax(torch.where(predicted_concepts[index,:,:]>thr,predicted_concepts[index,:,:],0), dim=-1)[:,:,1]>0.5, 1, 0)
        elif dataset=='drug':
            preds = torch.argmax(predicted_concepts[index,:,:], dim=-1)
            #preds = torch.where(predicted_concepts[index,:,1]>thr, 1, 0) #torch.where(F.softmax(predicted_concepts, dim=-1)[index,:,1]>thr, 1, 0) 
        elif dataset=='emo':
            preds = torch.where(predicted_concepts[index,:,1]>thr, 1, 0) 
        f1s = []
        for i in range(n_concepts):
            f1s.append(classification_report(true_concepts[index,i], preds[:,i].detach().cpu().numpy(), output_dict=True, zero_division=0)['macro avg']['f1-score'])
        if np.array(f1s).mean()>best_f1:
            threshold = thr
            best_f1 = np.array(f1s).mean()
    print(f'Result obtined by sampling and "manually" labeling {N} samples from the validation-set.')
    print('Best threshold:', threshold, '; best f1-score macro:', best_f1)
else:
    threshold = 0.5

# Results folder

In [None]:
results_folder = f"results/{dataset}/{backbone}/"
if not os.path.exists(results_folder):
    os.makedirs(results_folder)
cnt = len([x for x in os.listdir(results_folder) if 'ipynb' not in x])
if pseudo_labeling == 'sbert':
    result_folder = results_folder+f'{configuration}_{model_name}_E{cnt}/'
else:
    result_folder = results_folder+f'{configuration}_E{cnt}/'
    
os.makedirs(result_folder)
result_folder

In [None]:
if dataset != 'depression':
    true_concepts = torch.zeros(1, n_concepts)
    predicted_concepts = torch.zeros(1, n_concepts)
    for sentence_batch in loaded_test:
        if dataset=='cebab':
            food = sentence_batch['food']
            ambiance = sentence_batch['ambiance']
            service = sentence_batch['service']
            noise = sentence_batch['noise']
            true_concepts = torch.cat([true_concepts, torch.hstack([food.unsqueeze(1), ambiance.unsqueeze(1), service.unsqueeze(1), noise.unsqueeze(1)])])  
        elif dataset=='drug':
            effectiveness = sentence_batch['effectiveness']
            sideEffects = sentence_batch['sideEffects']
            true_concepts = torch.cat([true_concepts, torch.hstack([effectiveness.unsqueeze(1), sideEffects.unsqueeze(1)])]) 
        elif dataset=='emo':
            gioia = sentence_batch['joy']
            fiducia = sentence_batch['trust']
            tristezza = sentence_batch['sadness']
            sorpresa = sentence_batch['surprise']
            true_concepts = torch.cat([true_concepts, torch.hstack([gioia.unsqueeze(1), fiducia.unsqueeze(1), tristezza.unsqueeze(1), sorpresa.unsqueeze(1)])])         

        if pseudo_labeling=='sbert' and dataset=='cebab':
            unsupervised_concepts = torch.where(F.softmax(torch.where(sentence_batch['concept_score']>threshold,sentence_batch['concept_score'],0), dim=-1)[:,:,1]>0.5, 1, 0)
        elif pseudo_labeling=='sbert' and dataset=='drug':
            unsupervised_concepts = torch.argmax(sentence_batch['concept_score'], dim=-1) #torch.where(sentence_batch['concept_score'][:,:,1]>threshold, 1, 0) 
        elif pseudo_labeling=='sbert' and dataset=='emo':
            unsupervised_concepts = torch.where(sentence_batch['concept_score'][:,:,1]>threshold, 1, 0) 
        elif pseudo_labeling in ['mistral', 'mixtral']:
            unsupervised_concepts = sentence_batch['concept_score']

        predicted_concepts = torch.cat([predicted_concepts, unsupervised_concepts])

    true_concepts = true_concepts[1:,:]
    predicted_concepts = predicted_concepts[1:,:]

    fig, ax = plt.subplots(1, n_concepts)
    fig.set_size_inches(20,5)
    if dataset=='cebab':
        names = ['good_food', 'good_ambiance', 'good_service', 'good_noise']
    elif dataset=='drug':
        names = ['effectiveness', 'sideEffects']
    elif dataset=='emo':
        names = ['joy', 'trust', 'sadness', 'surprise']
    for i in range(n_concepts):
        print()
        cm = confusion_matrix(true_concepts[:,i], predicted_concepts[:,i].to(torch.long))
        sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', ax=ax[i])
        ax[i].set_xlabel("True")
        ax[i].set_ylabel("Pred")

        cr = pd.DataFrame(classification_report(true_concepts[:,i], predicted_concepts[:,i].detach().cpu().numpy(), output_dict=True)).T
        cr.index = ["0", "1", "accuracy", "macro avg", "weighted avg"]
        if configuration in ['cbm_fc_ct', 'cbm_ff_ct', 'modified_cem', 'DTree_ct', 'XGBoost_ct', 'cem_ct']:
            if not os.path.exists(result_folder+'/concepts'):
                os.makedirs(result_folder+'/concepts')
            cr.to_csv(result_folder+f'/concepts/{names[i]}_classification_report.csv', index=True)

        # store the results on the folder related to the labeling strategy
        if not os.path.exists(loaded_set_folder+f'/concepts'):
            os.makedirs(loaded_set_folder+f'/concepts')
        cr.to_csv(loaded_set_folder+f'/concepts/{names[i]}_classification_report.csv', index=True)

        ax[i].set_title(names[i]+'; f1-score macro: '+str(round(classification_report(true_concepts[:,i], predicted_concepts[:,i].to(torch.long), output_dict=True)['macro avg']['f1-score'],2)))

# Modules definition

In [None]:
if backbone=='mixtral':
    embedding_size = 4096
elif backbone=='bert':
    embedding_size = 768

# Combine the layers in a sequential model
if configuration in ['dcr', 'supervised_dcr']:
    cem = ConceptEmbedding(embedding_size, n_concepts, concept_size).to(device)
    dcr = ConceptReasoningLayer(concept_size, n_labels, temperature=0.1).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, cem, dcr)
    else:
        model = torch.nn.Sequential(cem, dcr)
elif configuration=='cem':
    cem = ConceptEmbedding(embedding_size, n_concepts, concept_size).to(device)
    ff = nn.Sequential(nn.Linear(n_concepts*concept_size, 10), nn.ReLU(), nn.Linear(10, n_labels), nn.Softmax()).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, cem, ff)
    else:
        model = torch.nn.Sequential(cem, ff)
elif configuration=='cem_ct':
    cem = Modified_cem(embedding_size, n_concepts, concept_size).to(device)
    ff = nn.Sequential(nn.Linear(n_concepts*concept_size, 10), nn.ReLU(), nn.Linear(10, n_labels), nn.Softmax()).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, cem, ff)
    else:
        model = torch.nn.Sequential(cem, ff)   
elif configuration=='cbm_fc':
    ff_concept = nn.Sequential(nn.Linear(embedding_size, n_concepts), nn.Sigmoid()).to(device)
    fc_task = nn.Sequential(nn.Linear(n_concepts, n_labels), nn.Softmax()).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, ff_concept, fc_task)
    else:
        model = torch.nn.Sequential(ff_concept, fc_task)
elif configuration=='cbm_ff':
    concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
    ff_concept = nn.Sequential(nn.Linear(embedding_size, n_concepts), nn.Sigmoid()).to(device)
    ff_task = nn.Sequential(nn.Linear(n_concepts, 3*n_concepts), nn.ReLU(), nn.Linear(3*n_concepts, n_labels), nn.Softmax()).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, ff_concept, ff_task)
    else:
        model = torch.nn.Sequential(ff_concept, ff_task) 
elif configuration=='modified_cem':
    cem = Modified_cem(embedding_size, n_concepts, concept_size).to(device)
    dcr = ConceptReasoningLayer(concept_size, n_labels, temperature=0.1).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, cem, dcr)
    else:
        model = torch.nn.Sequential(cem, dcr)
elif configuration in ['dcl', 'supervised_dcl']:
    cem = ConceptEmbedding(embedding_size, n_concepts, concept_size).to(device)
    dcl = ConceptLinearLayer(concept_size, n_labels, bias=True, attention=False).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, cem, dcl)
    else:
        model = torch.nn.Sequential(cem, dcl)
elif configuration=='modified_dcl':
    cem = Modified_cem(embedding_size, n_concepts, concept_size).to(device)
    dcl = ConceptLinearLayer(concept_size, n_labels, bias=True, attention=False).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, cem, dcl)
    else:
        model = torch.nn.Sequential(cem, dcl)
elif configuration=='cbm_fc_ct':
    fc_task = nn.Sequential(nn.Linear(n_concepts, n_labels), nn.Softmax()).to(device)
    model = nn.Sequential(fc_task)
elif configuration=='cbm_ff_ct':
    ff_task = nn.Sequential(nn.Linear(n_concepts, 3*n_concepts), nn.ReLU(), nn.Linear(3*n_concepts, n_labels), nn.Softmax()).to(device)
    model = nn.Sequential(ff_task)
elif configuration=='DTree':
    ff_concept = nn.Sequential(nn.Linear(embedding_size, n_concepts), nn.Sigmoid()).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, ff_concept)
    else:
        model = torch.nn.Sequential(ff_concept)
    dtree = DecisionTreeClassifier()
elif configuration=='DTree_ct':
    dtree = DecisionTreeClassifier()
elif configuration=='XGBoost':
    # Set XGBoost parameters
    params = {
        'objective': 'multi:softmax',  # Multiclass classification
        'num_class': n_labels,                 # Number of classes
        'max_depth': 3,                 # Maximum depth of each tree
        'eta': 0.3,                     # Learning rate
        'eval_metric': 'merror'         # Evaluation metric
    }
    # Train the XGBoost model
    num_rounds = 5  # Number of boosting rounds
    #model = xgb.train(params, dtrain, num_rounds)
    ff_concept = nn.Sequential(nn.Linear(embedding_size, n_concepts), nn.Sigmoid()).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, ff_concept)
    else:
        model = torch.nn.Sequential(ff_concept)
elif configuration=='XGBoost_ct':
    # Set XGBoost parameters
    params = {
        'objective': 'multi:softmax',  # Multiclass classification
        'num_class': n_labels,                 # Number of classes
        'max_depth': 3,                 # Maximum depth of each tree
        'eta': 0.3,                     # Learning rate
        'eval_metric': 'merror'         # Evaluation metric
    }
    # Train the XGBoost model
    num_rounds = 5  # Number of boosting rounds
    #model = xgb.train(params, dtrain, num_rounds)
elif configuration=='e2e':
    ff_task = nn.Sequential(nn.Linear(embedding_size, 100), nn.ReLU(), nn.Linear(100, n_labels), nn.Softmax()).to(device)
    if backbone=='bert':
        concept_encoder = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(device)
        model = torch.nn.Sequential(concept_encoder, ff_task)
    else:
        model = nn.Sequential(ff_task)

In [None]:
if backbone=='bert':
    for param in concept_encoder.bert.parameters():
        param.requires_grad = False
    for param in concept_encoder.bert.encoder.layer[-1].parameters():
        param.requires_grad = True

if configuration not in ['DTree_ct', 'XGBoost_ct']:
    print('Number of trainable parameters:', sum(p.numel() if p.requires_grad==True else 0 for p in model.parameters()))

# Training

In [None]:
if dataset=='cebab':
    concept_names = ['good food', 'good ambiance', 'good service', 'good noise']
    class_names = ['bad', 'good']
elif dataset=='drug':
    concept_names = ['effectiveness', 'side effects']
    class_names = ['0', '1', '2']
elif dataset=='emo':
    concept_names = ['joy', 'trust', 'sadness', 'surprise']
    class_names = ['0', '1', '2']
elif dataset=='depression':
    concept_names = ['Deprecation', 'Loss_of_Interest', 'Hopelessness', 'Sleep_Disturbances', 'Appetite_Changes', 'Fatigue']
    class_names = ['not_depressed', 'depressed']

In [None]:
@torch.no_grad()
def evaluate(loaded_set, rules=False):
    model.eval()
    running_task_loss = 0
    running_concept_loss = 0
    task_preds = torch.zeros(1,n_labels).to(device)
    concept_preds =  torch.zeros(1,n_concepts).to(device)
    pred_rules = []
    original_sentences = torch.zeros(1,max_length).to(device)
    real_labels = torch.zeros(1)
    
    for sentence_batch in loaded_set:
        y = torch.Tensor(sentence_batch['labels'])
        
        # dummy tensor
        target = torch.nn.functional.one_hot(y.to(torch.long), n_labels).to(device).to(torch.float)
            
        # true_concepts = torch.where(sentence_batch['concept_score']>0.5, 1, 0).to(device).to(torch.float)
        if pseudo_labeling=='sbert':
            if dataset=='cebab':
                true_concepts = torch.where(F.softmax(torch.where(sentence_batch['concept_score']>threshold,sentence_batch['concept_score'],0), dim=-1)[:,:,1]>0.5, 1, 0).to(torch.float).to(device)
            elif dataset=='drug':
                true_concepts = torch.argmax(sentence_batch['concept_score'], dim=-1).to(torch.float).to(device) 
            elif dataset=='emo':
                true_concepts = torch.where(sentence_batch['concept_score'][:,:,1]>threshold, 1, 0).to(torch.float).to(device)
        elif pseudo_labeling in ['mistral', 'mixtral']:
            true_concepts = sentence_batch['concept_score'].to(torch.float).to(device)
            
        if backbone=='bert':
            emb = concept_encoder(sentence_batch['input_ids'].squeeze().to(torch.long).to(device), 
                                    sentence_batch['attention_mask'].squeeze().to(torch.long).to(device), 
                                    output_hidden_states=True).hidden_states[-1][:,0,:]
        elif backbone=='mixtral':
            emb = sentence_batch['embedded_reviews'].to(device)
            
        if dataset=='emo':
            concept_labels = torch.cat([sentence_batch['joy'].unsqueeze(1), 
                                        sentence_batch['trust'].unsqueeze(1), 
                                        sentence_batch['sadness'].unsqueeze(1), 
                                        sentence_batch['surprise'].unsqueeze(1)], axis=1).to(device)
        elif dataset=='cebab':
            concept_labels = torch.cat([sentence_batch['food'].unsqueeze(1), sentence_batch['ambiance'].unsqueeze(1), 
                                        sentence_batch['service'].unsqueeze(1), sentence_batch['noise'].unsqueeze(1)], axis=1).to(device)
        elif dataset=='drug':
            concept_labels = torch.cat([sentence_batch['effectiveness'].unsqueeze(1), 
                                        sentence_batch['sideEffects'].unsqueeze(1)], axis=1).to(device)            

        if configuration in ['dcr', 'supervised_dcr']:
            c_emb, c_pred = cem(emb)
            y_pred = dcr(c_emb, c_pred)
            if configuration=='dcr':
                running_concept_loss += loss_form(c_pred, true_concepts)
            else:
                running_concept_loss += loss_form(c_pred, concept_labels)
            concept_preds = torch.cat([concept_preds, c_pred])
            if rules:
                local_explanations = dcr.explain(c_emb, c_pred, 'local', 
                                                 concept_names=concept_names, class_names=class_names)
                pred_rules += local_explanations
                original_sentences = torch.cat([original_sentences, sentence_batch['input_ids'].to(device)])
        elif configuration in ['dcl', 'supervised_dcl']:
            c_emb, c_pred = cem(emb)
            y_pred = dcl(c_emb, c_pred)
            if configuration=='dcl':
                running_concept_loss += loss_form_concepts(c_pred, true_concepts)
            else:
                running_concept_loss += loss_form_concepts(c_pred, concept_labels)
            concept_preds = torch.cat([concept_preds, c_pred])
            if rules:
                local_explanations = dcl.explain(c_emb, c_pred, 'local', 
                                                 concept_names=concept_names, class_names=class_names)
                pred_rules += local_explanations
                original_sentences = torch.cat([original_sentences, sentence_batch['input_ids'].to(device)])            
        elif configuration=='cem':
            c_emb, c_pred = cem(emb)
            y_pred = ff(c_emb.flatten(start_dim=1))
            running_concept_loss += loss_form(c_pred, true_concepts)
            concept_preds = torch.cat([concept_preds, c_pred])
        elif configuration=='cbm_fc':
            c_pred = ff_concept(emb)
            y_pred = fc_task(c_pred)
            running_concept_loss += loss_form(c_pred, true_concepts)
            concept_preds = torch.cat([concept_preds, c_pred]) 
        elif configuration in ['DTree', 'XGBoost']:
            c_pred = ff_concept(emb)
            concept_loss = loss_form(c_pred, true_concepts)
            y_pred = torch.zeros(emb.shape[0], n_labels).to(device)
            task_loss = torch.Tensor([0]).to(device)
            running_task_loss += task_loss.item()
            running_concept_loss += concept_loss
            loss = concept_loss
        elif configuration=='cbm_ff':
            c_pred = ff_concept(emb)
            y_pred = ff_task(c_pred)
            running_concept_loss += loss_form(c_pred, true_concepts)
            concept_preds = torch.cat([concept_preds, c_pred])
        elif configuration=='modified_cem':
            c_emb = cem(emb, true_concepts)
            y_pred = dcr(c_emb, true_concepts)
            running_concept_loss = torch.Tensor([0])
            concept_preds = torch.cat([concept_preds, true_concepts])
            if rules:
                local_explanations = dcr.explain(c_emb, true_concepts, 'local', 
                                                 concept_names=concept_names, class_names=class_names)
                pred_rules += local_explanations
                original_sentences = torch.cat([original_sentences, sentence_batch['input_ids'].to(device)])
        elif configuration=='cem_ct':
            c_emb = cem(emb, true_concepts)
            y_pred = ff(c_emb.flatten(start_dim=1))
            running_concept_loss = torch.Tensor([0])
            concept_preds = torch.cat([concept_preds, true_concepts])          
        elif configuration=='modified_dcl':
            c_emb = cem(emb, true_concepts)
            y_pred = dcl(c_emb, true_concepts)
            # task_loss = loss_form(y_pred, target)    
            running_concept_loss = torch.Tensor([0])
            concept_preds = torch.cat([concept_preds, true_concepts])
            if rules:
                local_explanations = dcl.explain(c_emb, true_concepts, 'local', 
                                                 concept_names=concept_names, class_names=class_names)
                pred_rules += local_explanations
                original_sentences = torch.cat([original_sentences, sentence_batch['input_ids'].to(device)])
        elif configuration=='cbm_fc_ct':
            y_pred = fc_task(true_concepts)
            running_concept_loss = torch.Tensor([0])
            concept_preds = torch.cat([concept_preds, true_concepts])
        elif configuration=='cbm_ff_ct':
            y_pred = ff_task(true_concepts)
            running_concept_loss = torch.Tensor([0])
            concept_preds = torch.cat([concept_preds, true_concepts])
        elif configuration=='e2e':
            y_pred = ff_task(emb)
            running_concept_loss = torch.Tensor([0])
            concept_preds = torch.cat([concept_preds, true_concepts])
            
        running_task_loss += loss_form(y_pred, target)     
        task_preds = torch.cat([task_preds, y_pred])
        real_labels = torch.cat([real_labels, y])
    model.train()
    
    if rules:
        return pred_rules, original_sentences[1:,:], real_labels[1:]
    else:
        return running_task_loss.item()/len(loaded_set), running_concept_loss.item()/len(loaded_set), task_preds[1:,:], concept_preds[1:,:], real_labels[1:]

In [None]:
if configuration not in ['DTree_ct', 'XGBoost_ct']:
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    if configuration in ['dcl', 'modified_dcl', 'supervised_dcl']:
        loss_form = torch.nn.BCEWithLogitsLoss()
        loss_form_concepts =  torch.nn.BCELoss()
    else:
        loss_form =  torch.nn.BCELoss()
    train_task_losses = []
    train_concept_losses = []
    val_task_losses = []
    val_concept_losses = []
    epochs = args.epochs
    scheduler = StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

    model.train()
    for epoch in range(epochs):
        running_task_loss = 0
        running_concept_loss = 0
        for sentence_batch in tqdm(loaded_train):
                        
            optimizer.zero_grad()

            y = torch.Tensor(sentence_batch['labels'])
            # dummy tensor
            target = torch.nn.functional.one_hot(y.to(torch.long), n_labels).to(device).to(torch.float)

            # true_concepts = torch.where(sentence_batch['concept_score']>0.5, 1, 0).to(device).to(torch.float)
            if pseudo_labeling=='sbert':
                if dataset=='cebab':
                    true_concepts = torch.where(F.softmax(torch.where(sentence_batch['concept_score']>threshold,sentence_batch['concept_score'],0), dim=-1)[:,:,1]>0.5, 1, 0).to(torch.float).to(device)
                elif dataset=='drug':
                    true_concepts = torch.argmax(sentence_batch['concept_score'], dim=-1).to(torch.float).to(device) #torch.where(F.softmax(sentence_batch['concept_score'], dim=-1)[:,:,1]>threshold, 1, 0).to(torch.float).to(device)
                elif dataset=='emo':
                    true_concepts = torch.where(sentence_batch['concept_score'][:,:,1]>threshold, 1, 0).to(torch.float).to(device)
            elif pseudo_labeling in ['mistral', 'mixtral']:
                true_concepts = sentence_batch['concept_score'].to(torch.float).to(device)

            if backbone=='bert':
                emb = concept_encoder(sentence_batch['input_ids'].squeeze().to(torch.long).to(device), 
                                        sentence_batch['attention_mask'].squeeze().to(torch.long).to(device), 
                                        output_hidden_states=True).hidden_states[-1][:,0,:]
            elif backbone=='mixtral':
                emb = sentence_batch['embedded_reviews'].to(device)

            if dataset=='emo':
                concept_labels = torch.cat([sentence_batch['joy'].unsqueeze(1), 
                                            sentence_batch['trust'].unsqueeze(1), 
                                            sentence_batch['sadness'].unsqueeze(1), 
                                            sentence_batch['surprise'].unsqueeze(1)], axis=1).to(device)
            elif dataset=='cebab':
                concept_labels = torch.cat([sentence_batch['food'].unsqueeze(1), sentence_batch['ambiance'].unsqueeze(1), 
                                            sentence_batch['service'].unsqueeze(1), sentence_batch['noise'].unsqueeze(1)], axis=1).to(device)
            elif dataset=='drug':
                concept_labels = torch.cat([sentence_batch['effectiveness'].unsqueeze(1), 
                                            sentence_batch['sideEffects'].unsqueeze(1)], axis=1).to(device)               

            if configuration in ['dcr','supervised_dcr']:
                c_emb, c_pred = cem(emb)
                y_pred = dcr(c_emb, c_pred)
                if configuration=='dcr':
                    concept_loss = loss_form(c_pred, true_concepts)
                else:
                    concept_loss = loss_form(c_pred, concept_labels)
                task_loss = loss_form(y_pred, target)    
                running_task_loss += task_loss.item()
                running_concept_loss += concept_loss.item()
                loss = concept_loss + lambda_coeff * task_loss
            if configuration in ['dcl', 'supervised_dcl']:
                c_emb, c_pred = cem(emb)
                y_pred, weight_attn, bias_attn = dcl(c_emb, c_pred, return_attn=True)
                if configuration=='dcl':
                    concept_loss = loss_form_concepts(c_pred, true_concepts)
                else:
                    concept_loss = loss_form_concepts(c_pred, concept_labels)
                task_loss = loss_form(y_pred, target)    
                running_task_loss += task_loss.item()
                running_concept_loss += concept_loss.item()
                #weight_attn_reg = dcl.entropy_reg(weight_attn)
                penalty = 1e-6
                loss = concept_loss + lambda_coeff * task_loss + penalty * torch.mean(weight_attn.abs()) + penalty * torch.mean(bias_attn.abs()**2)
            elif configuration=='cem':
                c_emb, c_pred = cem(emb)
                y_pred = ff(c_emb.flatten(start_dim=1))
                concept_loss = loss_form(c_pred, true_concepts)
                task_loss = loss_form(y_pred, target)
                running_task_loss += task_loss.item()
                running_concept_loss += concept_loss.item()
                loss = concept_loss + lambda_coeff * task_loss
            elif configuration=='cbm_fc':
                c_pred = ff_concept(emb)
                y_pred = fc_task(c_pred)
                concept_loss = loss_form(c_pred, true_concepts)
                task_loss = loss_form(y_pred, target) 
                running_task_loss += task_loss.item()
                running_concept_loss += concept_loss.item()
                loss = concept_loss + lambda_coeff*task_loss
            elif configuration in ['DTree', 'XGBoost']:
                c_pred = ff_concept(emb)
                concept_loss = loss_form(c_pred, true_concepts)
                task_loss = torch.Tensor([0]).to(device)
                running_task_loss += task_loss.item()
                running_concept_loss += concept_loss.item()
                loss = concept_loss
            elif configuration=='cbm_ff':
                c_pred = ff_concept(emb)
                y_pred = ff_task(c_pred)
                concept_loss = loss_form(c_pred, true_concepts)
                task_loss = loss_form(y_pred, target) 
                running_task_loss += task_loss.item()
                running_concept_loss += concept_loss.item()
                loss = concept_loss + lambda_coeff*task_loss
            elif configuration=='modified_cem':
                c_emb = cem(emb, true_concepts)
                y_pred = dcr(c_emb, true_concepts)
                task_loss = loss_form(y_pred, target)    
                running_task_loss += task_loss.item()
                loss = task_loss
            elif configuration=='cem_ct':
                c_emb = cem(emb, true_concepts)
                y_pred = ff(c_emb.flatten(start_dim=1))
                task_loss = loss_form(y_pred, target)    
                running_task_loss += task_loss.item()
                loss = task_loss
            elif configuration=='modified_dcl':
                c_emb = cem(emb, true_concepts)
                y_pred, weight_attn, bias_attn = dcl(c_emb, true_concepts, return_attn=True)
                task_loss = loss_form(y_pred, target)    
                running_task_loss += task_loss.item()
                weight_attn_reg = dcl.entropy_reg(weight_attn)
                penalty = 1e-6
                loss = task_loss + penalty * torch.mean(weight_attn.abs()) + penalty * torch.mean(bias_attn.abs()**2)
            elif configuration=='cbm_fc_ct':
                y_pred = fc_task(true_concepts)
                task_loss = loss_form(y_pred, target)    
                running_task_loss += task_loss.item()
                loss = task_loss
            elif configuration=='cbm_ff_ct':
                y_pred = ff_task(true_concepts)
                task_loss = loss_form(y_pred, target)    
                running_task_loss += task_loss.item()
                loss = task_loss
            elif configuration=='e2e':
                y_pred = ff_task(emb)
                task_loss = loss_form(y_pred, target)    
                running_task_loss += task_loss.item()
                loss = task_loss
                
            loss.backward()
            optimizer.step()

        train_task_losses.append(running_task_loss/len(loaded_train))
        train_concept_losses.append(running_concept_loss/len(loaded_train))
        val_task_loss, val_concept_loss, _, _, _ = evaluate(loaded_val)
        val_task_losses.append(val_task_loss)
        val_concept_losses.append(val_concept_loss)
        
    # Plot training curves
    fig, ax = plt.subplots(1,2)
    fig.set_size_inches(20,5)
    ax[0].plot(range(1,epochs+1), train_task_losses, label='Training')
    ax[0].plot(range(1,epochs+1), val_task_losses, label='Validation')
    ax[0].grid()
    ax[0].set_xlabel("Epochs")
    ax[0].set_ylabel("Loss")
    ax[0].set_title('Task Loss')
    ax[0].legend()
    ax[1].plot(range(1,epochs+1), train_concept_losses, label='Training')
    ax[1].plot(range(1,epochs+1), val_concept_losses, label='Validation')
    ax[1].grid()
    ax[1].set_xlabel("Epochs")
    ax[1].set_ylabel("Loss")
    ax[1].set_title('Concept Loss')
    ax[1].legend()
    plt.tight_layout()
    plt.savefig(result_folder+'training_plots.png')
    plt.show()

    # Creating a DataFrame from lists
    df = pd.DataFrame({'epoch': range(epochs), 
                      'train_task_losses': train_task_losses,
                      'val_task_losses': val_task_losses,
                      'train_concept_losses': train_concept_losses,
                      'val_concept_losses': val_concept_losses})

    # Name of the CSV file
    csv_file = "training_information.csv"

    # Writing DataFrame to CSV file
    df.to_csv(result_folder+csv_file, index=False)


In [None]:
if configuration in ['DTree_ct', 'DTree', 'XGBoost_ct', 'XGBoost']:
    '''
    if 'ct' not in configuration:      
        # The concept predictor is loaded
        concept_predictor_folder = [x for x in os.listdir(f"results/{dataset}/{backbone}/") if 'cbm_ff' in x and 'cbm_ff_ct' not in x][1]
        print(concept_predictor_folder)
        if backbone=='bert':
            ff_concept = torch.load(f"results/{dataset}/{backbone}/{concept_predictor_folder}/models/ff_concept").cpu()
            concept_encoder = torch.load(f"results/{dataset}/{backbone}/{concept_predictor_folder}/models/concept_encoder").cpu()
        else:
            ff_concept = torch.load(f"results/{dataset}/{backbone}/{concept_predictor_folder}/models/ff_concept").cpu()
    '''
    
    if backbone=='bert':
        for p in concept_encoder.parameters():
            p.requires_grad=False
            
    concept_preds =  torch.zeros(1,n_concepts).to(device)
    real_labels = torch.zeros(1)
    for sentence_batch in tqdm(loaded_train):
        y = torch.Tensor(sentence_batch['labels'])
        # dummy tensor             
        target = torch.nn.functional.one_hot(y.to(torch.long), n_labels).to(torch.float)
        if 'ct' not in configuration:  
            if backbone=='bert':
                inputs = sentence_batch['input_ids'].squeeze().to(torch.long).to(device)
                att_mask = sentence_batch['attention_mask'].squeeze().to(torch.long).to(device)
                emb = concept_encoder(inputs, att_mask, output_hidden_states=True).hidden_states[-1][:,0,:]
                c_pred = ff_concept(emb)   
                del inputs
                del att_mask
                del emb
            else:
                embs = sentence_batch['embedded_reviews'].to(device)
                c_pred = ff_concept(embs)
                del embs
            concept_preds = torch.cat([concept_preds, c_pred])
        else:
            concept_preds = torch.cat([concept_preds, sentence_batch['concept_score'].to(device)])
        real_labels = torch.cat([real_labels, y])
    if 'cuda' in concept_preds.device.type:
        concept_preds = concept_preds.cpu()
    concept_preds = concept_preds[1:,:].detach().numpy()
    real_labels = real_labels[1:].numpy()

    if 'DTree' in configuration:
        dtree.fit(concept_preds, real_labels)
    elif 'XGBoost' in configuration:
        dtrain = xgb.DMatrix(concept_preds, label=real_labels)
        dtree = xgb.train(params, dtrain, num_rounds)

    concept_preds =  torch.zeros(1,n_concepts).to(device)
    
    y_true = torch.zeros(1)
    for sentence_batch in tqdm(loaded_test):
        y = torch.Tensor(sentence_batch['labels'])
        # dummy tensor             
        if 'ct' not in configuration:      
            if backbone=='bert':
                emb = concept_encoder(sentence_batch['input_ids'].squeeze().to(torch.long).to(device), 
                                        sentence_batch['attention_mask'].squeeze().to(torch.long).to(device), 
                                        output_hidden_states=True).hidden_states[-1][:,0,:]
                c_pred = ff_concept(emb)   
            else:
                c_pred = ff_concept(sentence_batch['embedded_reviews'].to(device))
            concept_preds = torch.cat([concept_preds, c_pred])
        else:
            concept_preds = torch.cat([concept_preds, sentence_batch['concept_score'].to(device)])
        y_true = torch.cat([y_true, y])
    if 'cuda' in concept_preds.device.type:
        concept_preds = concept_preds.cpu()
    concept_preds = concept_preds[1:,:].detach().numpy()
    y_true = y_true[1:].numpy()

    if 'DTree' in configuration:
        y_pred = dtree.predict(concept_preds)
    elif 'XGBoost' in configuration:
        dtest = xgb.DMatrix(concept_preds)
        y_pred = dtree.predict(dtest)
        
    c_pred = torch.Tensor(concept_preds)


# Save models

In [None]:
'DTree' in configuration or 'XGBoost' in configuration

In [None]:
model_path = result_folder + 'models/'

if not os.path.exists(model_path):
    os.makedirs(model_path)

if backbone=='bert' and 'ct' not in configuration:
    if 'DTree' in configuration or 'XGBoost' in configuration:
        torch.save(ff_concept, model_path+'concept_encoder') 
    else:
        torch.save(concept_encoder, model_path+'concept_encoder') 

if configuration in ['dcr', 'supervised_dcr']:
    torch.save(cem, model_path+'cem')
    torch.save(dcr, model_path+'dcr')
elif configuration=='cem':
    torch.save(cem, model_path+'cem')
    torch.save(ff, model_path+'ff')
elif configuration=='cem_ct':
    torch.save(cem, model_path+'cem')
    torch.save(ff, model_path+'ff')    
elif configuration=='modified_cem':
    torch.save(cem, model_path+'cem')
    torch.save(dcr, model_path+'dcr')
elif configuration=='cbm_fc':
    torch.save(ff_concept, model_path+'ff_concept')
    torch.save(fc_task, model_path+'fc_task')
elif configuration=='cbm_ff':
    torch.save(ff_concept, model_path+'ff_concept')
    torch.save(ff_task, model_path+'ff_task')
elif configuration=='cbm_fc_ct':
    torch.save(fc_task, model_path+'fc_task')
elif configuration=='cbm_ff_ct':
    torch.save(ff_task, model_path+'ff_task')
elif configuration in ['DTree', 'XGBoost', 'DTree_ct', 'XGBoost_ct']:
    joblib.dump(dtree, model_path+'tree_based_model')
elif configuration in ['dcl', 'supervised_dcl']:
    torch.save(cem, model_path+'cem')
    torch.save(dcl, model_path+'dcl')   
elif configuration in ['modified_dcl']:
    torch.save(cem, model_path+'cem')
    torch.save(dcl, model_path+'dcl')
elif configuration=='e2e':
    torch.save(ff_task, model_path+'ff_task')

# Task results (test-set)

In [None]:
if configuration not in ['DTree', 'XGBoost', 'DTree_ct', 'XGBoost_ct']:
    _, _, y_pred, c_pred, y_true = evaluate(loaded_test)
    y_pred = torch.argmax(y_pred, dim=-1).detach().cpu().numpy()

cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
plt.show()

print(classification_report(y_true, y_pred))

n = len(os.listdir('results'))+1
report_string = classification_report(y_true, y_pred, output_dict=True)

# Convert the report to a DataFrame
report_df = pd.DataFrame(report_string).T
if n_labels==2:
    report_df.index = ["0", "1", "accuracy", "macro avg", "weighted avg"]
else:
    report_df.index = ["0", "1", "2", "accuracy", "macro avg", "weighted avg"]

# Save the DataFrame to a CSV file
report_df.to_csv(result_folder+'classification_report.csv', index=True)

## Concept results (test-set)

In [None]:
if dataset!='depression':
    true_concepts = torch.zeros(1, n_concepts)
    for sentence_batch in loaded_test:
        if dataset=='cebab':
            food = sentence_batch['food']
            ambiance = sentence_batch['ambiance']
            service = sentence_batch['service']
            noise = sentence_batch['noise']
            true_concepts = torch.cat([true_concepts, torch.hstack([food.unsqueeze(1), ambiance.unsqueeze(1), service.unsqueeze(1), noise.unsqueeze(1)])])  
        elif dataset=='drug':
            effectiveness = sentence_batch['effectiveness']
            sideEffects = sentence_batch['sideEffects']
            true_concepts = torch.cat([true_concepts, torch.hstack([effectiveness.unsqueeze(1), sideEffects.unsqueeze(1)])])  
        elif dataset=='emo':
            gioia = sentence_batch['joy']
            fiducia = sentence_batch['trust']
            tristezza = sentence_batch['sadness']
            sorpresa = sentence_batch['surprise']
            true_concepts = torch.cat([true_concepts, torch.hstack([gioia.unsqueeze(1), fiducia.unsqueeze(1), tristezza.unsqueeze(1), sorpresa.unsqueeze(1)])])         
    true_concepts = true_concepts[1:,:]

    fig, ax = plt.subplots(1, n_concepts)
    fig.set_size_inches(20,5)

    if not os.path.exists(result_folder+'concepts'):
        os.makedirs(result_folder+'concepts')

    for i in range(n_concepts):
        cm = confusion_matrix(true_concepts[:,i], 
                              torch.where(c_pred[:,i]>0.5,1,0).detach().cpu().numpy())
        sns.heatmap(cm, annot=True, cmap='Blues', fmt='g', ax=ax[i])
        ax[i].set_xlabel("Epochs")
        ax[i].set_ylabel("Loss")

        cr = pd.DataFrame(classification_report(true_concepts[:,i], torch.where(c_pred[:,i]>0.5,1,0).detach().cpu().numpy(), output_dict=True)).T
        cr.index = ["0", "1", "accuracy", "macro avg", "weighted avg"]
        cr.to_csv(result_folder+f'/concepts/{names[i]}_classification_report.csv', index=True)

        ax[i].set_title(names[i]+'; f1-score macro: '+str(round(classification_report(true_concepts[:,i], torch.where(c_pred[:,i]>0.5,1,0).detach().cpu().numpy(), output_dict=True)['macro avg']['f1-score'],2)))

# Explanations

In [None]:
# Load model
load_model = False
if load_model:
    configuration = 'dcr'
    exp = '4'
    result_folder = f'results/{dataset}/mixtral/{configuration}_E{exp}/'
    print(result_folder)
    model_path = result_folder + 'models/'
    if configuration in ['supervised_dcr', 'dcr']:
        #concept_encoder = torch.load(model_path+'concept_encoder')
        cem = torch.load(model_path+'cem')
        dcr = torch.load(model_path+'dcr')
        model = nn.Sequential(cem, dcr)
    elif configuration=='cem':
        concept_encoder = torch.load(model_path+'concept_encoder')
        cem = torch.load(model_path+'cem')
        ff = torch.load(model_path+'ff')
        model = nn.Sequential(concept_encoder, cem, ff)
    elif configuration=='cbm_fc':
        fc = torch.load(model_path+'fc')    
        model = nn.Sequential(fc)
    elif configuration=='modified_cem':
        cem = torch.load(model_path+'cem')
        dcr = torch.load(model_path+'dcr')
        model = nn.Sequential(cem, dcr)
    elif configuration in ['supervised_dcl', 'dcl']:
        #concept_encoder = torch.load(model_path+'concept_encoder')
        cem = torch.load(model_path+'cem')
        dcl = torch.load(model_path+'dcl')
        model = nn.Sequential(cem, dcl)   
    elif configuration=='modified_dcl':
        cem = torch.load(model_path+'cem')
        dcl = torch.load(model_path+'dcl')
        model = nn.Sequential(cem, dcl)
    elif configuration=='cbm_ff':
        ff = torch.load(model_path+'fc')
        model = nn.Sequential(ff)
    elif configuration=='cbm_fc_bert':
        concept_encoder = torch.load(model_path+'concept_encoder')
        ff_concept = torch.load(model_path+'ff_concept')
        fc_task = torch.load(model_path+'fc_task')
        model = nn.Sequential(concept_encoder, ff_concept, fc_task)
    elif configuration=='cbm_ff_bert':
        ff_concept = torch.load(model_path+'ff_concept')
        concept_encoder = torch.load(model_path+'concept_encoder')
        ff_task = torch.load(model_path+'ff_task')
        model = nn.Sequential(concept_encoder, ff_concept, ff_task)
        
    if configuration in ['dcl', 'modified_dcl', 'supervised_dcl']:
        loss_form = torch.nn.BCEWithLogitsLoss()
        loss_form_concepts =  torch.nn.BCELoss()
    else:
        loss_form =  torch.nn.BCELoss()

In [None]:
rules, original, lab = evaluate(loaded_test, True)

In [None]:
rules_to_store = pd.DataFrame(columns=['text', 'true', 'prediction', 'rule'])
for rule, tokens, l in zip(rules, original, lab):
    d = {}
    d['true'] = int(l)
    d['text'] = tokenizer.decode(tokens, skip_special_tokens=True)
    d['prediction'] = rule['class']
    if 'dcl' in configuration:
        if rule['explanation']!='':
            d['rule'] = {k:round(float(v),2) for k, v in rule['explanation'].items()}
        else:
            d['rule'] = ''
    else:
        d['rule'] = rule['explanation']
    df = pd.DataFrame([d])
    rules_to_store = pd.concat([rules_to_store, df], ignore_index=True)

In [None]:
rules_to_store

In [None]:
rules_to_store.to_csv(result_folder+'rules.csv')