In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob
from tqdm.notebook import tqdm, trange
from PIL import Image
import matplotlib.pyplot as plt
import torchsummary
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
# import resnet18 model from pytorch
from torchvision.models import resnet18
from torch.utils.tensorboard import SummaryWriter
import mxnet as mx
from mxnet import recordio
import torch.multiprocessing as mp
from sklearn.model_selection import train_test_split
from collections import defaultdict
import logging

In [2]:
DIM = (112, 112)
BS = 256
EPOCHS = 60
LR = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
NUM_CLASSES = 10572
NUM_WORKERS = 4
LOG_INTERVAL = 6

In [3]:
class CASIAWebFaceDataset(Dataset):
    def __init__(self, path_imgrec, transform=None):
        self.transform = transform
        assert path_imgrec
        if path_imgrec:
            logging.info('loading recordio %s...',
                         path_imgrec)
            path_imgidx = path_imgrec[0:-4] + ".idx"
            print(path_imgrec, path_imgidx)
            self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
            s = self.imgrec.read_idx(0)
            header, _ = recordio.unpack(s)
            if header.flag > 0:
                print('header0 label', header.label)
                self.header0 = (int(header.label[0]), int(header.label[1]))
                # assert(header.flag==1)
                # self.imgidx = range(1, int(header.label[0]))
                self.imgidx = []
                self.id2range = {}
                self.seq_identity = range(int(header.label[0]), int(header.label[1]))
                for identity in self.seq_identity:
                    s = self.imgrec.read_idx(identity)
                    header, _ = recordio.unpack(s)
                    a, b = int(header.label[0]), int(header.label[1])
                    count = b - a
                    self.id2range[identity] = (a, b)
                    self.imgidx += range(a, b)
                print('id2range', len(self.id2range))
            else:
                self.imgidx = list(self.imgrec.keys)
            self.seq = self.imgidx

    def __getitem__(self, idx):
        # Map global index to class ID and local index
        actual_idx = idx + 1  # MXNet indices start from 1
        
        # Read record
        header, s = recordio.unpack(self.imgrec.read_idx(actual_idx))
        img = mx.image.imdecode(s).asnumpy()
        label = int(header.label)
        
        # # Convert to PIL and apply transforms
        img = Image.fromarray(img)
        if self.transform:
            img = self.transform(img)
        
        return img, label

    def __len__(self):
        return len(self.seq)

In [4]:
# dataset = CASIAWebFaceDataset(path_imgrec='./faces_webface_112x112/train.rec', transform=transforms.Compose([
#     transforms.Resize(112),
#     transforms.ToTensor()
# ]))
# len(dataset)
# print(dataset[0])

In [5]:
class CustomNormalize:
    def __call__(self, img):
        # Convert PIL image to tensor
        img = transforms.ToTensor()(img)
        # Subtract 128 and divide by 128
        img = (img * 255.0 - 128) / 128.0
        return img

train_preprocess = transforms.Compose([
    transforms.Resize(DIM),
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.RandomRotation(30),  # Randomly rotate the image by up to 10 degrees
    # transforms.RandomVerticalFlip(),  # Randomly flip the image vertically
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Randomly change brightness, contrast, saturation and hue
    CustomNormalize()
])

test_preprocess = transforms.Compose([
    transforms.Resize(DIM),
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    CustomNormalize()
])

