In [1]:
import sys

path_append = "../../" # Go up one directory from where you are.
sys.path.append(path_append) 

import torch
from tools.setting.ml_params import MLParameters
from tools.setting.data_config import DataConfig
from nn.utils.init import set_random_seed
set_random_seed(0)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
from torchvision import datasets, transforms

n_img_sz = 28  

transform_mnist = transforms.Compose([
    transforms.Resize(n_img_sz),  
    transforms.ToTensor(),  
    transforms.Normalize((0.5,), (0.5,)),  
])

trainset = datasets.MNIST(root='../data/mnist', train=True, transform=transform_mnist, download=True)
testset = datasets.MNIST(root='../data/mnist', train=False, transform=transform_mnist, download=True)
num_classes = len(trainset.classes)


In [3]:
ml_params = MLParameters(model_name = 'resnet18')
ml_params.model.ccnet_config.num_layers = 4
ml_params.model.ccnet_config.d_model = 128

obs_shape = trainset[0][0].shape

print(f"Input shape: {obs_shape}")
print(f"Label shape: {num_classes}")

print(f"Total number of samples in trainset: {len(trainset)}")
indices = []
for i in range(num_classes):
    indices.append(torch.where(trainset.targets == i)[0][0].item())
data_config = DataConfig(dataset_name = 'mnist', task_type='multi_class_classification', obs_shape=obs_shape, label_size= num_classes, 
                         show_image_indices=indices)

Input shape: torch.Size([1, 28, 28])
Label shape: 10
Total number of samples in trainset: 60000


In [4]:
# Custom dataset class for CelebA dataset
class Mnist(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        
    def __getitem__(self, index):
        X, y = self.dataset[index]  # Get the image and label at the specified index

        # Ensure y is an integer and within the expected range
        y = torch.tensor(y, dtype=torch.long).unsqueeze(0)
        return X, y  # Return the image and the one-hot encoded label
    
    def __len__(self):
        return len(self.dataset) # Return the size of the dataset
        
trainset = Mnist(trainset)
testset = Mnist(testset)

In [5]:
from trainer_hub import TrainerHub

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Initialize the TrainerHub class with the training configuration, data configuration, device, and use_print and use_wandb flags
trainer_hub = TrainerHub(ml_params, data_config, device, use_print=True, use_wandb=False)



Trainer Name: causal_trainer


[1mModelParameters Parameters:[0m


Unnamed: 0,ccnet_config,ccnet_network,encoder_config,encoder_network
0,See details below,resnet18,,none


[3m
Detailed ccnet_config Configuration:[0m


Unnamed: 0,ccnet_config_model_name,ccnet_config_num_layers,ccnet_config_d_model,ccnet_config_dropout,ccnet_config_obs_shape
0,resnet18,4,128,0.05,"(1, 28, 28)"


[1mTrainingParameters Parameters:[0m


Unnamed: 0,batch_size,max_iters,max_seq_len,min_seq_len,num_epoch
0,64,100000,,,100


[1mOptimizationParameters Parameters:[0m


Unnamed: 0,clip_grad_range,decay_rate_100k,learning_rate,max_grad_norm,scheduler_type
0,,0.05,0.0002,1.0,exponential


[1mAlgorithmParameters Parameters:[0m


Unnamed: 0,enable_diffusion,error_function,reset_pretrained
0,False,mse,False


[1mDataConfig Parameters:[0m


Unnamed: 0,dataset_name,task_type,obs_shape,label_size,explain_size,explain_layer,state_size,show_image_indices
0,mnist,multi_class_classification,"(1, 28, 28)",10,64,tanh,,"[1, 3, 5, 7, 2, 0, 13, 15, 17, 4]"








In [6]:
trainer_hub.train(trainset, testset)

[0/100][700/937][Time 13.20]
Unified LR across all optimizers: 0.00019584377661232514
--------------------Training Metrics--------------------
CCNet:  Three Resnet18
Inf: 0.0935	Gen: 0.1558	Rec: 0.1451	E: 0.0142	R: 0.0082	P: 0.0480
--------------------Test Metrics------------------------
accuracy: 0.4414
precision: 0.3534
recall: 0.4468
f1_score: 0.3704

