# Import

In [1]:
import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR
from torch.utils.data import DataLoader
from fvcore.nn import FlopCountAnalysis, flop_count_table
import numpy as np
import matplotlib.pyplot as plt
import os

####################################################
from src.Mymodel import MyResNet34
from src.Mymodel import MyResNet_CIFAR
from src.Mytraining import DoTraining
from src.Earlystopper import EarlyStopper
from src.LogViewer import LogViewer

In [2]:
import copy
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms.v2 import (
    ToTensor,
    RandomHorizontalFlip,
    Compose,
    RandomCrop,
    RandomShortestSize,
    AutoAugment,
    Normalize,
    TenCrop,
    CenterCrop,
    Pad,
    Resize,
)
from torchvision.transforms.autoaugment import AutoAugmentPolicy

# Setup

In [3]:
"""Dataset selection"""
# DATASET = "CIFAR10"
# DATASET = "CIFAR100"
DATASET = "ImageNet2012"

"""Model selection for CIFAR"""
NUM_LAYERS_LEVEL = 5

"""Dataset parameters"""
BATCH = 256
SHUFFLE = True
NUMOFWORKERS = 8
PIN_MEMORY = True
SPLIT_RATIO = 0

"""optimizer parameters"""
OPTIMIZER = "SGD"
# OPTIMIZER = "Adam"
# OPTIMIZER = "Adam_decay"

"""Learning rate scheduler parameters"""
# LOAD_BEFORE_TRAINING = False
LOAD_BEFORE_TRAINING = True
NUM_EPOCHS = 1000
scheduler_patience_mapping = {"CIFAR10": 100, "CIFAR100": 100, "ImageNet2012": 5}

"""Early stopping parameters"""
EARLYSTOPPINGPATIENCE = 25
file_path = ""
if DATASET == "ImageNet2012":
    file_path = f"{DATASET}/MyResNet34_{BATCH}_{OPTIMIZER}"
else:
    file_path = f"{DATASET}/MyResNet{NUM_LAYERS_LEVEL*6+2}_{BATCH}_{OPTIMIZER}"

if SPLIT_RATIO != 0:
    file_path += f"_{int(SPLIT_RATIO*100)}"

In [4]:
file_path

'ImageNet2012/MyResNet34_256_SGD'

# Loading the dataset

In [5]:
class LoadDataset:
    def __init__(self, root, seceted_dataset, split_ratio=0):
        self.Randp = 0.5
        self.dataset_name = seceted_dataset
        self.split_ratio = split_ratio

        if self.dataset_name[:5] == "CIFAR":
            pass
        elif self.dataset_name == "ImageNet2012":
            self.ImageNetRoot = "data/" + self.dataset_name + "/"

            self.train_data = None
            """논문에 제시된 in testing, 10crop + 멀티스케일"""
            self.valid_data = None
            """
            각 지정된 스케일에 따라 10 crop해야하는데, 5개 scale들의 평균을 내야하니까 좀 번거로움.
            그치만, 학습 중엔 center crop으로 eval하니, 지금 당장 필요하지는 않음.
            """
            compose_totensor = Compose(
                [
                    ToTensor(),
                ]
            )
            ref_test_data = datasets.ImageFolder(
                root=self.ImageNetRoot + "val", transform=compose_totensor
            )
            tmp = []
            scales = [224, 256, 384, 480, 640]
            for i in range(len(scales)):
                tmp.append(copy.deepcopy(ref_test_data))

            test_data = torch.utils.data.ConcatDataset(tmp)
            for i in range(len(scales)):
                test_data.datasets[i].transform = Compose(
                    [
                        RandomShortestSize(min_size=scales[i], antialias=True),
                        TenCrop(size=scales[i]),
                        ToTensor(),
                        Normalize(
                            mean=[0.485, 0.456, 0.406], std=[1, 1, 1], inplace=True
                        ),
                    ]
                )
                test_data.datasets[i].classes = ref_test_data.classes
                test_data.datasets[i].class_to_idx = ref_test_data.class_to_idx

            self.test_data = test_data
            # tmp = [None, None, None, None, None]
            # tmp[0], tmp[1], tmp[2], tmp[3], tmp[4] = random_split(
            #     ref_test_data, [0.2, 0.2, 0.2, 0.2, 0.2]
            # )
            # self.test_data = torch.utils.data.ConcatDataset([tmp])

            # scale = [224, 256, 384, 480, 640]
            # for i in range(len(tmp)):
            #     self.test_data.datasets[i].transform = copy.deepcopy(
            #         ref_test_data.transform
            #     )
            #     self.test_data.datasets[i].transform.transforms.append(
            #         RandomShortestSize(min_size=scale[i], antialias=True)
            #     )
            #     self.test_data.datasets[i].transform.transforms.append(
            #         Pad(padding=int(scale[i] / 8), padding_mode="constant")
            #     )
            #     self.test_data.datasets[i].transform.transforms.append(
            #         TenCrop(size=scale[i])
            #     )

            #     self.test_data.datasets[i].classes = self.train_data.classes
            #     self.test_data.datasets[i].class_to_idx = self.train_data.class_to_idx

        else:
            raise ValueError(f"Unsupported dataset: {self.dataset_name}")

        return

    def Unpack(self, print_info=True):
        if print_info == True:
            print(
                "-----------------------------------------------------------------------"
            )
            print("Dataset : ", self.dataset_name)
            print("- Length of Train Set : ", len(self.train_data))
            if self.valid_data != None:
                print("- Length of Valid Set : ", len(self.valid_data))
            if self.test_data != None:
                print("- Length of Test Set : ", len(self.test_data))
            print("- Count of Classes : ", len(self.train_data.classes))
            print(
                "-----------------------------------------------------------------------"
            )
        return (
            self.train_data,
            self.valid_data,
            self.test_data,
            len(self.train_data.classes),
        )

