In [8]:
# Configurations
import torch 
from datetime import datetime
import logging
from pathlib import Path 
import os 


In [9]:
# Configure device: CUDA, MPS, CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA as device")
else:
    # Check that MPS is available
    if not torch.backends.mps.is_available():
        if not torch.backends.mps.is_built():
            print("MPS not available because the current PyTorch install was not "
                  "built with MPS enabled.")
        else:
            print("MPS not available because the current MacOS version is not 12.3+ "
                  "and/or you do not have an MPS-enabled device on this machine.")
        device = torch.device("cpu")
        print("Using CPU as device")
    else:
        device = torch.device("mps")
        print("Using MPS as device")

    
# torch.set_default_device(device)

Using MPS as device


In [10]:
# Configure Directory
project_dir = Path(os.getcwd()).parent
rawdata_dir = project_dir / "rawdata"
data_dir = project_dir / "data"
model_dir = project_dir / "models"
log_dir = project_dir / "logs"

rawdata_dir.mkdir(parents=True, exist_ok=True)
data_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
log_dir.mkdir(parents=True, exist_ok=True)

print(f'project_dir: {project_dir}')
print(f'rawdata_dir: {rawdata_dir}')
print(f'data_dir: {data_dir}')
print(f'model_dir: {model_dir}')
print(f'log_dir: {log_dir}')

project_dir: /Users/ball/Documents/workspace
rawdata_dir: /Users/ball/Documents/workspace/rawdata
data_dir: /Users/ball/Documents/workspace/data
model_dir: /Users/ball/Documents/workspace/models
log_dir: /Users/ball/Documents/workspace/logs


In [25]:
# Configure Logger 

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = log_dir / f'log_{timestamp}.log'

logger = logging.getLogger('transformer_log')
logger.setLevel(logging.INFO)

file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)

In [None]:
# Define input image size 
# We are going to resize the original image
EPOCHS = 100
BATCH_SIZE = 37
IMG_SIZE = 200
PATCH_SIZE = 5 
IN_CHANNELS = 3 
N_HEAD = 8
D_MODEL = 200 
FFN_HIDDEN = 128 
MLP_HIDDEN = 400 
N_LAYERS = 3
CLASS_NUM = 101
DROP_PROB = 0.1 
INIT_LR = 0.01
NUM_WORKERS=2
WARMUP_STEPS=2500

torch.multiprocessing.set_start_method('spawn')

In [None]:
# Define learning rate scheduler.
# If you want to modify the logic of Scheduler, please modify this class

class LRScheduler:
    def __init__(self, optimizer, d_model, warmup_steps, LR_scale=1):
        self.optimizer = optimizer
        self.step_count = 0
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.LR_scale = LR_scale
        self._d_model_factor = self.LR_scale * (self.d_model ** -0.5)
    def step(self):
        self.step_count += 1
        lr = self.calculate_learning_rate()
        self.optimizer.param_groups[0]['lr'] = lr
    def calculate_learning_rate(self):
        minimum_factor = min(self.step_count ** -0.5, self.step_count * self.warmup_steps ** -1.5)
        return self._d_model_factor * minimum_factor