In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob
from tqdm.notebook import tqdm, trange
from PIL import Image
import matplotlib.pyplot as plt
import torchsummary
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
# import resnet18 model from pytorch
from torchvision.models import resnet18
from torch.utils.tensorboard import SummaryWriter
import mxnet as mx
from mxnet import recordio
import torch.multiprocessing as mp
from sklearn.model_selection import train_test_split
from collections import defaultdict
import logging

In [2]:
DIM = (112, 112)
BS = 256
EPOCHS = 60
LR = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
NUM_CLASSES = 10572
NUM_WORKERS = 4
LOG_INTERVAL = 6

In [3]:
class CASIAWebFaceDataset(Dataset):
    def __init__(self, path_imgrec, transform=None):
        self.transform = transform
        assert path_imgrec
        if path_imgrec:
            logging.info('loading recordio %s...',
                         path_imgrec)
            path_imgidx = path_imgrec[0:-4] + ".idx"
            print(path_imgrec, path_imgidx)
            self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
            s = self.imgrec.read_idx(0)
            header, _ = recordio.unpack(s)
            if header.flag > 0:
                print('header0 label', header.label)
                self.header0 = (int(header.label[0]), int(header.label[1]))
                # assert(header.flag==1)
                # self.imgidx = range(1, int(header.label[0]))
                self.imgidx = []
                self.id2range = {}
                self.seq_identity = range(int(header.label[0]), int(header.label[1]))
                for identity in self.seq_identity:
                    s = self.imgrec.read_idx(identity)
                    header, _ = recordio.unpack(s)
                    a, b = int(header.label[0]), int(header.label[1])
                    count = b - a
                    self.id2range[identity] = (a, b)
                    self.imgidx += range(a, b)
                print('id2range', len(self.id2range))
            else:
                self.imgidx = list(self.imgrec.keys)
            self.seq = self.imgidx

    def __getitem__(self, idx):
        # Map global index to class ID and local index
        actual_idx = idx + 1  # MXNet indices start from 1
        
        # Read record
        header, s = recordio.unpack(self.imgrec.read_idx(actual_idx))
        img = mx.image.imdecode(s).asnumpy()
        label = int(header.label)
        
        # # Convert to PIL and apply transforms
        img = Image.fromarray(img)
        if self.transform:
            img = self.transform(img)
        
        return img, label

    def __len__(self):
        return len(self.seq)

In [4]:
# dataset = CASIAWebFaceDataset(path_imgrec='./faces_webface_112x112/train.rec', transform=transforms.Compose([
#     transforms.Resize(112),
#     transforms.ToTensor()
# ]))
# len(dataset)
# print(dataset[0])

In [5]:
class CustomNormalize:
    def __call__(self, img):
        # Convert PIL image to tensor
        img = transforms.ToTensor()(img)
        # Subtract 128 and divide by 128
        img = (img * 255.0 - 128) / 128.0
        return img

train_preprocess = transforms.Compose([
    transforms.Resize(DIM),
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.RandomRotation(30),  # Randomly rotate the image by up to 10 degrees
    # transforms.RandomVerticalFlip(),  # Randomly flip the image vertically
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Randomly change brightness, contrast, saturation and hue
    CustomNormalize()
])

test_preprocess = transforms.Compose([
    transforms.Resize(DIM),
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    CustomNormalize()
])

