# Loading Trained Model

In [14]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from src.Mymodel import MyResNet34
from src.Mymodel import MyResNet_CIFAR
from src.Mydataloader import LoadDataset
from src.Mytraining import DoTraining

In [15]:
"""Dataset selection"""
# DATASET = "CIFAR10"
# DATASET = "CIFAR100"
DATASET = "ImageNet2012"

"""Model selection for CIFAR"""
NUM_LAYERS_LEVEL = 5

"""Dataset parameters"""
BATCH = 256
SHUFFLE = True
NUMOFWORKERS = 8
PIN_MEMORY = True
SPLIT_RATIO = 0

"""optimizer parameters"""
OPTIMIZER = "SGD"
# OPTIMIZER = "Adam"
# OPTIMIZER = "Adam_decay"


file_path = ""
if DATASET == "ImageNet2012":
    file_path = f"MyResNet34_{BATCH}_{OPTIMIZER}"
    _model_name = f"MyResNet34_{DATASET}_{BATCH}_{OPTIMIZER}"

    # casenum = 1
    # if casenum != 0:
    #     _model_name += f"_case{casenum}"
else:
    file_path = f"MyResNet{NUM_LAYERS_LEVEL*6+2}_{BATCH}_{OPTIMIZER}"
    _model_name = f"MyResNet{NUM_LAYERS_LEVEL*6+2}_{DATASET}_{BATCH}_{OPTIMIZER}"

if SPLIT_RATIO != 0:
    _model_name += f"_{int(SPLIT_RATIO*100)}"
    file_path += f"_{int(SPLIT_RATIO*100)}"

In [16]:
tmp = LoadDataset(root="data", seceted_dataset=DATASET, split_ratio=SPLIT_RATIO)
_, valid_data, test_data, COUNT_OF_CLASSES = tmp.Unpack()

if valid_data is not None:
    valid_dataloader = DataLoader(
        valid_data,
        batch_size=BATCH,
        shuffle=SHUFFLE,
        num_workers=NUMOFWORKERS,
        pin_memory=PIN_MEMORY,
        # pin_memory_device="cuda",
        persistent_workers=True,
    )
    print("valid.transforms =", valid_data.transform, valid_dataloader.batch_size)
else:
    valid_dataloader = None

if test_data is not None:
    test_dataloader = DataLoader(
        test_data,
        batch_size=BATCH,
        shuffle=SHUFFLE,
        num_workers=NUMOFWORKERS,
        pin_memory=PIN_MEMORY,
        # pin_memory_device="cuda",
        persistent_workers=True,
    )
    print("test.transforms =", test_data.transform, test_dataloader.batch_size)
else:
    test_dataloader = None
    
"""Model selection"""   
if valid_dataloader is None:
    eval_dataloader = test_dataloader
elif test_dataloader is None:
    eval_dataloader = valid_dataloader
else:
    raise ValueError("valid_dataloader and test_dataloader cannot be None at the same time")

-----------------------------------------------------------------------
Dataset :  ImageNet2012
- Length of Train Set :  1281167
- Length of Valid Set :  50000
- Count of Classes :  1000
-----------------------------------------------------------------------
valid.transforms = Compose(
      RandomShortestSize(min_size=[256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DATASET == "ImageNet2012":
    model = MyResNet34(num_classes=COUNT_OF_CLASSES, Downsample_option="B").to(device)

else:
    model = MyResNet_CIFAR(
        num_classes=COUNT_OF_CLASSES, num_layer_factor=NUM_LAYERS_LEVEL
    ).to(device)



model.load_state_dict(torch.load(f"models/{_model_name}/{file_path}.pth"))

<All keys matched successfully>

In [18]:
testingclass = DoTraining(
    model=model, criterion=nn.CrossEntropyLoss(), optimizer=None, device=device
)
test_loss, test_acc = testingclass.Forward_eval(eval_dataloader)

eval: 100%|██████████| 196/196 [01:14<00:00,  2.64it/s]


In [19]:
print("test_loss:", test_loss)

test_loss: 1.2982161787091469


In [20]:
print(f"test_acc: {test_acc*100:.2f}%")


test_acc: 72.40%


In [21]:
print(f"test_error: {100 - test_acc*100:.2f}%")

test_error: 27.60%


In [22]:
file_path

'MyResNet34_256_SGD'