In [1]:
# Train CAM model with pretrained model in this case, resnet18 and resnet50
# Because data is not fit with resnet back data, imagenet

In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torch.optim import lr_scheduler
from PIL import Image
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
import json
from tqdm import tqdm
import sys
sys.path.append('../')
from random_state import set_seed
sys.path.remove('../')
import time

local_path = 'C:/Users/MZC01-HYUNHOPART/Desktop/PROJECT/PAPER/HAM10000'

# SET RANDOMSEED
set_seed(seed=42)

# Get Mean and Std about MAM10000
with open(f'{local_path}/HAM10000_MeanStd.json', 'r') as ms_json:
    MeanStd = json.load(ms_json)
    
import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

print(f'IMPORTED....')

RANDOM STATE 42 IS STAIBLIZED....
IMPORTED....


In [2]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, label_dir, transform=None, augmented_images=None, augmented_labels=None):
        self.root_dir = root_dir
        self.metadata = pd.read_csv(label_dir)
        self.transform = transform

        # Create a mapping from string labels to integers
        self.label_mapping = {str(label): idx for idx, label in enumerate(self.metadata.iloc[:, 1].unique())}
        
        # Print unique labels for debugging
        print("Unique Labels in Dataset:", self.label_mapping)

        # Augmented data
        self.augmented_images = augmented_images if augmented_images is not None else []
        self.augmented_labels = augmented_labels if augmented_labels is not None else []

    def __len__(self):
        # Return the total number of samples, including both original and augmented
        return len(self.metadata) + len(self.augmented_images)

    def unique_class(self):
        return len(self.metadata.iloc[:, 1].unique())

    def __getitem__(self, idx):
        if idx < len(self.metadata):
            # Original data
            img_name = self.metadata.iloc[idx, 0]
            label_str = self.metadata.iloc[idx, 1]
            label = self.label_mapping[label_str]
            img_path = os.path.join(self.root_dir, img_name + '.jpg')
            image = Image.open(img_path).convert('RGB')
        else:
            # Augmented data
            idx -= len(self.metadata)
            image = self.augmented_images[idx]
            label = self.augmented_labels[idx]
            
        if self.transform:
            image = self.transform(image)

        return image, label


def load_finetune_dataset(data_dir:str, label_dir:str, batch_size:int=32, 
                          MeanStd:dict=MeanStd, num_workers:int=0, apply_augment:bool=True, apply_random:bool=True):
    '''
    Load target fintune dataset to fine tune on pre-trained model
    Argument
        - data_dir: The directory containing the dataset.
        - label_dir: The directory containing label data in csv file.
        - batch_size: The batch size for loading the dataset.    
        - meanstd : Computed the HAM10000 mean and standard deviation value  
    '''
    transform = transforms.Compose([
        transforms.Resize((224, 224)) # ImageNet Trained size
        , transforms.ToTensor()
        , transforms.Normalize(mean=MeanStd['Mean'], std=MeanStd['Std'])
    ])

    dataset = CustomDataset(data_dir, label_dir, transform=transform)
    unique_class = dataset.unique_class()

    # Split the dataset into training, validation and test sets
    train_size = int(0.8 * len(dataset))
    valid_size = int(0.1 * len(dataset))
    test_size  = len(dataset) - train_size - valid_size
        
    # Split train, valid, and test data
    train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size, test_size])

    if apply_augment:
        # Apply data augmentation only to the training dataset
        augmented_train_images, augmented_train_labels, randomness = apply_augmentation(train_dataset, randomness=apply_random)

        # Extend the training dataset with augmented images
        train_dataset.dataset.augmented_images = augmented_train_images
        train_dataset.dataset.augmented_labels = augmented_train_labels
        train_dataset.dataset.metadata = pd.concat([train_dataset.dataset.metadata,
                                                    pd.DataFrame({'Image': [f'image_{i}' for i in range(len(augmented_train_images))],
                                                                'Label': augmented_train_labels})],
                                                ignore_index=True)
    else:
        randomness = False

    # Print information
    print(f'Training dataset size: {len(train_dataset.dataset)}')
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True , num_workers=num_workers)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, valid_loader, test_loader, unique_class, apply_augment, randomness

