In [None]:
import sys

path_append = "../" # Go up one directory from where you are.
sys.path.append(path_append) 

import torch
from tools.setting.ml_params import MLParameters
from tools.setting.data_config import DataConfig
from nn.utils.init import set_random_seed
set_random_seed(0)

In [None]:
from torchvision import datasets, transforms

n_img_sz = 28  
attribute_indices = torch.tensor([4, 9])

transform_mnist = transforms.Compose([
    transforms.Resize(n_img_sz),  
    transforms.ToTensor(),  
    transforms.Normalize((0.5,), (0.5,)),  
])

trainset = datasets.MNIST(root='../data/mnist', train=True, transform=transform_mnist, download=False)
testset = datasets.MNIST(root='../data/mnist', train=False, transform=transform_mnist, download=False)


In [None]:
data_config = DataConfig(dataset_name = 'mnist', task_type='multi_label_classification', obs_shape=[1, 28, 28], label_size= 5,
                         show_image_indices=[0, 1, 2, 3])

#  Set training configuration from the AlgorithmConfig class, returning them as a Namespace object.
ml_params = MLParameters()
ml_params.core_model_name = 'gpt' 
ml_params.encoder_model_name = 'stylegan'

first_data = trainset[0]
X, y = first_data

print(f"Input shape: {X.shape}")
print(f"Label shape: {y}")

print(f"Total number of samples in trainset: {len(trainset)}")

In [None]:
# Custom dataset class for CelebA dataset
class Mnist(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.attribute_indices = attribute_indices
        
    def __getitem__(self, index):
        X, y = self.dataset[index] # Get the image and label at the specified index
        y_one_hot = torch.zeros(10)  # Assuming MNIST labels (0-9)
        y_one_hot[y] = 1  # Convert scalar label to one-hot encoding
        # Select attributes using index_select, simulate attributes selection using one-hot
        y = torch.index_select(y_one_hot.unsqueeze(0), 1, self.attribute_indices).squeeze(0)
        return X, y # Return the image and the selected attri   butes
    
    def __len__(self):
        return len(self.dataset) # Return the size of the dataset
        
trainset = Mnist(trainset)
testset = Mnist(testset)

In [None]:
from trainer_hub import TrainerHub

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Initialize the TrainerHub class with the training configuration, data configuration, device, and use_print and use_wandb flags
trainer_hub = TrainerHub(ml_params, data_config, device, use_print=True, use_wandb=False)

In [None]:
trainer_hub.train(trainset, testset)