In [6]:
class AMSoftmax(nn.Module):
    '''
    The am softmax as seen on https://arxiv.org/pdf/1801.05599.pdf,

        in_features: size of the embedding, eg. 512
        n_classes: number of classes on the classification task
        s: s parameter of loss, standard = 30.
        m: m parameter of loss, standard = 0.4, best between 0.35 and 0.4 according to paper.

        *inputs: tensor shaped (batch_size X embedding_size)
        output : tensor shaped (batch_size X n_classes) AM_softmax logits for NLL_loss.

    '''
    def __init__(self, in_features, n_classes, s=30, m=0.4):
        super(AMSoftmax, self).__init__()
        self.linear = nn.Linear(in_features, n_classes, bias=False)
        self.s = s
        self.m = m

    def forward(self, *inputs):
        x_vector = F.normalize(inputs[0], p=2, dim=-1)
        self.linear.weight.data = F.normalize(self.linear.weight.data, p=2, dim=-1)
        logits = self.linear(x_vector)
        scaled_logits = (logits - self.m)*self.s
        return  scaled_logits - self._am_logsumexp(logits)

    def _am_logsumexp(self, logits):
        '''
        logsumexp designed for am_softmax, the computation is numerically stable

        '''
        max_x = torch.max(logits, dim=-1)[0].unsqueeze(-1)
        term1 = (self.s*(logits - (max_x + self.m))).exp()
        term2 = (self.s * (logits - max_x)).exp().sum(-1).unsqueeze(-1) \
                - (self.s * (logits - max_x)).exp()
        return self.s*max_x + (term2 + term1).log()

In [7]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        # self.relu = nn.PReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out