## Define Dateloader

In [6]:
tmp = LoadDataset(root="data", seceted_dataset=DATASET, split_ratio=SPLIT_RATIO)
test_data = tmp.test_data



In [7]:
if test_data is not None:
    test_dataloader = DataLoader(
        test_data,
        batch_size=BATCH,
        shuffle=SHUFFLE,
        num_workers=NUMOFWORKERS,
        pin_memory=PIN_MEMORY,
        # pin_memory_device="cuda",
        persistent_workers=True,
    )
    
else:
    test_dataloader = None

In [8]:
for i in range(5):
    
    print(test_data.datasets[i].transform)
    print(test_data.datasets[i].__len__())

Compose(
      RandomShortestSize(min_size=[224], interpolation=InterpolationMode.BILINEAR, antialias=True)
      TenCrop(size=(224, 224), vertical_flip=False)
      ToTensor()
      Normalize(mean=[0.485, 0.456, 0.406], std=[1, 1, 1], inplace=True)
)
50000
Compose(
      RandomShortestSize(min_size=[256], interpolation=InterpolationMode.BILINEAR, antialias=True)
      TenCrop(size=(256, 256), vertical_flip=False)
      ToTensor()
      Normalize(mean=[0.485, 0.456, 0.406], std=[1, 1, 1], inplace=True)
)
50000
Compose(
      RandomShortestSize(min_size=[384], interpolation=InterpolationMode.BILINEAR, antialias=True)
      TenCrop(size=(384, 384), vertical_flip=False)
      ToTensor()
      Normalize(mean=[0.485, 0.456, 0.406], std=[1, 1, 1], inplace=True)
)
50000
Compose(
      RandomShortestSize(min_size=[480], interpolation=InterpolationMode.BILINEAR, antialias=True)
      TenCrop(size=(480, 480), vertical_flip=False)
      ToTensor()
      Normalize(mean=[0.485, 0.456, 0.406], std=[

## Confirm that the dataset is loaded properly

In [9]:
# if test_data is not None:
#     for X, y in test_dataloader:
#         print(f"Shape of X [N, C, H, W]: {X.shape}")
#         print("mean of X", X.mean(dim=(0, 2, 3)))
#         print(f"Shape of y: {y.shape} {y.dtype}")
#         break
    
#     class_names = test_dataloader.dataset.classes
#     count = 0
#     fig, axs = plt.subplots(2, 5, figsize=(8, 4))

#     for images, labels in test_dataloader:
#         images = images.numpy()

#         for i in range(len(images)):
#             image = images[i]
#             label = labels[i]
#             image = np.transpose(image, (1, 2, 0))
#             image = np.clip(image, 0, 1)
#             ax = axs[count // 5, count % 5]
#             ax.imshow(image)
#             ax.set_title(f"{class_names[label], label}")
#             ax.axis("off")
#             count += 1

#             if count == 10:
#                 break
#         if count == 10:
#             break
#     plt.tight_layout()
#     plt.show()

# Define ResNet

## Model Confirm

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [11]:

"""ResNet34 for ImageNet 2012"""
model = MyResNet34(
    num_classes=1000, 
    Downsample_option="B"
).to(device)
# model = models.resnet34(pretrained=True).to(device)
# model = models.resnet34(pretrained=False).to(device)
print(f"ResNet-34 for {DATASET} is loaded.")


ResNet-34 for ImageNet2012 is loaded.


In [12]:
# model.named_modules

In [13]:
# tmp_input = torch.rand(BATCH, 3, 32, 32).to(device)
# flops = FlopCountAnalysis(model, tmp_input)
# print(flop_count_table(flops))

# Define Training

## (1) Define Criterion

In [14]:
criterion = nn.CrossEntropyLoss()

## (2) Define Optimazer

In [15]:
if OPTIMIZER == "Adam":
    optimizer = torch.optim.Adam(model.parameters())
elif OPTIMIZER == "Adam_decay":
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
elif OPTIMIZER == "SGD":
    optimizer = torch.optim.SGD(
        model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001
    )

## (3) Define Early Stopping

In [16]:
earlystopper = EarlyStopper(patience=EARLYSTOPPINGPATIENCE, model=model, file_path=file_path)

## (4) Define Learning Rate schedualer

In [17]:
scheduler = ReduceLROnPlateau(
    optimizer,
    mode="min",
    patience=scheduler_patience_mapping[DATASET],
    factor=0.1,
    verbose=True,
    threshold=1e-4,
    cooldown=5,
)


## (5) Define AMP scaler

In [18]:
scaler = torch.cuda.amp.GradScaler(enabled=True)

## Load before process

In [19]:
scaler = torch.cuda.amp.GradScaler(enabled=True)

if LOAD_BEFORE_TRAINING == True and os.path.exists("logs/" + file_path + ".pth.tar"):
    # Read checkpoint as desired, e.g.,
    checkpoint = torch.load(
        "logs/" + file_path + ".pth.tar",
        map_location=lambda storage, loc: storage.cuda(device),
    )
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    scaler.load_state_dict(checkpoint["scaler"])
    scheduler.load_state_dict(checkpoint["scheduler"])
    earlystopper.load_state_dict(checkpoint["earlystopper"])
    logs = checkpoint["logs"]

    print("Suceessfully loaded the All setting and Log file.")
    print(file_path)
    print(f"Current epoch is {len(logs['train_loss'])}")
    print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")
else:
    # Create a dictionary to store the variables
    train_loss = []
    train_acc = []
    eval_loss = []
    valid_acc = []
    test_loss = []
    test_acc = []
    lr_log = []
    logs = {
        "train_loss": train_loss,
        "train_acc": train_acc,
        "valid_loss": eval_loss,
        "valid_acc": valid_acc,
        "test_loss": test_loss,
        "test_acc": test_acc,
        "lr_log": lr_log,
    }
    print("File does not exist. Created a new log.")

Suceessfully loaded the All setting and Log file.
ImageNet2012/MyResNet34_256_SGD
Current epoch is 3
Current learning rate: 0.1


In [20]:
optimizer.param_groups[0]["lr"]

0.1

# [Training Loop]

In [21]:
Training = DoTraining(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scaler=scaler,
    scheduler=scheduler,
    earlystopper=earlystopper,
    device=device,
    logs=logs,
    file_path=file_path,
)
pre_epochs = len(Training.logs["train_loss"])

for epoch in range(NUM_EPOCHS):
    now = epoch + 1 + pre_epochs
    print(f"[Epoch {epoch+1+pre_epochs}/{NUM_EPOCHS}] :")

    if DATASET == "ImageNet2012":
        eval_loss = Training.SingleEpoch(test_dataloader, test_dataloader)

    if earlystopper.check(eval_loss) == True:
        break

    print("-" * 50)

[Epoch 4/1000] :


  0%|          | 0/977 [00:03<?, ?it/s]


ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torch/utils/data/dataset.py", line 302, in __getitem__
    return self.datasets[dataset_idx][sample_idx]
           ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torchvision/datasets/folder.py", line 231, in __getitem__
    sample = self.transform(sample)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torchvision/transforms/v2/_container.py", line 53, in forward
    outputs = transform(*inputs)
              ^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 50, in forward
    flat_outputs = [
                   ^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 51, in <listcomp>
    self._transform(inpt, params) if needs_transform else inpt
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py", line 425, in _transform
    return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torchvision/transforms/v2/_geometry.py", line 418, in _call_kernel
    return super()._call_kernel(functional, inpt, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torchvision/transforms/v2/_transform.py", line 35, in _call_kernel
    return kernel(inpt, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_geometry.py", line 2353, in _ten_crop_image_pil
    non_flipped = _five_crop_image_pil(image, size)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lee/anaconda3/lib/python3.11/site-packages/torchvision/transforms/v2/functional/_geometry.py", line 2268, in _five_crop_image_pil
    raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
ValueError: Requested crop size (480, 480) is bigger than input size (479, 595)


In [None]:
view = LogViewer(logs)
view.draw()

In [None]:
view.print_all()

In [None]:
# CHECK = 5410
# logs["train_loss"] = logs["train_loss"][:CHECK]
# logs["train_acc"] = logs["train_acc"][:CHECK]
# logs["valid_loss"] = logs["valid_loss"][:CHECK]
# logs["valid_acc"] = logs["valid_acc"][:CHECK]
# logs["test_loss"] = logs["test_loss"][:CHECK]
# logs["test_acc"] = logs["test_acc"][:CHECK]
# model.load_state_dict(torch.load(f"models/{file_path}.pth"))

In [None]:
earlystopper.early_stop_counter




In [None]:
from PIL import Image
import matplotlib.pyplot as plt

# Open the image
image_path = "n01440764_190.JPEG"
image = Image.open(image_path)

# Display the image
plt.imshow(image)
plt.axis("off")
plt.show()


In [None]:
image.size