# Import

In [None]:
import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from fvcore.nn import FlopCountAnalysis, flop_count_table
import numpy as np
import matplotlib.pyplot as plt
import os

####################################################
from src.Mydataloader import LoadDataset
from src.Mymodel import MyResNet34
from src.Mymodel import MyResNet_CIFAR
from src.Mytraining import DoTraining
from src.LogViewer import LogViewer

# Setup

In [None]:
"""Dataset selection"""
DATASET = "CIFAR10"
# DATASET = "CIFAR100"
# DATASET = "ImageNet2012"

"""Model selection for CIFAR"""
NUM_LAYERS_LEVEL = 5

"""Dataset parameters"""
BATCH = 256
SHUFFLE = True
NUMOFWORKERS = 8
PIN_MEMORY = True
SPLIT_RATIO = 0.9
"""optimizer parameters"""
OPTIMIZER = "SGD"
# OPTIMIZER = "Adam"
# OPTIMIZER = "Adam_decay"

"""Learning rate scheduler parameters"""
# LOAD_BEFORE_TRAINING = False
LOAD_BEFORE_TRAINING = True
NUM_EPOCHS = 100000

"""Early stopping parameters"""
EARLYSTOPPINGPATIENCE = 3000
file_path = ""
if DATASET == "ImageNet2012":
    file_path = f"{DATASET}/MyResNet34_{BATCH}_{OPTIMIZER}"
else:
    file_path = f"{DATASET}/MyResNet{NUM_LAYERS_LEVEL*6+2}_{BATCH}_{OPTIMIZER}"
    
if SPLIT_RATIO != 0:
    file_path += f"_{int(SPLIT_RATIO*100)}"

In [None]:
file_path

# Loading the dataset

## Define Dateloader

In [None]:
tmp = LoadDataset(root="data", seceted_dataset=DATASET, split_ratio=SPLIT_RATIO)
train_data, valid_data, test_data, COUNT_OF_CLASSES = tmp.Unpack()

In [None]:
train_dataloader = DataLoader(
    train_data,
    batch_size=BATCH,
    shuffle=SHUFFLE,
    num_workers=NUMOFWORKERS,
    pin_memory=PIN_MEMORY,
    # pin_memory_device="cuda",
    persistent_workers=True,
)
print("train.transforms =", train_data.transform, train_dataloader.batch_size)

if SPLIT_RATIO != 0:
    valid_dataloader = DataLoader(
        valid_data,
        batch_size=BATCH,
        shuffle=SHUFFLE,
        num_workers=NUMOFWORKERS,
        pin_memory=PIN_MEMORY,
        # pin_memory_device="cuda",
        persistent_workers=True,
    )
    print("valid.transforms =", valid_data.transform, valid_dataloader.batch_size)
elif SPLIT_RATIO == 0:
    valid_dataloader = None

test_dataloader = DataLoader(
    test_data,
    batch_size=BATCH,
    shuffle=SHUFFLE,
    num_workers=NUMOFWORKERS,
    pin_memory=PIN_MEMORY,
    # pin_memory_device="cuda",
    persistent_workers=True,
)
print("test.transforms =", test_data.transform, test_dataloader.batch_size)

## Confirm that the dataset is loaded properly

In [None]:
if DATASET != "ImageNet2012":
    for X, y in test_dataloader:
        print(f"Shape of X [N, C, H, W]: {X.shape}")
        print("mean of X", X.mean(dim=(0, 2, 3)))
        print(f"Shape of y: {y.shape} {y.dtype}")
        break

