In [1]:
import sys

path_append = "../../" # Go up one directory from where you are.
sys.path.append(path_append) 

import torch
from tools.config.ml_config import MLConfig
from tools.config.data_config import DataConfig
from nn.utils.init_layer import set_random_seed
set_random_seed(0)

In [2]:
from torchvision import datasets, transforms

attribute_indices = torch.tensor([4, 9])

transform_mnist = transforms.Compose([
    transforms.ToTensor(),  
    transforms.Normalize((0.5,), (0.5,)),  
])

trainset = datasets.FashionMNIST(root='../data/fmnist', train=True, transform=transform_mnist, download=True)
testset = datasets.FashionMNIST(root='../data/fmnist', train=False, transform=transform_mnist, download=True)
num_classes = len(trainset.classes)


In [3]:
ml_config = MLConfig(model_name = 'resnet18')
ml_config.model.num_layers = 4
ml_config.model.d_model = 128
ml_config.optimization.learning_rate = 2e-4
ml_config.training.error_function = 'mae'

obs_shape = trainset[0][0].shape

print(f"Input shape: {obs_shape}")
print(f"Label shape: {num_classes}")

print(f"Total number of samples in trainset: {len(trainset)}")


# search trainset indices where the attribute from 0 to num_classes-1 is present for one each class.
indices = []
for i in range(num_classes):
    indices.append(torch.where(trainset.targets == i)[0][0].item())
data_config = DataConfig(dataset_name = 'fashion_mnist', task_type='multi_class_classification', obs_shape=obs_shape, label_size= num_classes,
                         show_image_indices=indices)

Input shape: torch.Size([1, 28, 28])
Label shape: 10
Total number of samples in trainset: 60000


In [4]:
# Custom dataset class for CelebA dataset
class FashionMnist(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        
    def __getitem__(self, index):
        X, y = self.dataset[index]  # Get the image and label at the specified index
        y = torch.tensor(y, dtype=torch.long).unsqueeze(0) # Convert the label to a tensor
        return X, y 
    
    def __len__(self):
        return len(self.dataset) # Return the size of the dataset
        
trainset = FashionMnist(trainset)
testset = FashionMnist(testset)

In [5]:
from causal_learning import CausalLearning

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Initialize the CausalLearning class with the training configuration, data configuration, device, and use_print and use_wandb flags
causal_learning = CausalLearning(ml_config, data_config, device, use_print=True, use_wandb=False)



Trainer Name: causal_trainer


[1mModelConfig Parameters:[0m


Unnamed: 0,d_model,dropout,model_name,num_layers,use_seq_input
0,128,0.05,resnet18,4,False


[1mTrainConfig Parameters:[0m


Unnamed: 0,batch_size,error_function,max_seq_len,min_seq_len,num_epoch
0,64,mae,,,100


[1mOptimConfig Parameters:[0m


Unnamed: 0,clip_grad_range,decay_rate_100k,learning_rate,max_grad_norm,scheduler_type
0,,0.05,0.0002,1.0,exponential


[1mDataConfig Parameters:[0m


Unnamed: 0,dataset_name,task_type,obs_shape,label_size,explain_size,show_image_indices
0,fashion_mnist,multi_class_classification,"(1, 28, 28)",10,,"[1, 16, 5, 3, 19, 8, 18, 6, 23, 0]"








In [6]:
causal_learning.train(trainset, testset)

[15/100][245/937][Time 5.44]
Unified LR across all optimizers: 0.00013030758498070239
CCNet:  Three Resnet18
Inf: 0.0090	Gen: 0.1102	Rec: 0.1100	E: 0.0092	R: 0.0087	P: 0.2112

accuracy: 0.8984
precision: 0.8878
recall: 0.8756
f1_score: 0.8788

accuracy: 0.9180
precision: 0.9124
recall: 0.9110
f1_score: 0.9107