class EmbeddingResNet18(nn.Module):
    def __init__(self, input_size = (112,112), num_classes=512, dropout=0.4):
        super(EmbeddingResNet18, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        # self.relu = nn.PReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)


        self.flatten_layer = nn.Flatten()  # Separate Flatten layer
        self.dropout = nn.Dropout(dropout)

        #Calculate the correct input size for the fully connected layer.
        #THIS DEPENDS ENTIRELY ON THE INPUT IMAGE SIZE and the convolutional layers
        #You MUST calculate this based on the input size
        #THIS IS JUST AN EXAMPLE
        fc_input_size = 512 * 16  # Example: Calculated based on a particular input size

        self.fc = nn.Linear(fc_input_size, num_classes)


    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.flatten_layer(out)
        out = self.dropout(out)
        out = self.fc(out)
        out = F.normalize(out, p=2, dim=-1)
        return out

In [8]:
# model = EmbeddingResNet18()
# model = model.to("cuda")
# torchsummary.summary(model, (3, 112, 112))
# torch.cuda.empty_cache()

# assert False

In [9]:
def test_classifier(model, classifier, data_loader, device, message):
    model.eval()
    classifier.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            embeddings = model(images)  # Extract features
            logits = classifier(embeddings)  # Compute AMSoftmax logits
            predictions = torch.argmax(logits, dim=1)  # Get class with max probability
            
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total * 100
    print(f"✅ Classification Accuracy for {message}: {accuracy:.2f}%")
    return accuracy

In [10]:
def train_AMSoftmax(model: nn.Module, classifier: nn.Module, data_loader: DataLoader, val_loader: DataLoader,
                    optimizer: optim.Optimizer, scheduler: optim.lr_scheduler, 
                    criterion: nn.Module, epochs: int, device: torch.device, 
                    retain_graph: bool, checkpoint_interval: int = 10):

    train_losses = []
    tmstmp = time.strftime("%Y%m%d-%H%M%S")
    best_loss = np.inf
    vacc = 0

    # tmstmp = "20250313-074759"

    log_dir = f"runs/{DIM[0]}x{DIM[1]}_Embedding-ResNet18_AMSoftmax_{tmstmp}"
    writer = SummaryWriter(log_dir=log_dir)

    print(f"Started Training at {tmstmp}")
    
    for e, epoch in enumerate(tqdm(range(epochs), desc="Epochs")):
        model.train()
        classifier.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(tqdm(data_loader, desc="Batches")):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            embeddings = model(images)  # Extract embeddings from model
            logits = classifier(embeddings)  # Compute AMSoftmax logits
            loss = criterion(logits, labels)  # Compute NLL loss

            loss.backward(retain_graph=retain_graph)
            optimizer.step()
            running_loss += loss.item()
            # print(f"Batch {i+1}/{len(data_loader)} - Loss: {loss.item():.4f}")

        scheduler.step()
        avg_loss = running_loss / len(data_loader)
        train_losses.append(avg_loss)

        # Validation
        model.eval()
        running_loss = 0.0
        for i, (images, labels) in enumerate(tqdm(val_loader, desc="Validation")):
            images, labels = images.to(device), labels.to(device)
            embeddings = model(images)
            logits = classifier(embeddings)
            loss = criterion(logits, labels)
            running_loss += loss.item()
            # print(f"Batch {i+1}/{len(val_loader)} - Loss: {loss.item():.4f}")


        # Log per epoch
        writer.add_scalar('Loss/train', avg_loss, epoch)
        writer.add_scalar('Loss/val', running_loss / len(val_loader), epoch)

        # Checkpoint saving
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(log_dir, f'checkpoint_epoch_{epoch + 1}.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'classifier_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")
            train_acc = test_classifier(model, classifier, data_loader, device, "Training")
            val_acc = test_classifier(model, classifier, val_loader, device, "Validation")
            writer.add_scalar('Accuracy/train', train_acc, epoch)
            writer.add_scalar('Accuracy/val', val_acc, epoch)
            if val_acc > vacc:
                vacc = val_acc
                torch.save(model.state_dict(), f"{log_dir}{DIM[0]}x{DIM[1]}_Embedding-ResNet18_AMSoftmax_validation_{tmstmp}.pt")
                print(f"Saved best model with validation accuracy {vacc}")
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
        # Save Best Model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f"{log_dir}/{DIM[0]}x{DIM[1]}_{DIM[0]}x{DIM[1]}_Embedding-ResNet18_AMSoftmax_{tmstmp}.pt")
            print(f"Saved best model with loss: {best_loss:.4f}")

    print(f"Finished Training at {time.strftime('%Y%m%d-%H%M%S')} with best validation accuracy {vacc:.4f}")
    writer.close()

    return model, train_losses, tmstmp

In [11]:
def train_on_casia_webface(embedding_model, classifier, optimizer, device, device_ids):
    # Set up paths
    rec_path = "/home/ichitu/py-files/faces_webface_112x112/train.rec"

    
    # Load datasets
    print("Loading CASIA WebFace dataset...")
    casia_dataset = CASIAWebFaceDataset(
        path_imgrec=rec_path,
        transform=train_preprocess
    )

    train_idx, test_idx = train_test_split(range(len(casia_dataset)), test_size=0.2, random_state=42)
    train_dataset = torch.utils.data.Subset(casia_dataset, train_idx)
    test_dataset = torch.utils.data.Subset(casia_dataset, test_idx)
    # val_idx, test_idx = train_test_split(test_idx, test_size=0.7, random_state=42)
    # val_dataset = torch.utils.data.Subset(casia_dataset, val_idx)
    # test_dataset = torch.utils.data.Subset(casia_dataset, test_idx)

    print(len(train_dataset))
    # print(len(val_dataset))
    print(len(test_dataset))

    

    print(len(casia_dataset))
    
    # print("Loading LFW dataset...")
    # lfw_dataset = foz.load_zoo_dataset("lfw")
    
    # Find and filter overlapping identities
    # print("Finding overlapping identities...")
    # overlapping_ids = find_overlapping_identities(casia_dataset, lfw_dataset)
    
    # print("Creating filtered dataset...")
    # filtered_dataset = FilteredCASIADataset(casia_dataset, overlapping_ids)
    # filtered_dataset = casia_dataset
    
    # Create data loader
    # train_loader = DataLoader(
    #     train_dataset, 
    #     batch_size=BS * len(device_ids),
    #     shuffle=True, 
    #     num_workers=2,
    #     pin_memory=True
    # )

    # val_loader = DataLoader(
    #     val_dataset, 
    #     batch_size=BS * len(device_ids),
    #     shuffle=True, 
    #     num_workers=2,
    #     pin_memory=True
    # )

    # test_loader = DataLoader(
    #     test_dataset, 
    #     batch_size=BS * len(device_ids),
    #     shuffle=True, 
    #     num_workers=2,
    #     pin_memory=True
    # )

    train_loader = DataLoader(
        casia_dataset,
        batch_size=BS,
        num_workers=4,
        sampler=torch.utils.data.SubsetRandomSampler(train_idx),
        pin_memory=True
    )

    # val_loader = DataLoader(
    #     CASIAWebFaceDataset(
    #         path_imgrec=rec_path,
    #         transform=test_preprocess
    #     ),
    #     batch_size=BS,
    #     num_workers=2,
    #     sampler=torch.utils.data.SubsetRandomSampler(val_idx),
    #     pin_memory=True
    # )

    test_loader = DataLoader(
        CASIAWebFaceDataset(
            path_imgrec=rec_path,
            transform=test_preprocess
        ),
        batch_size=BS,
        num_workers=2,
        sampler=torch.utils.data.SubsetRandomSampler(test_idx),
        pin_memory=True
    )
    
    # Set up model
    print("Setting up model...")
    
    # Set up optimizer (include both models' parameters)
    # optimizer = optim.SGD(
    #     list(embedding_model.parameters()) + list(classifier.parameters()),
    #     lr=LR,
    #     momentum=MOMENTUM,
    #     weight_decay=WEIGHT_DECAY
    # )
    optimizer = optim.Adam(
        embedding_model.parameters(),
        lr=0.001,
        weight_decay=5e-4,
        eps=1e-8
    )
    # optimizer = optim.SGD(
    #     embedding_model.parameters(),
    #     lr=optimizer['param_groups'][0]['lr'],
    #     momentum=MOMENTUM,
    # weight_decay=WEIGHT_DECAY
    # )

    # Set up scheduler
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-3)
    # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 
    #                                            milestones=[16,24,30,50],
    #                                               gamma=0.2)
    # Set up loss
    criterion = nn.NLLLoss()

    device = torch.device("cuda:{}".format(device_ids[1][0]) if torch.cuda.is_available() else "cpu")
    
    # Train
    print("Starting training...")
    model, losses, timestamp = train_AMSoftmax(
        model=embedding_model,
        classifier=classifier,
        data_loader=train_loader,
        # val_loader=val_loader,
        val_loader=test_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        epochs=EPOCHS,  # Adjust as needed
        device=device,
        retain_graph=False,
        checkpoint_interval=LOG_INTERVAL
    )

    print("Testing model...")
    test_acc = test_classifier(embedding_model, classifier, test_loader, device, "Test")
    
    return model, losses, timestamp, test_acc

In [12]:
def load_model(model, check_point_dir, device='cuda'):
    check_point = torch.load(check_point_dir, map_location=device)
    state_dict = check_point['model_state_dict']
    # Create new OrderedDict without 'module.' prefix
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:] # remove 'module.' prefix
        else:
            name = k
        new_state_dict[name] = v
    
    # Load the weights
    model.load_state_dict(new_state_dict)
    # Set to evaluation mode
    # model.eval()
    # print(f"Model loaded from {path}")

    optimizer = torch.load(check_point_dir, map_location=device)['optimizer_state_dict']

    return model, optimizer