In [None]:
if DATASET != "ImageNet2012":
    class_names = test_dataloader.dataset.classes
    count = 0
    fig, axs = plt.subplots(2, 5, figsize=(8, 4))

    for images, labels in test_dataloader:
        images = images.numpy()

        for i in range(len(images)):
            image = images[i]
            label = labels[i]
            image = np.transpose(image, (1, 2, 0))
            image = np.clip(image, 0, 1)
            ax = axs[count // 5, count % 5]
            ax.imshow(image)
            ax.set_title(f"{class_names[label], label}")
            ax.axis('off')
            count += 1
            
            if count == 10:
                break
        if count == 10:
            break
    plt.tight_layout()
    plt.show()

# Define ResNet

## Model Confirm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
if DATASET == "CIFAR10" or DATASET == "CIFAR100":
    """ResNet{20, 32, 44, 56, 110, 1202} for CIFAR"""
    model = MyResNet_CIFAR(
        num_classes=COUNT_OF_CLASSES,
        num_layer_factor=NUM_LAYERS_LEVEL,
        Downsample_option="A",
    ).to(device)
    print(f"ResNet-{5*6+2} for {DATASET} is loaded.")

elif DATASET == "ImageNet2012":
    """ResNet34 for ImageNet 2012"""
    model = MyResNet34(
        num_classes=COUNT_OF_CLASSES, 
        Downsample_option="A"
    ).to(device)
    # model = models.resnet34(pretrained=True).to(device)
    # model = models.resnet34(pretrained=False).to(device)
    print(f"ResNet-34 for {DATASET} is loaded.")


In [None]:
model.named_modules

In [None]:
tmp_input = torch.rand(BATCH, 3, 32, 32).to(device)
flops = FlopCountAnalysis(model, tmp_input)
print(flop_count_table(flops))

# Define Training

## (1) Define Criterion

In [None]:
criterion = nn.CrossEntropyLoss()

## (2) Define Optimazer

In [None]:
if OPTIMIZER == "Adam":
    optimizer = torch.optim.Adam(model.parameters())
elif OPTIMIZER == "Adam_decay":
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
elif OPTIMIZER == "SGD":
    optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001
    )

## (3) Define Early Stopping

In [None]:
class EarlyStopper:
    def __init__(self, patience, model, file_path):
        self.best_eval_loss = float("inf")
        self.early_stop_counter = 0
        self.PATIENCE = patience
        self.file_path = file_path
        self.model = model
        pass

    def check(self, eval_loss):
        if eval_loss < self.best_eval_loss:
            self.best_eval_loss = eval_loss
            self.early_stop_counter = 0
            print("updated best eval loss :", self.best_eval_loss)
            torch.save(self.model.state_dict(), "models/" + self.file_path + ".pth")
            return False
        else:
            self.early_stop_counter += 1
            if self.early_stop_counter >= self.PATIENCE:
                print(f"Early stop!! best_eval_loss = {self.best_eval_loss}")
                return True
                
    def state_dict(self):
        return {"best_eval_loss": self.best_eval_loss, "early_stop_counter": self.early_stop_counter}
    
    def load_state_dict(self, state_dict):
        self.best_eval_loss = state_dict["best_eval_loss"]
        self.early_stop_counter = state_dict["early_stop_counter"]
        
        return
    
earlystopper = EarlyStopper(EARLYSTOPPINGPATIENCE, model, file_path)

## (4) Define Learning Rate schedualer

In [None]:
scheduler_mapping = {"CIFAR10": 1000, "CIFAR100": 1000, "ImageNet2012": 30}

scheduler = ReduceLROnPlateau(
    optimizer,
    mode="min",
    patience=scheduler_mapping[DATASET],
    factor=0.1,
    verbose=True,
    threshold=1e-4,
    cooldown=100
)

## (5) Define AMP scaler

In [None]:
scaler = torch.cuda.amp.GradScaler(enabled=True)

## Load before process

In [None]:
scaler = torch.cuda.amp.GradScaler(enabled=True)

if LOAD_BEFORE_TRAINING == True and os.path.exists("logs/" + file_path + ".pth.tar"):
    # Read checkpoint as desired, e.g.,
    checkpoint = torch.load(
        "logs/" + file_path + ".pth.tar",
        map_location=lambda storage, loc: storage.cuda(device),
    )
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    scaler.load_state_dict(checkpoint["scaler"])
    scheduler.load_state_dict(checkpoint["scheduler"])
    earlystopper.load_state_dict(checkpoint["earlystopper"])
    logs = checkpoint["logs"]

    print("Suceessfully loaded the All setting and Log file.")
    print(file_path)
    print(f"Current epoch is {len(logs['train_loss'])}")
    print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")
else:
    # Create a dictionary to store the variables
    train_loss = []
    train_acc = []
    eval_loss = []
    valid_acc = []
    test_loss = []
    test_acc = []
    lr_log = []
    logs = {
        "train_loss": train_loss,
        "train_acc": train_acc,
        "valid_loss": eval_loss,
        "valid_acc": valid_acc,
        "test_loss": test_loss,
        "test_acc": test_acc,
        "lr_log": lr_log,
    }
    print("File does not exist. Created a new log.")

In [None]:
optimizer.param_groups[0]["lr"]

# [Training Loop]

In [21]:
Training = DoTraining(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scaler=scaler,
    scheduler=scheduler,
    earlystopper=earlystopper,
    device=device,
    logs=logs,
    file_path=file_path,
)
pre_epochs = len(Training.logs["train_loss"])

