Author:
        
        PARK, JunHo, junho@ccnets.org

        
        KIM, JeongYoong, jeongyoong@ccnets.org
        
    COPYRIGHT (c) 2024. CCNets. All Rights reserved.

In [None]:
import sys

path_append = "../" # Go up one directory from where you are.
sys.path.append(path_append) 

from nn.utils.init import set_random_seed
set_random_seed(0)

import warnings
warnings.filterwarnings("ignore")

In [None]:
import torch
import torchvision.datasets as dset
from torchvision import transforms

# import albumentations
n_img_sz = 64
# Load the CelebA dataset for training. Specify the root directory where the dataset is located
trainset = dset.CelebA(root=path_append + '../data/celeba', split = "train", transform=transforms.Compose([
                            transforms.Resize(n_img_sz), # Transformations include resizing the images to `n_img_sz`
                            transforms.CenterCrop(n_img_sz), # Center cropping to the same size
                            transforms.ToTensor(), # Converting the images to tensors,
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalizing the pixel values to have a mean and standard deviation of 0.5 across all channels.
                        ]), download= True)

testset = dset.CelebA(root=path_append + '../data/celeba', split = "test", transform=transforms.Compose([
                            transforms.Resize(n_img_sz), # Transformations include resizing the images to `n_img_sz`
                            transforms.CenterCrop(n_img_sz), # Center cropping to the same size
                            transforms.ToTensor(), # Converting the images to tensors
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalizing the pixel values to have a mean and standard deviation of 0.5 across all channels.
                        ]), download= True)    

In [None]:
trainset = torch.utils.data.Subset(trainset, range(0, 40000))
testset = torch.utils.data.Subset(testset, range(0, 10000))

In [None]:
label_list = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 
              'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 
              'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 
              'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 
              'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young' ]

male_and_smiling_attributes = torch.tensor([label_list.index('Male'), label_list.index('Smiling')])
eyeglasses_and_young_attributes = torch.tensor([label_list.index('Eyeglasses'), label_list.index('Young')])

In [None]:
# Custom dataset class for CelebA dataset
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class CausalModelDataset(Dataset):
    def __init__(self, dataset, selected_attributes):
        self.dataset = dataset
        self.selected_attributes = selected_attributes
    
    def __getitem__(self, index):
        X, y = self.dataset[index]
        y = torch.index_select(y.unsqueeze(0), 1, self.selected_attributes).squeeze(0)
        return X, y
    
    def __len__(self):
        return len(self.dataset)

class EncodingDataset(Dataset):
    def __init__(self, dataset, attributes, causal_model):
        self.dataset = dataset
        self.attributes = attributes
        self.causal_model = causal_model

        data_loader = DataLoader(dataset=dataset, batch_size=256, shuffle=False, drop_last=False)
        list_encodings = []
        list_labels = []
        with torch.no_grad():
            for images, labels in data_loader:
                images = images.to(self.causal_model.device)
                encodings = self.causal_model.explain(images).detach().cpu()
                attributes = labels[:, self.attributes]
                list_encodings.append(encodings)
                list_labels.append(attributes)
        self.encodings = torch.cat(list_encodings, dim=0)
        self.labels = torch.cat(list_labels, dim=0)
        
    def __getitem__(self, index):
        return self.encodings[index], self.labels[index]

    def __len__(self):
        return len(self.dataset)

In [None]:
from tools.setting.ml_params import MLParameters
from tools.setting.data_config import DataConfig
from trainer_hub import TrainerHub
num_classes = 2
data_config = DataConfig(dataset_name = 'celebA', task_type='multi_label_classification', obs_shape=[3, n_img_sz, n_img_sz], \
                        label_size=num_classes)

#  Set training configuration from the AlgorithmConfig class, returning them as a Namespace object.
ml_params = MLParameters(ccnet_network = 'resnet')

ml_params.training.num_epoch = 1
ml_params.model.ccnet_config.num_layers = 4
ml_params.algorithm.reset_pretrained = True

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
from nn.tabnet import TabNet 
from tools.setting.ml_params import ModelConfig

class AttributeClassifier(torch.nn.Module):
    def __init__(self, input_size, output_size, num_layers=3, hidden_size=256):
        super(AttributeClassifier, self).__init__()
        
        model_config = ModelConfig('tabnet')
        model_config.num_layers = num_layers
        model_config.d_model = hidden_size
        
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        
        # Create a list to hold all layers
        layers = []
        
        # Input layer
        layers.append(torch.nn.Linear(input_size, hidden_size))
        layers.append(torch.nn.ReLU())
        
        ## Add TabNet layers
        layers.append(TabNet(model_config))
        layers.append(torch.nn.ReLU())

        # Output layer
        layers.append(torch.nn.Linear(hidden_size, output_size))
        
        # Register all layers
        self.layers = torch.nn.Sequential(*layers)

    def forward(self, x):
        x = self.layers(x)
        return torch.sigmoid(x)

In [None]:
# Function to train classifier
DECAY_RATE = 0.01
ITERATION_100K = 100000
gamma = pow(DECAY_RATE, 1 / ITERATION_100K)    

def train_classifier(model, trainset, num_epochs=5, gamma=gamma):
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

    model.train()
    for epoch in range(num_epochs):
        train_loader = DataLoader(trainset, batch_size=64, shuffle=True)    
        for data, labels in train_loader:
            data, labels = data.to(device), labels.to(device).float()
            optimizer.zero_grad()
            outputs = model(data)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
    print("Learning rate: ", optimizer.param_groups[0]['lr'])

In [None]:
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, classification_report, f1_score
import matplotlib.pyplot as plt