In [13]:
# assert 1==2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

check_point_dir = "/home/ichitu/py-files/runs/112x96_ResNet18_AMSoftmax_20250326-171144/checkpoint_epoch_60.pth"
# check_point_dir = f"/home/ichitu/py-files/runs/{DIM[0]}x{DIM[1]}_Embedding-ResNet18_AMSoftmax_20250326-171144/112x96_ResNet18_AMSoftmax_20250326-171144.pt"

device_ids = [[0],[0]]
classifier = AMSoftmax(512, 10572)
embedding_model = EmbeddingResNet18(512)
# embedding_model, optimizer = load_model(embedding_model, check_point_dir, 'cuda:{}'.format(device_ids[0][0]))
# embedding_model = load_model(embedding_model, '/home/ichitu/py-files/runs/112x96_ResNet18_AMSoftmax_20250322-161021/112x96_ResNet18_AMSoftmax_20250322-161021.pt', 'cuda:3')

if torch.cuda.device_count() > 1:
    print(f"Avaible {torch.cuda.device_count()} GPUs and using {device_ids}")
    embedding_model = nn.DataParallel(embedding_model, device_ids=device_ids[0])
    classifier = nn.DataParallel(classifier, device_ids=device_ids[1])

# embedding_model = embedding_model.to(device)
# classifier = classifier.to(device)
# embedding_model.load_state_dict(torch.load("Models-pt/112x96_ResNet18_AMSoftmax_20250313-074759.pt"))
# print(embedding_model)
# print(classifier)
# print(optimizer)
# # assert False

