In [1]:
# Configurations
import torch 
from datetime import datetime
import logging
from pathlib import Path 
import os 


In [3]:
# Configure device: CUDA, MPS, CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA as device")
else:
    # Check that MPS is available
    if not torch.backends.mps.is_available():
        if not torch.backends.mps.is_built():
            print("MPS not available because the current PyTorch install was not "
                  "built with MPS enabled.")
        else:
            print("MPS not available because the current MacOS version is not 12.3+ "
                  "and/or you do not have an MPS-enabled device on this machine.")
        device = torch.device("cpu")
        print("Using CPU as device")
    else:
        device = torch.device("mps")
        print("Using MPS as device")

# for mps 
device = torch.device("cpu")

    
torch.set_default_device(device)
print(f'{device}')

Using MPS as device
mps


In [10]:
# Configure Directory
project_dir = Path(os.getcwd()).parent
rawdata_dir = project_dir / "rawdata"
data_dir = project_dir / "data"
model_dir = project_dir / "models"
log_dir = project_dir / "logs"

rawdata_dir.mkdir(parents=True, exist_ok=True)
data_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)

print(f'project_dir: {project_dir}')
print(f'rawdata_dir: {rawdata_dir}')
print(f'data_dir: {data_dir}')
print(f'model_dir: {model_dir}')
print(f'log_dir: {log_dir}')

project_dir: /Users/ball/Documents/workspace
rawdata_dir: /Users/ball/Documents/workspace/rawdata
data_dir: /Users/ball/Documents/workspace/data
model_dir: /Users/ball/Documents/workspace/models
log_dir: /Users/ball/Documents/workspace/logs


In [25]:
# Configure Logger 

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = log_dir / f'log_{timestamp}.log'

logger = logging.getLogger('transformer_log')
logger.setLevel(logging.INFO)

file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.DEBUG)


In [None]:
# Define input image size 
# We are going to resize the original image

DATASET="mnist" 
# "mnist" or "cifar10". default is "mnist"
# if using "mnist", please change IN_CHANNELS=1
# if using "cifar10", please change IN_CHANNELS=3

EPOCHS = 100
BATCH_SIZE = 64
IMG_SIZE = 28
PATCH_SIZE = 4
IN_CHANNELS = 1
N_HEAD = 5
D_MODEL = 200
FFN_HIDDEN = 512 
MLP_HIDDEN = 512
N_LAYERS = 5
CLASS_NUM = 10
DROP_PROB = 0.1 
INIT_LR = 5e-5
NUM_WORKERS=2
WEIGHT_DECAY=1e-4
GRADIENT_CLIP = 1.0


LOAD_MODEL=False
LOADING_MODEL_NAME="model_20241202_234017_1"

logger.info(f'CONFIGURATION START')
logger.info(f'EPOCHS: {EPOCHS}')
logger.info(f'BATCH_SIZE: {BATCH_SIZE}')
logger.info(f'IMG_SIZE: {IMG_SIZE}')
logger.info(f'PATCH_SIZE: {PATCH_SIZE}')
logger.info(f'IN_CHANNELS: {IN_CHANNELS}')
logger.info(f'N_HEAD: {N_HEAD}')
logger.info(f'D_MODEL: {D_MODEL}')
logger.info(f'FFN_HIDDEN: {FFN_HIDDEN}')
logger.info(f'MLP_HIDDEN: {MLP_HIDDEN}')
logger.info(f'N_LAYERS: {N_LAYERS}') 
logger.info(f'CLASS_NUM: {CLASS_NUM}') 
logger.info(f'DROP_PROB: {DROP_PROB}') 
logger.info(f'INIT_LR: {INIT_LR}') 
logger.info(f'NUM_WORKERS: {NUM_WORKERS}')
logger.info(f'WEIGHT_DECAY: {WEIGHT_DECAY}')
logger.info(f'CONFIGURATION END')

# torch.multiprocessing.set_start_method('spawn')