In [6]:
class AMSoftmax(nn.Module):
    '''
    The am softmax as seen on https://arxiv.org/pdf/1801.05599.pdf,

        in_features: size of the embedding, eg. 512
        n_classes: number of classes on the classification task
        s: s parameter of loss, standard = 30.
        m: m parameter of loss, standard = 0.4, best between 0.35 and 0.4 according to paper.

        *inputs: tensor shaped (batch_size X embedding_size)
        output : tensor shaped (batch_size X n_classes) AM_softmax logits for NLL_loss.

    '''
    def __init__(self, in_features, n_classes, s=30, m=0.4):
        super(AMSoftmax, self).__init__()
        self.linear = nn.Linear(in_features, n_classes, bias=False)
        self.s = s
        self.m = m

    def forward(self, *inputs):
        x_vector = F.normalize(inputs[0], p=2, dim=-1)
        self.linear.weight.data = F.normalize(self.linear.weight.data, p=2, dim=-1)
        logits = self.linear(x_vector)
        scaled_logits = (logits - self.m)*self.s
        return  scaled_logits - self._am_logsumexp(logits)

    def _am_logsumexp(self, logits):
        '''
        logsumexp designed for am_softmax, the computation is numerically stable

        '''
        max_x = torch.max(logits, dim=-1)[0].unsqueeze(-1)
        term1 = (self.s*(logits - (max_x + self.m))).exp()
        term2 = (self.s * (logits - max_x)).exp().sum(-1).unsqueeze(-1) \
                - (self.s * (logits - max_x)).exp()
        return self.s*max_x + (term2 + term1).log()
    

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        # self.relu = nn.PReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out


# How can I modify the model to output embeddings of size 128?
# 1. Create a new model that outputs embeddings
# 2. Modify the last layer of the model to output embeddings
# 3. Use a hook to extract embeddings from the model
# 4. Use a custom loss function to train the model