optimizer = torch.load(check_point_dir, map_location=device)['optimizer_state_dict']

embedding_model, train_losses, tmstmp, test_acc = train_on_casia_webface(embedding_model, classifier, optimizer, device, device_ids)


Avaible 4 GPUs and using [[0], [0]]


Loading CASIA WebFace dataset...
/home/ichitu/py-files/faces_webface_112x112/train.rec /home/ichitu/py-files/faces_webface_112x112/train.idx


header0 label [490624. 501196.]
id2range 10572
392498
98125
490623
/home/ichitu/py-files/faces_webface_112x112/train.rec /home/ichitu/py-files/faces_webface_112x112/train.idx


header0 label [490624. 501196.]
id2range 10572
Setting up model...
Starting training...
Started Training at 20250328-115538


Epochs:   0%|          | 0/60 [00:00<?, ?it/s]

Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 1/60 - Loss: 21.4952
Saved best model with loss: 21.4952


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 2/60 - Loss: 21.0322
Saved best model with loss: 21.0322


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 3/60 - Loss: 20.4763
Saved best model with loss: 20.4763


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 4/60 - Loss: 19.9531
Saved best model with loss: 19.9531


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 5/60 - Loss: 19.4596
Saved best model with loss: 19.4596


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_6.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 17.76%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 17.93%
Saved best model with validation accuracy 17.929171974522294
Epoch 6/60 - Loss: 18.9971
Saved best model with loss: 18.9971


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 7/60 - Loss: 18.5902
Saved best model with loss: 18.5902


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 8/60 - Loss: 18.2359
Saved best model with loss: 18.2359


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 9/60 - Loss: 17.9331
Saved best model with loss: 17.9331


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 10/60 - Loss: 17.6710
Saved best model with loss: 17.6710


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 11/60 - Loss: 17.4492
Saved best model with loss: 17.4492


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_12.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 29.05%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 28.13%
Saved best model with validation accuracy 28.1284076433121
Epoch 12/60 - Loss: 17.2413
Saved best model with loss: 17.2413


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 13/60 - Loss: 17.0702
Saved best model with loss: 17.0702


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 14/60 - Loss: 16.9161
Saved best model with loss: 16.9161


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 15/60 - Loss: 16.7906
Saved best model with loss: 16.7906


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 16/60 - Loss: 16.6676
Saved best model with loss: 16.6676


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 17/60 - Loss: 16.5564
Saved best model with loss: 16.5564


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_18.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 33.47%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 32.19%
Saved best model with validation accuracy 32.193630573248406
Epoch 18/60 - Loss: 16.4643