for epoch in range(NUM_EPOCHS):
    now = epoch + 1 + pre_epochs
    print(f"[Epoch {epoch+1+pre_epochs}/{NUM_EPOCHS}] :")

    if DATASET == "ImageNet2012":
        eval_loss = Training.SingleEpoch(train_dataloader, valid_dataloader)
    else:
        eval_loss = Training.SingleEpoch(
            train_dataloader, valid_dataloader, test_dataloader
        )

    Training.Save()

    if earlystopper.check(eval_loss) == True:
        break

    print("-" * 50)

100%|██████████| 176/176 [00:08<00:00, 19.80it/s]


Train Loss: 0.0024 | Train Acc: 87.00%
Valid Loss: 0.6854 | Valid Acc: 77.76%
Test  Loss: 0.4705 | Test Acc: 85.49%
--------------------------------------------------
[Epoch 317/100000] :


100%|██████████| 176/176 [00:09<00:00, 18.65it/s]


Train Loss: 0.0024 | Train Acc: 85.00%
Valid Loss: 0.6161 | Valid Acc: 79.54%
Test  Loss: 0.3714 | Test Acc: 88.28%
--------------------------------------------------
[Epoch 318/100000] :


100%|██████████| 176/176 [00:08<00:00, 19.92it/s]


Train Loss: 0.0034 | Train Acc: 78.50%
Valid Loss: 0.6072 | Valid Acc: 79.36%
Test  Loss: 0.3768 | Test Acc: 87.96%
--------------------------------------------------
[Epoch 319/100000] :


100%|██████████| 176/176 [00:08<00:00, 19.62it/s]


Train Loss: 0.0029 | Train Acc: 82.00%
Valid Loss: 0.7127 | Valid Acc: 76.94%
Test  Loss: 0.5086 | Test Acc: 84.11%
--------------------------------------------------
[Epoch 320/100000] :


100%|██████████| 176/176 [00:09<00:00, 19.55it/s]


Train Loss: 0.0025 | Train Acc: 87.00%
Valid Loss: 0.7014 | Valid Acc: 77.44%
Test  Loss: 0.4609 | Test Acc: 86.59%
--------------------------------------------------
[Epoch 321/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.24it/s]


Train Loss: 0.0023 | Train Acc: 86.00%
Valid Loss: 0.6106 | Valid Acc: 80.28%
Test  Loss: 0.3617 | Test Acc: 88.20%
--------------------------------------------------
[Epoch 322/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.25it/s]


Train Loss: 0.0018 | Train Acc: 91.00%
Valid Loss: 0.5935 | Valid Acc: 80.06%
Test  Loss: 0.3902 | Test Acc: 87.37%
--------------------------------------------------
[Epoch 323/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.33it/s]


Train Loss: 0.0024 | Train Acc: 84.00%
Valid Loss: 0.5371 | Valid Acc: 81.84%
Test  Loss: 0.3316 | Test Acc: 89.28%
updated best eval loss : 0.5370827436447143
--------------------------------------------------
[Epoch 324/100000] :


100%|██████████| 176/176 [00:08<00:00, 19.78it/s]


Train Loss: 0.0028 | Train Acc: 84.00%
Valid Loss: 0.5715 | Valid Acc: 81.14%
Test  Loss: 0.3053 | Test Acc: 89.59%
--------------------------------------------------
[Epoch 325/100000] :


100%|██████████| 176/176 [00:10<00:00, 17.12it/s]


Train Loss: 0.0022 | Train Acc: 87.50%
Valid Loss: 0.6461 | Valid Acc: 78.96%
Test  Loss: 0.3596 | Test Acc: 89.07%
--------------------------------------------------
[Epoch 326/100000] :


100%|██████████| 176/176 [00:09<00:00, 19.55it/s]


Train Loss: 0.0022 | Train Acc: 87.00%
Valid Loss: 0.5946 | Valid Acc: 80.12%
Test  Loss: 0.3506 | Test Acc: 89.10%
--------------------------------------------------
[Epoch 327/100000] :


100%|██████████| 176/176 [00:09<00:00, 19.39it/s]


Train Loss: 0.0030 | Train Acc: 82.50%
Valid Loss: 0.5989 | Valid Acc: 80.02%
Test  Loss: 0.3794 | Test Acc: 88.06%
--------------------------------------------------
[Epoch 328/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.10it/s]