class ResNet18(nn.Module):
    def __init__(self, num_classes=10572, dropout=0.4):
        super(ResNet18, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(dropout)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        self.dropout(out)
        out = self.fc(out)
        out = F.normalize(out, p=2, dim=-1)
        return out


In [7]:
def test_classifier(model, classifier, data_loader, device, message):
    model.eval()
    classifier.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            embeddings = model(images)  # Extract features
            logits = classifier(embeddings)  # Compute AMSoftmax logits
            predictions = torch.argmax(logits, dim=1)  # Get class with max probability
            
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total * 100
    print(f"✅ Classification Accuracy for {message}: {accuracy:.2f}%")
    return accuracy

In [8]:
def train_AMSoftmax(model: nn.Module, classifier: nn.Module, data_loader: DataLoader, val_loader: DataLoader,
                    optimizer: optim.Optimizer, scheduler: optim.lr_scheduler, 
                    criterion: nn.Module, epochs: int, device: torch.device, 
                    retain_graph: bool, checkpoint_interval: int = 10):

    train_losses = []
    tmstmp = time.strftime("%Y%m%d-%H%M%S")
    best_loss = np.inf
    vacc = 0

    # tmstmp = "20250313-074759"

    log_dir = f"runs/{DIM[0]}x{DIM[1]}_ResNet18_AMSoftmax_{tmstmp}"
    writer = SummaryWriter(log_dir=log_dir)

    print(f"Started Training at {tmstmp}")
    
    for e, epoch in enumerate(tqdm(range(epochs), desc="Epochs")):
        model.train()
        classifier.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(tqdm(data_loader, desc="Batches")):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            embeddings = model(images)  # Extract embeddings from model
            logits = classifier(embeddings)  # Compute AMSoftmax logits
            loss = criterion(logits, labels)  # Compute NLL loss

            loss.backward(retain_graph=retain_graph)
            optimizer.step()
            running_loss += loss.item()
            # print(f"Batch {i+1}/{len(data_loader)} - Loss: {loss.item():.4f}")

        scheduler.step()
        avg_loss = running_loss / len(data_loader)
        train_losses.append(avg_loss)

        # Validation
        model.eval()
        running_loss = 0.0
        for i, (images, labels) in enumerate(tqdm(val_loader, desc="Validation")):
            images, labels = images.to(device), labels.to(device)
            embeddings = model(images)
            logits = classifier(embeddings)
            loss = criterion(logits, labels)
            running_loss += loss.item()
            # print(f"Batch {i+1}/{len(val_loader)} - Loss: {loss.item():.4f}")


        # Log per epoch
        writer.add_scalar('Loss/train', avg_loss, epoch)
        writer.add_scalar('Loss/val', running_loss / len(val_loader), epoch)

        # Checkpoint saving
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(log_dir, f'checkpoint_epoch_{epoch + 1}.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'classifier_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")
            train_acc = test_classifier(model, classifier, data_loader, device, "Training")
            val_acc = test_classifier(model, classifier, val_loader, device, "Validation")
            writer.add_scalar('Accuracy/train', train_acc, epoch)
            writer.add_scalar('Accuracy/val', val_acc, epoch)
            if val_acc > vacc:
                vacc = val_acc
                torch.save(model.state_dict(), f"{log_dir}/{DIM[0]}x{DIM[1]}_ResNet18_AMSoftmax_validation_{tmstmp}.pt")
                print(f"Saved best model with validation accuracy {vacc}")
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
        # Save Best Model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f"{log_dir}/{DIM[0]}x{DIM[1]}_ResNet18_AMSoftmax_{tmstmp}.pt")
            print(f"Saved best model with loss: {best_loss:.4f}")

    print(f"Finished Training at {time.strftime('%Y%m%d-%H%M%S')} with best validation accuracy {vacc:.4f}")
    writer.close()

    return model, train_losses, tmstmp

In [9]:
def train_on_casia_webface(embedding_model, classifier, optimizer, device, device_ids):
    # Set up paths
    rec_path = "/home/ichitu/py-files/faces_webface_112x112/train.rec"

    
    # Load datasets
    print("Loading CASIA WebFace dataset...")
    casia_dataset = CASIAWebFaceDataset(
        path_imgrec=rec_path,
        transform=train_preprocess
    )

    train_idx, test_idx = train_test_split(range(len(casia_dataset)), test_size=0.2, random_state=42)
    train_dataset = torch.utils.data.Subset(casia_dataset, train_idx)
    test_dataset = torch.utils.data.Subset(casia_dataset, test_idx)
    # val_idx, test_idx = train_test_split(test_idx, test_size=0.7, random_state=42)
    # val_dataset = torch.utils.data.Subset(casia_dataset, val_idx)
    # test_dataset = torch.utils.data.Subset(casia_dataset, test_idx)

    print(len(train_dataset))
    # print(len(val_dataset))
    print(len(test_dataset))

    

    print(len(casia_dataset))
    
    # print("Loading LFW dataset...")
    # lfw_dataset = foz.load_zoo_dataset("lfw")
    
    # Find and filter overlapping identities
    # print("Finding overlapping identities...")
    # overlapping_ids = find_overlapping_identities(casia_dataset, lfw_dataset)
    
    # print("Creating filtered dataset...")
    # filtered_dataset = FilteredCASIADataset(casia_dataset, overlapping_ids)
    # filtered_dataset = casia_dataset
    
    # Create data loader
    # train_loader = DataLoader(
    #     train_dataset, 
    #     batch_size=BS * len(device_ids),
    #     shuffle=True, 
    #     num_workers=2,
    #     pin_memory=True
    # )

    # val_loader = DataLoader(
    #     val_dataset, 
    #     batch_size=BS * len(device_ids),
    #     shuffle=True, 
    #     num_workers=2,
    #     pin_memory=True
    # )

    # test_loader = DataLoader(
    #     test_dataset, 
    #     batch_size=BS * len(device_ids),
    #     shuffle=True, 
    #     num_workers=2,
    #     pin_memory=True
    # )

    train_loader = DataLoader(
        casia_dataset,
        batch_size=BS,
        num_workers=4,
        sampler=torch.utils.data.SubsetRandomSampler(train_idx),
        pin_memory=True
    )

    # val_loader = DataLoader(
    #     CASIAWebFaceDataset(
    #         path_imgrec=rec_path,
    #         transform=test_preprocess
    #     ),
    #     batch_size=BS,
    #     num_workers=2,
    #     sampler=torch.utils.data.SubsetRandomSampler(val_idx),
    #     pin_memory=True
    # )

    test_loader = DataLoader(
        CASIAWebFaceDataset(
            path_imgrec=rec_path,
            transform=test_preprocess
        ),
        batch_size=BS,
        num_workers=2,
        sampler=torch.utils.data.SubsetRandomSampler(test_idx),
        pin_memory=True
    )
    
    # Set up model
    print("Setting up model...")
    
    # Set up optimizer (include both models' parameters)
    # optimizer = optim.SGD(
    #     list(embedding_model.parameters()) + list(classifier.parameters()),
    #     lr=LR,
    #     momentum=MOMENTUM,
    #     weight_decay=WEIGHT_DECAY
    # )
    # optimizer = optim.Adam(
    #     embedding_model.parameters(),
    #     lr=0.001,
    #     weight_decay=5e-4,
    #     eps=1e-8
    # )
    optimizer = optim.SGD(
        list(embedding_model.parameters()) + list(classifier.parameters()),
        lr=optimizer['param_groups'][0]['lr'],
        momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
    )

    # Set up scheduler
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-5)
    # scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=60, T_mult=1, eta_min=1e-5)
    # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 
    #                                            milestones=[16,24,30,50],
    #                                               gamma=0.2)
    # Set up loss
    criterion = nn.NLLLoss()

    device = torch.device("cuda:{}".format(device_ids[1][0]) if torch.cuda.is_available() else "cpu")
    
    # Train
    print("Starting training...")
    model, losses, timestamp = train_AMSoftmax(
        model=embedding_model,
        classifier=classifier,
        data_loader=train_loader,
        # val_loader=val_loader,
        val_loader=test_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        epochs=EPOCHS,  # Adjust as needed
        device=device,
        retain_graph=False,
        checkpoint_interval=LOG_INTERVAL
    )

    print("Testing model...")
    test_acc = test_classifier(embedding_model, classifier, test_loader, device, "Test")
    
    return model, losses, timestamp, test_acc

In [10]:
def load_model(model, classifier, check_point_dir, device='cuda'):
    check_point = torch.load(check_point_dir, map_location=device)
    state_dict = check_point['model_state_dict']
    # Create new OrderedDict without 'module.' prefix
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:] # remove 'module.' prefix
        else:
            name = k
        new_state_dict[name] = v
    
    # Load the weights
    model.load_state_dict(new_state_dict)
    # Set to evaluation mode
    # model.eval()
    # print(f"Model loaded from {path}")
    state_dict = check_point['classifier_state_dict']
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:] # remove 'module.' prefix
        else:
            name = k
        new_state_dict[name] = v
    # Load the weights
    classifier.load_state_dict(new_state_dict)

    optimizer = torch.load(check_point_dir, map_location=device)['optimizer_state_dict']

    return model, classifier, optimizer

