# Pre-Training Code

## Import Libraries

In [None]:
from os import path

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from transformers import AutoImageProcessor

from tqdm.auto import tqdm
from huggingface_hub import login as hf_login

### Huggingface login

In [None]:
# Uncomment out the line below when you need to login to Huggingface
hf_login()

### Check GPU Availability

In [None]:
!nvidia-smi

In [None]:
# Set CUDA Device Number
DEVICE_NUM = 0

device = torch.device(f"cuda:{DEVICE_NUM}") if torch.cuda.is_available() else torch.device("cpu")
print(f"INFO: Using device - {device}")

## Load DataSets

In [None]:
from reo_gia.datasets import ImageNet1K, CIFAR100, CIFAR10, DatasetHolder, build_augmentation

In [None]:
DATA_ROOT = path.join(".", "data")

IMAGENET1Ks = DatasetHolder(
    train=ImageNet1K(root=DATA_ROOT, force_download=False, train=True, transform=build_augmentation(ImageNet1K.img_size)),
    valid=ImageNet1K(root=DATA_ROOT, force_download=False, valid=True),
    test=ImageNet1K(root=DATA_ROOT, force_download=False, train=False)
)
print(f"INFO: Dataset loaded successfully - {IMAGENET1Ks}")

CIFAR100s = DatasetHolder(
    train=CIFAR100(root=DATA_ROOT, download=True, train=True, transform=build_augmentation(CIFAR100.img_size)),
    test=CIFAR100(root=DATA_ROOT, download=True, train=False)
).split_train_valid()
print(f"INFO: Dataset loaded successfully - {CIFAR100s}")

CIFAR10s = DatasetHolder(
    train=CIFAR10(root=DATA_ROOT, download=True, train=True, transform=build_augmentation(CIFAR10.img_size)),
    test=CIFAR10(root=DATA_ROOT, download=True, train=False)
).split_train_valid()
print(f"INFO: Dataset loaded successfully - {CIFAR10s}")

In [None]:
ChosenDataset: DatasetHolder = CIFAR10s

In [None]:
# Dataset Configuration
dataset_config = ChosenDataset.config
dataset_name = ChosenDataset.dataset_name
num_classes = ChosenDataset.num_classes

print(f"INFO: Dataset Size - {ChosenDataset}")

## Define Model

In [None]:
from reo_gia.models import (
    BaseModel,
    # Target Models
    ViTSmall, ViTBase, SwinTiny, SwinSmall,
    # Compare Models
    ConvNeXtTiny, ConvNeXtSmall, ResNet34, ResNet50
)

In [None]:
TargetModel: BaseModel = ConvNeXtTiny

In [None]:
# Initialize Model
TargetModel.dataset_name = dataset_name

model = TargetModel.from_pretrained(num_classes=num_classes)
model.to(device)

In [None]:
if TargetModel in [SwinTiny, SwinSmall]:  # Apply Swin-specific configuration
    if 'size' in dataset_config:
        del dataset_config['size']

# Load Image Processor
processor = AutoImageProcessor.from_pretrained(model.model_id)#, **dataset_config)

## Training Loop

In [None]:
def avg(lst):
    try:
        return sum(lst) / len(lst)
    except ZeroDivisionError:
        return 0

In [None]:
# Set Batch Size
BATCH_SIZE = 512, 512, 512
#BATCH_SIZE = list(map(len, [ChosenDataset.train, ChosenDataset.valid, ChosenDataset.test]))
BATCH_SIZE

In [None]:
def collate_fn(batch):
    images, labels = zip(*batch)
    images = processor(images=list(images), return_tensors="pt")["pixel_values"]
    labels = torch.tensor(labels)
    return images, labels

loader_config = dict(collate_fn=collate_fn)

In [None]:
MULTI_PROCESSING = True  # Set False if DataLoader is causing issues

