# Pre-Training Code

## Import Libraries

In [None]:
from os import path, mkdir

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from tqdm.auto import tqdm

### Huggingface login

In [None]:
# Uncomment out the line below when you need to login to Huggingface
#!huggingface-cli login

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number 0~7
DEVICE_NUM = 0
device = torch.device(f"cuda:{DEVICE_NUM}") if torch.cuda.is_available() else torch.device("cpu")

print(f"INFO: Using device - {device}")

## Load DataSets

In [None]:
from src.datasets import ImageNet1K, CIFAR100, CIFAR10, DatasetHolder, build_augmentation

In [None]:
DATA_ROOT = path.join(".", "data")

IMAGENETs = DatasetHolder(
    train=ImageNet1K(root=DATA_ROOT, force_download=False, train=True, transform=build_augmentation(ImageNet1K.img_size)),
    valid=ImageNet1K(root=DATA_ROOT, force_download=False, valid=True),
    test=ImageNet1K(root=DATA_ROOT, force_download=False, train=False)
).split_train_attack()
print(f"INFO: Dataset loaded successfully - {IMAGENETs}")

CIFAR100s = DatasetHolder(
    train=CIFAR100(root=DATA_ROOT, download=True, train=True, transform=build_augmentation(CIFAR100.img_size)),
    test=CIFAR100(root=DATA_ROOT, download=True, train=False)
).split_train_valid().split_train_attack()
print(f"INFO: Dataset loaded successfully - {CIFAR100s}")

CIFAR10s = DatasetHolder(
    train=CIFAR10(root=DATA_ROOT, download=True, train=True, transform=build_augmentation(CIFAR10.img_size)),
    test=CIFAR10(root=DATA_ROOT, download=True, train=False)
).split_train_valid().split_train_attack()
print(f"INFO: Dataset loaded successfully - {CIFAR10s}")

In [None]:
CHOSEN_DATASET =  CIFAR10s

train_dataset, valid_dataset, test_dataset = CHOSEN_DATASET.train, CHOSEN_DATASET.valid, CHOSEN_DATASET.test
print(f"INFO: Dataset Size - {CHOSEN_DATASET}")

## DataLoader

In [None]:
# Set Batch Size
#BATCH_SIZE = 512, 512, 512
BATCH_SIZE = len(train_dataset), len(valid_dataset), len(test_dataset)

In [None]:
MULTI_PROCESSING = True  # Set False if DataLoader is causing issues

from platform import system
if MULTI_PROCESSING and system() != "Windows":  # Multiprocess data loading is not supported on Windows
    import multiprocessing
    cpu_cores = multiprocessing.cpu_count()
    print(f"INFO: Number of CPU cores - {cpu_cores}")
else:
    cpu_cores = 0
    print("INFO: Using DataLoader without multi-processing.")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE[0], shuffle=True, num_workers=cpu_cores)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE[1], shuffle=False, num_workers=cpu_cores)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE[2], shuffle=False, num_workers=cpu_cores)

In [None]:
train_loader.show_sample_grid(**CHOSEN_DATASET.config.norm)

## Define Model

In [None]:
from src.models import (
    BaseModel,
    # Target Models
    ViTBase, ViTLarge, SwinTiny, SwinSmall,
    # Compare Models
    ConvNeXtTiny, ConvNeXtSmall, ResNet50
)

In [None]:
TargetModel: BaseModel = SwinSmall

In [None]:
# Initialize Model (automatically loads ImageNet pretrained weights)
TargetModel.dataset_name = CHOSEN_DATASET.train.name
model = TargetModel.from_pretrained()
model.to(device)

## Training Loop

In [None]:
def avg(lst):
    try:
        return sum(lst) / len(lst)
    except ZeroDivisionError:
        return 0

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = CHOSEN_DATASET.config.epoch
LEARNING_RATE = 1e-4, 1e-6
WEIGHT_DECAY = 0.05

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE[0], weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LEARNING_RATE[1])

In [None]:
train_length, valid_length = map(len, (train_loader, valid_loader))

try:
    epochs = tqdm(range(EPOCHS), desc="Running Epochs")
    with (tqdm(total=train_length, desc="Training") as train_progress,
        tqdm(total=valid_length, desc="Validation") as valid_progress):  # Set up Progress Bars

        for epoch in epochs:
            train_progress.reset(total=train_length)
            valid_progress.reset(total=valid_length)

            train_acc, train_loss, val_acc, val_loss = [], [], [], []

            # Training
            model.train()
            for i, (inputs, targets) in enumerate(train_loader):
                optimizer.zero_grad()
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)

                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                scheduler.step()

                train_loss.append(loss.item())
                train_acc.append((torch.max(outputs, 1)[1] == targets.data).sum().item() / len(inputs))

                train_progress.update(1)
                print(f"\rEpoch [{epoch+1:4}/{EPOCHS:4}], Step [{i+1:4}/{train_length:4}], Acc: {avg(train_acc):.6%}, Loss: {avg(train_loss):.6f}", end="")

            # Validation
            model.eval()
            with torch.no_grad():
                for inputs, targets in valid_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = model(inputs)  # but not use model loss

                    val_loss.append(criterion(outputs, targets).item())
                    val_acc.append((torch.max(outputs, 1)[1] == targets.data).sum().item() / len(inputs))
                    valid_progress.update(1)

            print(f"\rEpoch [{epoch+1:4}/{EPOCHS:4}], Step [{train_length:4}/{train_length:4}], Acc: {avg(train_acc):.6%}, Loss: {avg(train_loss):.6f}, Valid Acc: {avg(val_acc):.6%}, Valid Loss: {avg(val_loss):.6f}", end="\n" if (epoch+1) % 5 == 0 or (epoch+1) == EPOCHS else "")
except KeyboardInterrupt:
    print("\nINFO: Training interrupted by user. Saving model...")
finally:
    # Model Save
    save_directory = model.save_pretrained()  # Saves to "./results/"
    print(f"INFO: Model saved successfully at {save_directory}")

# Model Evaluation

In [None]:
# Load Model
model = TargetModel.from_pretrained(save_directory, num_classes=train_dataset.num_classes)
model.to(device)

In [None]:
corrects = 0
test_length = len(test_dataset)

model.eval()
with torch.no_grad():
    for inputs, targets in tqdm(test_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        corrects += (preds == targets.data).sum()
        print(f"Model Accuracy: {corrects/test_length:%}", end="\r")