Train Loss: 0.0026 | Train Acc: 85.00%
Valid Loss: 0.5746 | Valid Acc: 80.82%
Test  Loss: 0.3569 | Test Acc: 88.51%
--------------------------------------------------
[Epoch 329/100000] :


100%|██████████| 176/176 [00:08<00:00, 19.62it/s]


Train Loss: 0.0016 | Train Acc: 89.50%
Valid Loss: 0.6516 | Valid Acc: 77.80%
Test  Loss: 0.4035 | Test Acc: 87.70%
--------------------------------------------------
[Epoch 330/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.60it/s]


Train Loss: 0.0020 | Train Acc: 87.00%
Valid Loss: 0.6126 | Valid Acc: 79.86%
Test  Loss: 0.3504 | Test Acc: 88.52%
--------------------------------------------------
[Epoch 331/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.12it/s]


Train Loss: 0.0025 | Train Acc: 83.50%
Valid Loss: 0.6582 | Valid Acc: 78.44%
Test  Loss: 0.4586 | Test Acc: 84.98%
--------------------------------------------------
[Epoch 332/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.15it/s]


Train Loss: 0.0026 | Train Acc: 82.50%
Valid Loss: 0.6263 | Valid Acc: 78.92%
Test  Loss: 0.3533 | Test Acc: 89.11%
--------------------------------------------------
[Epoch 333/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.03it/s]


Train Loss: 0.0020 | Train Acc: 86.50%
Valid Loss: 0.5544 | Valid Acc: 81.26%
Test  Loss: 0.3776 | Test Acc: 88.06%
--------------------------------------------------
[Epoch 334/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.23it/s]


Train Loss: 0.0035 | Train Acc: 79.00%
Valid Loss: 0.5757 | Valid Acc: 80.38%
Test  Loss: 0.3238 | Test Acc: 89.41%
--------------------------------------------------
[Epoch 335/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.32it/s]


Train Loss: 0.0023 | Train Acc: 84.50%
Valid Loss: 0.5685 | Valid Acc: 81.16%
Test  Loss: 0.3135 | Test Acc: 90.16%
--------------------------------------------------
[Epoch 336/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.47it/s]


Train Loss: 0.0025 | Train Acc: 83.00%
Valid Loss: 0.5268 | Valid Acc: 81.98%
Test  Loss: 0.3285 | Test Acc: 89.63%
updated best eval loss : 0.5268193453550338
--------------------------------------------------
[Epoch 337/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.46it/s]


Train Loss: 0.0018 | Train Acc: 88.50%
Valid Loss: 0.6527 | Valid Acc: 78.30%
Test  Loss: 0.3927 | Test Acc: 87.59%
--------------------------------------------------
[Epoch 338/100000] :


100%|██████████| 176/176 [00:08<00:00, 19.69it/s]


Train Loss: 0.0027 | Train Acc: 82.50%
Valid Loss: 0.6090 | Valid Acc: 79.50%
Test  Loss: 0.3619 | Test Acc: 88.13%
--------------------------------------------------
[Epoch 339/100000] :


100%|██████████| 176/176 [00:08<00:00, 19.98it/s]


Train Loss: 0.0026 | Train Acc: 83.50%
Valid Loss: 0.5996 | Valid Acc: 80.10%
Test  Loss: 0.3484 | Test Acc: 89.17%
--------------------------------------------------
[Epoch 340/100000] :


100%|██████████| 176/176 [00:08<00:00, 20.39it/s]


Train Loss: 0.0019 | Train Acc: 91.50%
Valid Loss: 0.6078 | Valid Acc: 79.38%
Test  Loss: 0.3625 | Test Acc: 88.26%
--------------------------------------------------
[Epoch 341/100000] :


100%|██████████| 176/176 [00:09<00:00, 18.22it/s]


Train Loss: 0.0023 | Train Acc: 84.50%
Valid Loss: 0.5577 | Valid Acc: 81.42%


In [None]:
view = LogViewer(logs)
view.draw()

In [None]:
view.print_all()

In [None]:
# CHECK = 5410
# logs["train_loss"] = logs["train_loss"][:CHECK]
# logs["train_acc"] = logs["train_acc"][:CHECK]
# logs["valid_loss"] = logs["valid_loss"][:CHECK]
# logs["valid_acc"] = logs["valid_acc"][:CHECK]
# logs["test_loss"] = logs["test_loss"][:CHECK]
# logs["test_acc"] = logs["test_acc"][:CHECK]
# model.load_state_dict(torch.load(f"models/{file_path}.pth"))