Saved best model with loss: 16.4643


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 19/60 - Loss: 16.3761
Saved best model with loss: 16.3761


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 20/60 - Loss: 16.2943
Saved best model with loss: 16.2943


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 21/60 - Loss: 16.2280
Saved best model with loss: 16.2280


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 22/60 - Loss: 16.1590
Saved best model with loss: 16.1590


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 23/60 - Loss: 16.1009
Saved best model with loss: 16.1009


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_24.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 35.78%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 34.24%
Saved best model with validation accuracy 34.24305732484076
Epoch 24/60 - Loss: 16.0302
Saved best model with loss: 16.0302


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 25/60 - Loss: 15.9925
Saved best model with loss: 15.9925


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 26/60 - Loss: 15.9329
Saved best model with loss: 15.9329


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 27/60 - Loss: 15.8956
Saved best model with loss: 15.8956


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 28/60 - Loss: 15.8426
Saved best model with loss: 15.8426


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 29/60 - Loss: 15.8135
Saved best model with loss: 15.8135


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_30.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 37.95%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 36.04%
Saved best model with validation accuracy 36.035668789808916
Epoch 30/60 - Loss: 15.7681
Saved best model with loss: 15.7681


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 31/60 - Loss: 15.7258
Saved best model with loss: 15.7258


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 32/60 - Loss: 15.6913
Saved best model with loss: 15.6913


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 33/60 - Loss: 15.6633
Saved best model with loss: 15.6633


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 34/60 - Loss: 15.6328
Saved best model with loss: 15.6328


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 35/60 - Loss: 15.6002
Saved best model with loss: 15.6002


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_36.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 38.62%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 36.62%
Saved best model with validation accuracy 36.62471337579618
Epoch 36/60 - Loss: 15.5824
Saved best model with loss: 15.5824


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 37/60 - Loss: 15.5438
Saved best model with loss: 15.5438


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 38/60 - Loss: 15.5187
Saved best model with loss: 15.5187


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 39/60 - Loss: 15.4939
Saved best model with loss: 15.4939


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 40/60 - Loss: 15.4786
Saved best model with loss: 15.4786


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 41/60 - Loss: 15.4412
Saved best model with loss: 15.4412


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_42.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 37.70%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 35.88%
Epoch 42/60 - Loss: 15.4143
Saved best model with loss: 15.4143


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 43/60 - Loss: 15.4012
Saved best model with loss: 15.4012


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 44/60 - Loss: 15.3684
Saved best model with loss: 15.3684


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 45/60 - Loss: 15.3543
Saved best model with loss: 15.3543


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 46/60 - Loss: 15.3326
Saved best model with loss: 15.3326


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 47/60 - Loss: 15.3273
Saved best model with loss: 15.3273


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_48.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 39.85%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 37.84%
Saved best model with validation accuracy 37.83949044585987
Epoch 48/60 - Loss: 15.2948
Saved best model with loss: 15.2948


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 49/60 - Loss: 15.2718
Saved best model with loss: 15.2718


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 50/60 - Loss: 15.2570
Saved best model with loss: 15.2570


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 51/60 - Loss: 15.2459
Saved best model with loss: 15.2459


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 52/60 - Loss: 15.2221
Saved best model with loss: 15.2221


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 53/60 - Loss: 15.2163
Saved best model with loss: 15.2163


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_54.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 40.24%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 38.49%
Saved best model with validation accuracy 38.49375796178344
Epoch 54/60 - Loss: 15.1954


Saved best model with loss: 15.1954


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 55/60 - Loss: 15.1908
Saved best model with loss: 15.1908


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 56/60 - Loss: 15.1701
Saved best model with loss: 15.1701


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 57/60 - Loss: 15.1630
Saved best model with loss: 15.1630


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 58/60 - Loss: 15.1482
Saved best model with loss: 15.1482


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 59/60 - Loss: 15.1331
Saved best model with loss: 15.1331


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_Embedding-ResNet18_AMSoftmax_20250328-115538/checkpoint_epoch_60.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 40.05%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 37.57%
Epoch 60/60 - Loss: 15.1009
Saved best model with loss: 15.1009
Finished Training at 20250328-151716 with best validation accuracy 38.4938
Testing model...


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Test: 37.60%


In [14]:
print(test_acc)

37.60305732484077