# Function to evaluate classifier
def test_classifier(model, dataset, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
        for data, labels in dataloader:
            data, labels = data.to(device), labels.to(device).float()
            outputs = model(data)
            preds = torch.sigmoid(outputs).round()
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    return accuracy, f1

In [None]:
# Function to plot accuracy
def plot_accuracy(ax, epochs, selected_results_dict, none_selected_results_dict):
    ax.cla()
    ax.plot(epochs, selected_results_dict['accuracy'], label='Selected Attributes Accuracy')
    ax.plot(epochs, none_selected_results_dict['accuracy'], label='None Selected Attributes Accuracy')
    ax.set_xlabel('Epochs')
    ax.set_ylabel('Accuracy')
    ax.set_title('None Selected Attributes should be Accurate')
    ax.legend()

# Function to plot F1 score
def plot_f1_score(ax, epochs, selected_results_dict, none_selected_results_dict):
    ax.cla()
    ax.plot(epochs, selected_results_dict['f1_score'], label='Selected Attributes F1 Score')
    ax.plot(epochs, none_selected_results_dict['f1_score'], label='None Selected Attributes F1 Score')
    ax.set_xlabel('Epochs')
    ax.set_ylabel('F1 Score')
    ax.set_title('None Selected Attributes should be higher F1 Score')
    ax.legend()

In [None]:

def test_classifiers(epoch, axs, selected_classifier, none_selected_classifier, testset_selected, testset_none_selected, selected_results_dict, none_selected_results_dict):
    # Store and print results
    print(f"Testing causal classifier on selected attributes at epoch {epoch}...")
    set_random_seed(epoch)
    selected_acc, selected_f1 = test_classifier(selected_classifier, testset_selected, device)
    selected_results_dict['accuracy'].append(selected_acc)
    selected_results_dict['f1_score'].append(selected_f1)
    # print selected_acc, selected_f1
    print(f"Testing classifier on selected attributes accuracy {selected_acc}, f1 score {selected_f1}")
    print(f"Testing classifier on none selected attributes at epoch {epoch}...")
    set_random_seed(epoch)
    none_selected_acc, none_selected_f1 = test_classifier(none_selected_classifier, testset_none_selected, device)
    none_selected_results_dict['accuracy'].append(none_selected_acc)
    none_selected_results_dict['f1_score'].append(none_selected_f1)
    print(f"Testing classifier on none selected attributes accuracy {none_selected_acc}, f1 score {none_selected_f1}")

    # Update plots
    epochs = range(1, epoch + 2)
    
    plot_accuracy(axs[0], epochs, selected_results_dict, none_selected_results_dict)
    plot_f1_score(axs[1], epochs, selected_results_dict, none_selected_results_dict)

In [None]:
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

def train_causal_model_and_classifiers(causal_model_selected_attributes, causal_model_none_selected_attributes, num_epoch = 20):

    # Create dictionaries to store results
    selected_results_dict = {'accuracy': [], 'f1_score': []}
    none_selected_results_dict = {'accuracy': [], 'f1_score': []}

    # Setup the initial plot
    plt.ion()  # Turn on interactive mode
    fig, axs = plt.subplots(1, 2, figsize=(10, 4))
    
    # Display the initial plot
    def display_plot():
        plt.tight_layout()
        clear_output(wait=True)
        display(fig)
        plt.pause(0.1)  # Pause to allow the plot to update    
        
    # Initialize the TrainerHub class with the training configuration, data configuration, device, and use_print and use_wandb flags
    trainer_hub = TrainerHub(ml_params, data_config, device, use_print = True, print_interval=200)
    encoding_size = data_config.explain_size
    
    causal_model_dataset = CausalModelDataset(trainset, causal_model_selected_attributes)

    for epoch in range(0, num_epoch):
        selected_classifier = AttributeClassifier(encoding_size, num_classes).to(device)
        none_selected_classifier = AttributeClassifier(encoding_size, num_classes).to(device)

        print(f"Training causal model at epoch {epoch}...")
        trainer_hub.train(causal_model_dataset)
        causal_model = trainer_hub.ccnet
        
        # Train and evaluate classifiers on the explanation datasets
        print("Training causal classifier on selected attributes...")
        trainset_selected_attributes = EncodingDataset(trainset, causal_model_selected_attributes, causal_model)
        testset_selected_attributes = EncodingDataset(testset, causal_model_selected_attributes, causal_model)
        
        num_epoch_for_classifier = 1 if epoch != num_epoch - 1 else num_epoch
        train_classifier(selected_classifier, trainset_selected_attributes, num_epochs=num_epoch_for_classifier)

        print("Training classifier on none selected attributes...")
        trainset_none_selected_attributes = EncodingDataset(trainset, causal_model_none_selected_attributes, causal_model)
        testset_none_selected_attributes = EncodingDataset(testset, causal_model_none_selected_attributes, causal_model)    
        train_classifier(none_selected_classifier, trainset_none_selected_attributes, num_epochs=num_epoch_for_classifier)

        # Test classifiers
        test_classifiers(epoch, axs, 
                         selected_classifier, none_selected_classifier, 
                         testset_selected_attributes, testset_none_selected_attributes, 
                         selected_results_dict, none_selected_results_dict)
        display_plot()

    plt.ioff()  # Turn off interactive mode
    plt.show()

In [None]:
train_causal_model_and_classifiers(causal_model_selected_attributes = male_and_smiling_attributes, causal_model_none_selected_attributes = eyeglasses_and_young_attributes)

In [14]:
train_causal_model_and_classifiers(causal_model_selected_attributes = eyeglasses_and_young_attributes, causal_model_none_selected_attributes = male_and_smiling_attributes)