def apply_augmentation(dataset, randomness:bool=True):
    augmentations = [
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(30),
    ]

    augmented_images = []
    augmented_labels = []
    
    
    for image, label in tqdm(dataset):
        if randomness:
            # Apply a random augmentation to each image
            augmentation = random.choice(augmentations)
            augmented_images.append(transforms.ToPILImage()(augmentation(image)))
            augmented_labels.append(label)

        else :
            # Apply all augmentations to each image
            augmented_images.extend([transforms.ToPILImage()(augmentation(image)) for augmentation in augmentations])
            augmented_labels.extend([label] * len(augmentations))
    
    return augmented_images, augmented_labels, randomness

In [3]:
class EarlyStopping:
    def __init__(self, patience:int=5, threshold:float=0.1):
        self.patience  = patience
        self.threshold = threshold
        self.counter   = 0
        self.best_score= 0
        self.early_stop= False
    
    def __call__(self, val_accuracy):
        score = val_accuracy
        
        if score > self.best_score + self.threshold:
            self.counter   = 0
            self.best_score= score
        
        else:
            self.counter+=1
            if self.counter >= self.patience:
                self.early_stop = True
        
        return self.early_stop

In [4]:

%%time

train_loader, valid_loader, test_loader, unique_class, apply_augment, randomness = load_finetune_dataset(
    data_dir=f'{local_path}/images/', 
    label_dir=f'{local_path}/HAM10000_label.csv',
    batch_size=64, MeanStd=MeanStd, num_workers=0, apply_augment=False, apply_random=True,
    )

print(f'DATA LOADER IS COMPLETED....\n')

Unique Labels in Dataset: {'bkl': 0, 'nv': 1, 'df': 2, 'mel': 3, 'vasc': 4, 'bcc': 5, 'akiec': 6}
Training dataset size: 10015
DATA LOADER IS COMPLETED....

CPU times: total: 0 ns
Wall time: 14 ms


In [5]:
model_architecture = str(input("Please select model architecture betweeb resnet18, resnet50 and resnet101."))

if model_architecture not in ['resnet18', 'resnet50', 'resnet101']:
    raise ValueError(f"Model architecture '{model_architecture}' should be 'resnet18', 'resnet50' and 'resnet101'.")
else:
    print(f'Model set as {model_architecture}......')

Model set as resnet50......


In [6]:

# Trainig Loop
num_epochs = int(input("How many times to fine tune the model? e.g., 10, 20, 30"))

print(f'MODEL WILL FINE TUNE BY {num_epochs} TIMES.....')

MODEL WILL FINE TUNE BY 100 TIMES.....


In [7]:
# Load pre-trained Resnet18 and Resnet50 for training

model_class = getattr(models, model_architecture)
model = model_class(weights=True)
model.fc = nn.Linear(model.fc.in_features, unique_class)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4) # can change sgd for computation power
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Use GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(f'PYTORCH DEVICE : {device}')


PYTORCH DEVICE : cuda:0



In [8]:
%%time

#early_stopping = EarlyStopping(patience=10, threshold=0.1)