In [11]:
# assert 1==2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

check_point_dir = "/home/ichitu/py-files/runs/112x112_ResNet18_AMSoftmax_20250327-161025/checkpoint_epoch_60.pth"

device_ids = [[3],[3]]
classifier = AMSoftmax(512, 10572)
# embedding_model = EmbeddingResNet18(512)
embedding_model = ResNet18(512)
# embedding_model = Resface20(512)
embedding_model, classifier, optimizer = load_model(embedding_model, classifier, check_point_dir, 'cuda:{}'.format(device_ids[0][0]))
# embedding_model = load_model(embedding_model, '/home/ichitu/py-files/runs/112x96_ResNet18_AMSoftmax_20250322-161021/112x96_ResNet18_AMSoftmax_20250322-161021.pt', 'cuda:3')

if torch.cuda.device_count() > 1:
    print(f"Avaible {torch.cuda.device_count()} GPUs and using {device_ids}")
    embedding_model = nn.DataParallel(embedding_model, device_ids=device_ids[0])
    classifier = nn.DataParallel(classifier, device_ids=device_ids[1])

# embedding_model = embedding_model.to(device)
# classifier = classifier.to(device)
# embedding_model.load_state_dict(torch.load("Models-pt/112x96_ResNet18_AMSoftmax_20250313-074759.pt"))
print(embedding_model)
print(classifier)
print(optimizer)
# # assert False