from platform import system
if MULTI_PROCESSING and system() != "Windows":  # Multiprocess data loading is not supported on Windows
    import multiprocessing
    cpu_cores = multiprocessing.cpu_count()
    loader_config['num_workers'] = cpu_cores
    print(f"INFO: Number of CPU cores - {cpu_cores}")
else:
    cpu_cores = 0
    print("INFO: Using DataLoader without multi-processing.")

train_loader = DataLoader(ChosenDataset.train, batch_size=BATCH_SIZE[0], shuffle=True, **loader_config)
valid_loader = DataLoader(ChosenDataset.valid, batch_size=BATCH_SIZE[1], shuffle=False, **loader_config)
test_loader = DataLoader(ChosenDataset.test, batch_size=BATCH_SIZE[2], shuffle=False, **loader_config)

In [None]:
# Set Epoch Count & Learning Rate
EPOCHS = 2000
LEARNING_RATE = 1e-4, 1e-6
WEIGHT_DECAY = 0.05

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE[0], weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LEARNING_RATE[1])

In [None]:
train_length, valid_length = map(len, (train_loader, valid_loader))

try:
    epochs = tqdm(range(EPOCHS), desc="Running Epochs")
    with (tqdm(total=train_length, desc="Training") as train_progress,
        tqdm(total=valid_length, desc="Validation") as valid_progress):  # Set up Progress Bars

        for epoch in epochs:
            train_progress.reset(total=train_length)
            valid_progress.reset(total=valid_length)

            train_acc, train_loss, val_acc, val_loss = [], [], [], []

            # Training
            model.train()
            for i, (inputs, targets) in enumerate(train_loader):
                optimizer.zero_grad()
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)

                outputs.loss.backward()
                optimizer.step()
                scheduler.step()

                train_loss.append(outputs.loss.item())
                train_acc.append((torch.argmax(outputs.logits, dim=1) == targets.data).sum().item() / len(inputs))

                train_progress.update(1)
                print(f"\rEpoch [{epoch+1:4}/{EPOCHS:4}], Step [{i+1:4}/{train_length:4}], Acc: {avg(train_acc):.6%}, Loss: {avg(train_loss):.6f}", end="")

            # Validation
            model.eval()
            with torch.no_grad():
                for inputs, targets in valid_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = model(inputs)  # but not use model loss

                    val_loss.append(outputs.loss.item())
                    val_acc.append((torch.argmax(outputs.logits, dim=1) == targets.data).sum().item() / len(inputs))
                    valid_progress.update(1)

            print(f"\rEpoch [{epoch+1:4}/{EPOCHS:4}], Step [{train_length:4}/{train_length:4}], Acc: {avg(train_acc):.6%}, Loss: {avg(train_loss):.6f}, Valid Acc: {avg(val_acc):.6%}, Valid Loss: {avg(val_loss):.6f}", end="\n" if (epoch+1) % 5 == 0 or (epoch+1) == EPOCHS else "")
except KeyboardInterrupt:
    print("\nINFO: Training interrupted by user. Saving model...")
finally:
    # Model Save
    save_directory = model.save_pretrained()  # Save to "./results/"
    processor.save_pretrained(save_directory)  # Save processor config to the same directory
    print(f"INFO: Model saved successfully at {save_directory}")

# Model Evaluation

In [None]:
# Load Model
model = TargetModel.from_pretrained(save_directory, num_classes=num_classes)
model.to(device)

In [None]:
corrects = 0
total = len(ChosenDataset.test)

with torch.no_grad():
    for inputs, targets in tqdm(test_loader, desc=f"Evaluating {ChosenDataset.dataset_name}"):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)

        logits = outputs.logits if hasattr(outputs, "logits") else outputs
        preds = torch.argmax(logits, dim=1)
        corrects += (preds == targets).sum().item()

acc = corrects / total
print(f"\n" + "="*40)
print(f"Accuracy: {acc:.6%} ({corrects}/{total})")
print("="*40)