# Training
for epoch in tqdm(range(num_epochs), desc="Training Progress", unit="epoch"):
    model.train()
    total_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Average training loss for the epoch
    average_loss = total_loss / len(train_loader)
    
    # update the learning rate after every epoch
    scheduler.step()
        
    # Validation check by every step and last one
    step:int = 1
    if (epoch % step == step-1) or (epoch == num_epochs-1):
        model.eval()
        all_preds, all_labels = [], []
        
        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        # Calculate validation accuracy
        val_accuracy = accuracy_score(all_labels, all_preds) * 100
        tqdm.write(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {average_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%...')

        # Check the early stopping
        # if early_stopping(val_accuracy):
        #     print('Early Stopped....')
        #     break

Training Progress:   1%|          | 1/100 [12:25<20:30:09, 745.56s/epoch]

Epoch 1/100, Training Loss: 0.8635, Validation Accuracy: 70.03%...


Training Progress:   2%|▏         | 2/100 [26:16<21:39:31, 795.63s/epoch]

Epoch 2/100, Training Loss: 0.7430, Validation Accuracy: 41.16%...


Training Progress:   3%|▎         | 3/100 [39:56<21:44:21, 806.82s/epoch]

Epoch 3/100, Training Loss: 0.7454, Validation Accuracy: 73.53%...


Training Progress:   4%|▍         | 4/100 [53:27<21:33:54, 808.69s/epoch]

Epoch 4/100, Training Loss: 0.6504, Validation Accuracy: 75.02%...


Training Progress:   5%|▌         | 5/100 [1:07:07<21:26:43, 812.67s/epoch]

Epoch 5/100, Training Loss: 0.6010, Validation Accuracy: 70.23%...


Training Progress:   6%|▌         | 6/100 [1:20:46<21:16:38, 814.88s/epoch]

Epoch 6/100, Training Loss: 0.5030, Validation Accuracy: 80.32%...


Training Progress:   7%|▋         | 7/100 [1:34:26<21:05:17, 816.32s/epoch]

Epoch 7/100, Training Loss: 0.4320, Validation Accuracy: 80.52%...


Training Progress:   8%|▊         | 8/100 [1:47:59<20:50:27, 815.52s/epoch]

Epoch 8/100, Training Loss: 0.3967, Validation Accuracy: 80.82%...


Training Progress:   9%|▉         | 9/100 [2:01:43<20:40:44, 818.07s/epoch]

Epoch 9/100, Training Loss: 0.3468, Validation Accuracy: 81.92%...


Training Progress:  10%|█         | 10/100 [2:15:27<20:29:49, 819.88s/epoch]

Epoch 10/100, Training Loss: 0.2843, Validation Accuracy: 79.62%...


Training Progress:  11%|█         | 11/100 [2:29:08<20:16:43, 820.26s/epoch]

Epoch 11/100, Training Loss: 0.1880, Validation Accuracy: 81.62%...


Training Progress:  12%|█▏        | 12/100 [2:42:38<19:58:11, 816.95s/epoch]

Epoch 12/100, Training Loss: 0.1519, Validation Accuracy: 82.02%...


Training Progress:  13%|█▎        | 13/100 [2:56:34<19:53:11, 822.89s/epoch]

Epoch 13/100, Training Loss: 0.1292, Validation Accuracy: 81.22%...


Training Progress:  14%|█▍        | 14/100 [3:10:32<19:45:57, 827.42s/epoch]

Epoch 14/100, Training Loss: 0.1158, Validation Accuracy: 82.02%...


Training Progress:  15%|█▌        | 15/100 [3:24:29<19:36:26, 830.43s/epoch]

Epoch 15/100, Training Loss: 0.1014, Validation Accuracy: 82.02%...


Training Progress:  16%|█▌        | 16/100 [3:38:00<19:14:09, 824.40s/epoch]

Epoch 16/100, Training Loss: 0.0811, Validation Accuracy: 81.82%...


Training Progress:  17%|█▋        | 17/100 [3:51:35<18:56:46, 821.77s/epoch]

Epoch 17/100, Training Loss: 0.0813, Validation Accuracy: 82.02%...


Training Progress:  18%|█▊        | 18/100 [4:05:10<18:39:56, 819.47s/epoch]

Epoch 18/100, Training Loss: 0.0802, Validation Accuracy: 81.72%...


Training Progress:  19%|█▉        | 19/100 [4:18:28<18:17:53, 813.25s/epoch]

Epoch 19/100, Training Loss: 0.0887, Validation Accuracy: 81.82%...


Training Progress:  20%|██        | 20/100 [4:32:02<18:04:36, 813.46s/epoch]

Epoch 20/100, Training Loss: 0.0751, Validation Accuracy: 81.72%...


Training Progress:  21%|██        | 21/100 [4:45:38<17:51:59, 814.17s/epoch]

Epoch 21/100, Training Loss: 0.0787, Validation Accuracy: 81.92%...


Training Progress:  22%|██▏       | 22/100 [4:59:07<17:36:26, 812.65s/epoch]

Epoch 22/100, Training Loss: 0.0715, Validation Accuracy: 81.92%...


Training Progress:  23%|██▎       | 23/100 [5:12:29<17:18:45, 809.42s/epoch]

Epoch 23/100, Training Loss: 0.0749, Validation Accuracy: 81.42%...


Training Progress:  24%|██▍       | 24/100 [5:26:00<17:05:52, 809.90s/epoch]

Epoch 24/100, Training Loss: 0.0731, Validation Accuracy: 81.92%...


Training Progress:  25%|██▌       | 25/100 [5:39:36<16:54:40, 811.74s/epoch]

Epoch 25/100, Training Loss: 0.0753, Validation Accuracy: 81.92%...


Training Progress:  26%|██▌       | 26/100 [5:52:50<16:34:43, 806.53s/epoch]

Epoch 26/100, Training Loss: 0.0798, Validation Accuracy: 81.82%...


Training Progress:  27%|██▋       | 27/100 [6:06:24<16:23:58, 808.74s/epoch]

Epoch 27/100, Training Loss: 0.0740, Validation Accuracy: 81.22%...


Training Progress:  28%|██▊       | 28/100 [6:20:17<16:19:07, 815.94s/epoch]

Epoch 28/100, Training Loss: 0.0759, Validation Accuracy: 82.42%...


Training Progress:  29%|██▉       | 29/100 [6:34:18<16:14:23, 823.43s/epoch]

Epoch 29/100, Training Loss: 0.0785, Validation Accuracy: 82.52%...


Training Progress:  30%|███       | 30/100 [6:48:11<16:04:00, 826.30s/epoch]

Epoch 30/100, Training Loss: 0.0720, Validation Accuracy: 81.72%...


Training Progress:  31%|███       | 31/100 [7:02:06<15:53:22, 829.02s/epoch]

Epoch 31/100, Training Loss: 0.0751, Validation Accuracy: 82.22%...


Training Progress:  32%|███▏      | 32/100 [7:16:04<15:42:27, 831.58s/epoch]

Epoch 32/100, Training Loss: 0.0727, Validation Accuracy: 81.82%...


Training Progress:  33%|███▎      | 33/100 [7:30:00<15:30:10, 832.99s/epoch]

Epoch 33/100, Training Loss: 0.0799, Validation Accuracy: 81.72%...


Training Progress:  34%|███▍      | 34/100 [7:43:49<15:14:44, 831.58s/epoch]

Epoch 34/100, Training Loss: 0.0758, Validation Accuracy: 82.02%...


Training Progress:  35%|███▌      | 35/100 [7:57:23<14:55:12, 826.35s/epoch]

Epoch 35/100, Training Loss: 0.0828, Validation Accuracy: 81.72%...


Training Progress:  36%|███▌      | 36/100 [8:11:07<14:40:39, 825.62s/epoch]

Epoch 36/100, Training Loss: 0.0705, Validation Accuracy: 82.12%...


Training Progress:  37%|███▋      | 37/100 [8:25:20<14:35:32, 833.85s/epoch]

Epoch 37/100, Training Loss: 0.0747, Validation Accuracy: 82.22%...


Training Progress:  38%|███▊      | 38/100 [8:38:57<14:16:39, 829.02s/epoch]

Epoch 38/100, Training Loss: 0.0730, Validation Accuracy: 82.22%...


Training Progress:  39%|███▉      | 39/100 [8:52:21<13:55:06, 821.42s/epoch]

Epoch 39/100, Training Loss: 0.0756, Validation Accuracy: 82.42%...


Training Progress:  40%|████      | 40/100 [9:05:45<13:36:10, 816.17s/epoch]

Epoch 40/100, Training Loss: 0.0753, Validation Accuracy: 82.12%...


Training Progress:  41%|████      | 41/100 [9:19:18<13:21:33, 815.15s/epoch]

Epoch 41/100, Training Loss: 0.0746, Validation Accuracy: 82.12%...


Training Progress:  42%|████▏     | 42/100 [9:33:15<13:14:31, 821.93s/epoch]

Epoch 42/100, Training Loss: 0.0843, Validation Accuracy: 82.32%...


Training Progress:  43%|████▎     | 43/100 [9:47:16<13:05:59, 827.36s/epoch]

Epoch 43/100, Training Loss: 0.0796, Validation Accuracy: 82.32%...


Training Progress:  44%|████▍     | 44/100 [10:01:10<12:54:15, 829.57s/epoch]

Epoch 44/100, Training Loss: 0.0754, Validation Accuracy: 81.52%...


Training Progress:  45%|████▌     | 45/100 [10:14:34<12:33:21, 821.85s/epoch]

Epoch 45/100, Training Loss: 0.0741, Validation Accuracy: 81.92%...


Training Progress:  46%|████▌     | 46/100 [10:28:09<12:17:43, 819.70s/epoch]

Epoch 46/100, Training Loss: 0.0758, Validation Accuracy: 81.82%...


Training Progress:  47%|████▋     | 47/100 [10:41:38<12:01:13, 816.48s/epoch]

Epoch 47/100, Training Loss: 0.0764, Validation Accuracy: 81.52%...


Training Progress:  48%|████▊     | 48/100 [10:54:59<11:43:43, 811.99s/epoch]

Epoch 48/100, Training Loss: 0.0699, Validation Accuracy: 81.32%...


Training Progress:  49%|████▉     | 49/100 [11:08:19<11:27:04, 808.33s/epoch]

Epoch 49/100, Training Loss: 0.0710, Validation Accuracy: 81.72%...


Training Progress:  50%|█████     | 50/100 [11:21:45<11:13:05, 807.72s/epoch]

Epoch 50/100, Training Loss: 0.0746, Validation Accuracy: 81.32%...


Training Progress:  51%|█████     | 51/100 [11:35:08<10:58:19, 806.11s/epoch]

Epoch 51/100, Training Loss: 0.0720, Validation Accuracy: 81.62%...


Training Progress:  52%|█████▏    | 52/100 [11:48:29<10:43:48, 804.75s/epoch]

Epoch 52/100, Training Loss: 0.0799, Validation Accuracy: 81.72%...


Training Progress:  53%|█████▎    | 53/100 [12:01:58<10:31:15, 805.86s/epoch]

Epoch 53/100, Training Loss: 0.0768, Validation Accuracy: 81.52%...


Training Progress:  54%|█████▍    | 54/100 [12:15:44<10:22:36, 812.09s/epoch]

Epoch 54/100, Training Loss: 0.0753, Validation Accuracy: 82.02%...


Training Progress:  55%|█████▌    | 55/100 [12:29:26<10:11:18, 815.08s/epoch]

Epoch 55/100, Training Loss: 0.0753, Validation Accuracy: 81.52%...


Training Progress:  56%|█████▌    | 56/100 [12:42:38<9:52:30, 807.97s/epoch] 

Epoch 56/100, Training Loss: 0.0796, Validation Accuracy: 82.42%...


Training Progress:  57%|█████▋    | 57/100 [12:55:58<9:37:23, 805.66s/epoch]

Epoch 57/100, Training Loss: 0.0735, Validation Accuracy: 81.82%...


Training Progress:  58%|█████▊    | 58/100 [13:09:27<9:24:44, 806.77s/epoch]

Epoch 58/100, Training Loss: 0.0765, Validation Accuracy: 81.82%...


Training Progress:  59%|█████▉    | 59/100 [13:22:57<9:11:50, 807.57s/epoch]

Epoch 59/100, Training Loss: 0.0760, Validation Accuracy: 82.02%...


Training Progress:  60%|██████    | 60/100 [13:36:28<8:59:04, 808.61s/epoch]

Epoch 60/100, Training Loss: 0.0768, Validation Accuracy: 81.82%...


Training Progress:  61%|██████    | 61/100 [13:49:56<8:45:33, 808.56s/epoch]

Epoch 61/100, Training Loss: 0.0735, Validation Accuracy: 81.82%...


Training Progress:  62%|██████▏   | 62/100 [14:03:32<8:33:31, 810.82s/epoch]

Epoch 62/100, Training Loss: 0.0775, Validation Accuracy: 82.02%...


Training Progress:  63%|██████▎   | 63/100 [14:17:17<8:22:36, 815.05s/epoch]

Epoch 63/100, Training Loss: 0.0770, Validation Accuracy: 81.92%...


Training Progress:  64%|██████▍   | 64/100 [14:30:56<8:09:39, 816.11s/epoch]

Epoch 64/100, Training Loss: 0.0753, Validation Accuracy: 81.72%...


Training Progress:  65%|██████▌   | 65/100 [14:44:32<7:56:01, 816.04s/epoch]

Epoch 65/100, Training Loss: 0.0710, Validation Accuracy: 82.32%...


Training Progress:  66%|██████▌   | 66/100 [14:58:09<7:42:37, 816.40s/epoch]

Epoch 66/100, Training Loss: 0.0757, Validation Accuracy: 82.52%...


Training Progress:  67%|██████▋   | 67/100 [15:11:31<7:26:43, 812.21s/epoch]

Epoch 67/100, Training Loss: 0.0732, Validation Accuracy: 82.02%...


Training Progress:  68%|██████▊   | 68/100 [15:23:49<7:01:17, 789.93s/epoch]

Epoch 68/100, Training Loss: 0.0802, Validation Accuracy: 81.42%...


Training Progress:  69%|██████▉   | 69/100 [15:33:41<6:17:24, 730.46s/epoch]

Epoch 69/100, Training Loss: 0.0763, Validation Accuracy: 82.22%...


Training Progress:  70%|███████   | 70/100 [15:43:48<5:46:43, 693.47s/epoch]

Epoch 70/100, Training Loss: 0.0709, Validation Accuracy: 81.92%...


Training Progress:  71%|███████   | 71/100 [15:53:40<5:20:27, 663.03s/epoch]

Epoch 71/100, Training Loss: 0.0762, Validation Accuracy: 81.82%...


Training Progress:  72%|███████▏  | 72/100 [16:06:38<5:25:24, 697.29s/epoch]

Epoch 72/100, Training Loss: 0.0717, Validation Accuracy: 81.82%...


Training Progress:  73%|███████▎  | 73/100 [16:20:11<5:29:26, 732.09s/epoch]

Epoch 73/100, Training Loss: 0.0775, Validation Accuracy: 82.12%...


Training Progress:  74%|███████▍  | 74/100 [16:33:21<5:24:44, 749.40s/epoch]

Epoch 74/100, Training Loss: 0.0724, Validation Accuracy: 81.92%...


Training Progress:  75%|███████▌  | 75/100 [16:47:11<5:22:25, 773.83s/epoch]

Epoch 75/100, Training Loss: 0.0735, Validation Accuracy: 82.22%...


Training Progress:  76%|███████▌  | 76/100 [17:01:25<5:19:09, 797.89s/epoch]

Epoch 76/100, Training Loss: 0.0754, Validation Accuracy: 81.92%...


Training Progress:  77%|███████▋  | 77/100 [17:12:36<4:51:13, 759.73s/epoch]

Epoch 77/100, Training Loss: 0.0760, Validation Accuracy: 81.52%...


Training Progress:  78%|███████▊  | 78/100 [17:26:53<4:49:11, 788.73s/epoch]

Epoch 78/100, Training Loss: 0.0753, Validation Accuracy: 82.02%...


Training Progress:  79%|███████▉  | 79/100 [17:40:11<4:37:03, 791.57s/epoch]

Epoch 79/100, Training Loss: 0.0786, Validation Accuracy: 82.02%...


Training Progress:  80%|████████  | 80/100 [17:52:29<4:18:29, 775.49s/epoch]

Epoch 80/100, Training Loss: 0.0762, Validation Accuracy: 81.92%...


Training Progress:  81%|████████  | 81/100 [18:03:12<3:52:58, 735.69s/epoch]

Epoch 81/100, Training Loss: 0.0776, Validation Accuracy: 81.82%...


Training Progress:  82%|████████▏ | 82/100 [18:16:07<3:44:18, 747.70s/epoch]

Epoch 82/100, Training Loss: 0.0761, Validation Accuracy: 82.12%...


Training Progress:  83%|████████▎ | 83/100 [18:29:44<3:37:44, 768.53s/epoch]

Epoch 83/100, Training Loss: 0.0791, Validation Accuracy: 82.02%...


Training Progress:  84%|████████▍ | 84/100 [18:43:24<3:29:03, 783.95s/epoch]

Epoch 84/100, Training Loss: 0.0751, Validation Accuracy: 81.02%...


Training Progress:  85%|████████▌ | 85/100 [18:56:51<3:17:42, 790.86s/epoch]

Epoch 85/100, Training Loss: 0.0761, Validation Accuracy: 81.92%...


Training Progress:  86%|████████▌ | 86/100 [19:09:46<3:03:24, 786.04s/epoch]

Epoch 86/100, Training Loss: 0.0758, Validation Accuracy: 82.02%...


Training Progress:  87%|████████▋ | 87/100 [19:23:48<2:53:56, 802.78s/epoch]

Epoch 87/100, Training Loss: 0.0729, Validation Accuracy: 81.82%...


Training Progress:  88%|████████▊ | 88/100 [19:38:17<2:44:32, 822.70s/epoch]

Epoch 88/100, Training Loss: 0.0729, Validation Accuracy: 82.32%...


Training Progress:  89%|████████▉ | 89/100 [19:49:15<2:21:46, 773.36s/epoch]

Epoch 89/100, Training Loss: 0.0764, Validation Accuracy: 82.32%...


Training Progress:  90%|█████████ | 90/100 [19:59:46<2:01:45, 730.58s/epoch]

Epoch 90/100, Training Loss: 0.0727, Validation Accuracy: 82.22%...


Training Progress:  91%|█████████ | 91/100 [20:10:08<1:44:42, 698.01s/epoch]

Epoch 91/100, Training Loss: 0.0775, Validation Accuracy: 81.52%...


Training Progress:  92%|█████████▏| 92/100 [20:20:32<1:30:04, 675.62s/epoch]

Epoch 92/100, Training Loss: 0.0750, Validation Accuracy: 81.32%...


Training Progress:  93%|█████████▎| 93/100 [20:30:59<1:17:08, 661.26s/epoch]

Epoch 93/100, Training Loss: 0.0723, Validation Accuracy: 81.72%...


Training Progress:  94%|█████████▍| 94/100 [20:41:54<1:05:55, 659.24s/epoch]

Epoch 94/100, Training Loss: 0.0749, Validation Accuracy: 81.52%...


Training Progress:  95%|█████████▌| 95/100 [20:55:36<59:00, 708.14s/epoch]  

Epoch 95/100, Training Loss: 0.0749, Validation Accuracy: 81.32%...


Training Progress:  96%|█████████▌| 96/100 [21:09:13<49:23, 740.83s/epoch]

Epoch 96/100, Training Loss: 0.0727, Validation Accuracy: 82.22%...


Training Progress:  97%|█████████▋| 97/100 [21:24:18<39:30, 790.16s/epoch]

Epoch 97/100, Training Loss: 0.0755, Validation Accuracy: 81.92%...


Training Progress:  98%|█████████▊| 98/100 [21:40:47<28:19, 849.55s/epoch]

Epoch 98/100, Training Loss: 0.0714, Validation Accuracy: 82.02%...


Training Progress:  99%|█████████▉| 99/100 [21:54:54<14:08, 848.83s/epoch]

Epoch 99/100, Training Loss: 0.0723, Validation Accuracy: 82.02%...


Training Progress: 100%|██████████| 100/100 [22:09:05<00:00, 797.46s/epoch]

Epoch 100/100, Training Loss: 0.0739, Validation Accuracy: 82.12%...
CPU times: total: 8h 36min 18s
Wall time: 22h 9min 5s





In [9]:
# Test
model.eval()
test_preds, test_labels = [], []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        
        test_preds.extend(preds.cpu().numpy())
        test_labels.extend(labels.cpu().numpy())

# Calculate test accuracy
test_accuracy = accuracy_score(test_labels, test_preds)
print(f'Test Accuracy : {test_accuracy:.2f}%...')

Test Accuracy : 0.80%...


In [10]:
# Save the fine-tuned model

if apply_augment:
    pth_name = f'./fine_tune_model/{model_architecture}_iter_{num_epochs}_aug_{apply_augment}_rand_{randomness}_test_{test_accuracy:.2f}%.pth'
else:
    pth_name = f'./fine_tune_model/{model_architecture}_iter_{num_epochs}_aug_{apply_augment}_test_{test_accuracy:.2f}%.pth'    

torch.save(model.state_dict(), pth_name)
    
print(f'Pretrained model saved.....')

Pretrained model saved.....