# optimizer = torch.load(check_point_dir, map_location=device)['optimizer_state_dict']

embedding_model, train_losses, tmstmp, test_acc = train_on_casia_webface(embedding_model, classifier, optimizer, device, device_ids)


Avaible 4 GPUs and using [[3], [3]]
DataParallel(
  (module): ResNet18(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   

{'state': {0: {'momentum_buffer': tensor([[[[-2.9065e-12, -3.1499e-12, -1.6954e-12,  ..., -4.4676e-13,
            7.7209e-13,  2.6763e-12],
          [-4.9210e-12, -6.9758e-12, -5.9489e-12,  ..., -3.7476e-12,
           -9.0206e-13,  1.2005e-12],
          [-7.3464e-12, -9.5545e-12, -1.0954e-11,  ..., -8.0306e-12,
           -4.2895e-12, -3.2556e-13],
          ...,
          [-6.3286e-12, -6.8873e-12, -6.7828e-12,  ..., -7.3931e-12,
           -4.5429e-12, -2.7180e-12],
          [-4.8799e-12, -4.3100e-12, -3.9241e-12,  ..., -6.5401e-12,
           -2.8370e-12, -1.1823e-12],
          [-3.3216e-12, -2.3598e-12, -2.0927e-12,  ..., -2.9501e-12,
           -1.5456e-12, -1.6786e-12]],

         [[ 8.9513e-13,  1.3341e-12,  2.9428e-12,  ...,  6.9594e-14,
           -1.7243e-12, -2.8670e-12],
          [ 1.0377e-12, -1.6139e-13,  1.4352e-12,  ..., -2.9596e-13,
           -1.7505e-12, -3.6443e-12],
          [ 1.3670e-13, -3.5163e-13, -1.2305e-12,  ..., -2.8407e-12,
           -3.4419e-12, 

header0 label [490624. 501196.]
id2range 10572
392498
98125
490623
/home/ichitu/py-files/faces_webface_112x112/train.rec /home/ichitu/py-files/faces_webface_112x112/train.idx


header0 label [490624. 501196.]
id2range 10572
Setting up model...
Starting training...
Started Training at 20250330-144308


Epochs:   0%|          | 0/60 [00:00<?, ?it/s]

Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 1/60 - Loss: 3.9278
Saved best model with loss: 3.9278


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 2/60 - Loss: 3.8971
Saved best model with loss: 3.8971


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 3/60 - Loss: 3.8628
Saved best model with loss: 3.8628


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 4/60 - Loss: 3.8375
Saved best model with loss: 3.8375


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 5/60 - Loss: 3.8088
Saved best model with loss: 3.8088


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_6.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 93.28%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.57%
Saved best model with validation accuracy 86.56611464968152
Epoch 6/60 - Loss: 3.7812
Saved best model with loss: 3.7812


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 7/60 - Loss: 3.7533
Saved best model with loss: 3.7533


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 8/60 - Loss: 3.7293
Saved best model with loss: 3.7293


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 9/60 - Loss: 3.7016
Saved best model with loss: 3.7016


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 10/60 - Loss: 3.6763
Saved best model with loss: 3.6763


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 11/60 - Loss: 3.6488
Saved best model with loss: 3.6488


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_12.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 93.64%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.43%
Epoch 12/60 - Loss: 3.6222
Saved best model with loss: 3.6222


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 13/60 - Loss: 3.5938
Saved best model with loss: 3.5938


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 14/60 - Loss: 3.5696
Saved best model with loss: 3.5696


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 15/60 - Loss: 3.5482
Saved best model with loss: 3.5482


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 16/60 - Loss: 3.5197
Saved best model with loss: 3.5197


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 17/60 - Loss: 3.4961
Saved best model with loss: 3.4961


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_18.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 93.94%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.21%
Epoch 18/60 - Loss: 3.4679
Saved best model with loss: 3.4679


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 19/60 - Loss: 3.4424
Saved best model with loss: 3.4424


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 20/60 - Loss: 3.4139
Saved best model with loss: 3.4139


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 21/60 - Loss: 3.3874
Saved best model with loss: 3.3874


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 22/60 - Loss: 3.3657
Saved best model with loss: 3.3657


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 23/60 - Loss: 3.3373
Saved best model with loss: 3.3373


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_24.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 94.30%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.28%
Epoch 24/60 - Loss: 3.3128
Saved best model with loss: 3.3128


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 25/60 - Loss: 3.2823
Saved best model with loss: 3.2823


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 26/60 - Loss: 3.2598
Saved best model with loss: 3.2598


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 27/60 - Loss: 3.2329
Saved best model with loss: 3.2329


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 28/60 - Loss: 3.2084
Saved best model with loss: 3.2084


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 29/60 - Loss: 3.1811
Saved best model with loss: 3.1811


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_30.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 94.64%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.26%
Epoch 30/60 - Loss: 3.1566
Saved best model with loss: 3.1566


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 31/60 - Loss: 3.1312
Saved best model with loss: 3.1312


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 32/60 - Loss: 3.1073
Saved best model with loss: 3.1073


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 33/60 - Loss: 3.0805
Saved best model with loss: 3.0805


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 34/60 - Loss: 3.0576
Saved best model with loss: 3.0576


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 35/60 - Loss: 3.0393
Saved best model with loss: 3.0393


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_36.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 94.94%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.25%
Epoch 36/60 - Loss: 3.0154
Saved best model with loss: 3.0154


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 37/60 - Loss: 2.9896
Saved best model with loss: 2.9896


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 38/60 - Loss: 2.9672
Saved best model with loss: 2.9672


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 39/60 - Loss: 2.9505
Saved best model with loss: 2.9505


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 40/60 - Loss: 2.9290
Saved best model with loss: 2.9290


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 41/60 - Loss: 2.9113
Saved best model with loss: 2.9113


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_42.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 95.21%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.31%
Epoch 42/60 - Loss: 2.8962
Saved best model with loss: 2.8962


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 43/60 - Loss: 2.8774
Saved best model with loss: 2.8774


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 44/60 - Loss: 2.8600
Saved best model with loss: 2.8600


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 45/60 - Loss: 2.8460
Saved best model with loss: 2.8460


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 46/60 - Loss: 2.8308
Saved best model with loss: 2.8308


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 47/60 - Loss: 2.8206


Saved best model with loss: 2.8206


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_48.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 95.39%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.23%
Epoch 48/60 - Loss: 2.8049
Saved best model with loss: 2.8049


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 49/60 - Loss: 2.7954
Saved best model with loss: 2.7954


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 50/60 - Loss: 2.7880
Saved best model with loss: 2.7880


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 51/60 - Loss: 2.7776
Saved best model with loss: 2.7776


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 52/60 - Loss: 2.7693
Saved best model with loss: 2.7693


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 53/60 - Loss: 2.7621
Saved best model with loss: 2.7621


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_54.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 95.50%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.25%
Epoch 54/60 - Loss: 2.7581
Saved best model with loss: 2.7581


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 55/60 - Loss: 2.7508
Saved best model with loss: 2.7508


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 56/60 - Loss: 2.7478
Saved best model with loss: 2.7478


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 57/60 - Loss: 2.7452
Saved best model with loss: 2.7452


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 58/60 - Loss: 2.7415
Saved best model with loss: 2.7415


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 59/60 - Loss: 2.7381
Saved best model with loss: 2.7381


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResNet18_AMSoftmax_20250330-144308/checkpoint_epoch_60.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 95.54%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 86.30%
Epoch 60/60 - Loss: 2.7366
Saved best model with loss: 2.7366
Finished Training at 20250330-184136 with best validation accuracy 86.5661
Testing model...


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Test: 86.18%


In [12]:
print(test_acc)

86.17579617834394
