In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob
from tqdm.notebook import tqdm, trange
from PIL import Image
import matplotlib.pyplot as plt
import torchsummary
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
# import resnet18 model from pytorch
from torchvision.models import resnet18
from torch.utils.tensorboard import SummaryWriter
import mxnet as mx
from mxnet import recordio
import torch.multiprocessing as mp
from sklearn.model_selection import train_test_split
from collections import defaultdict
import logging

In [2]:
DIM = (112, 112)
BS = 256
EPOCHS = 30
LR = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
NUM_CLASSES = 10572
NUM_WORKERS = 4
LOG_INTERVAL = 6
DEVICE_IDS = [[2],[2]]
DEVICE = torch.device("cuda:{}".format(DEVICE_IDS[0][0]) if torch.cuda.is_available() else "cpu")

In [3]:
# class CASIAWebFaceDataset(Dataset):
#     def __init__(self, path_imgrec, transform=None):
#         self.transform = transform
#         assert path_imgrec
#         if path_imgrec:
#             logging.info('loading recordio %s...',
#                          path_imgrec)
#             path_imgidx = path_imgrec[0:-4] + ".idx"
#             print(path_imgrec, path_imgidx)
#             self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
#             s = self.imgrec.read_idx(0)
#             header, _ = recordio.unpack(s)
#             if header.flag > 0:
#                 print('header0 label', header.label)
#                 self.header0 = (int(header.label[0]), int(header.label[1]))
#                 # assert(header.flag==1)
#                 # self.imgidx = range(1, int(header.label[0]))
#                 self.imgidx = []
#                 self.id2range = {}
#                 self.seq_identity = range(int(header.label[0]), int(header.label[1]))
#                 for identity in self.seq_identity:
#                     s = self.imgrec.read_idx(identity)
#                     header, _ = recordio.unpack(s)
#                     a, b = int(header.label[0]), int(header.label[1])
#                     count = b - a
#                     self.id2range[identity] = (a, b)
#                     self.imgidx += range(a, b)
#                 print('id2range', len(self.id2range))
#             else:
#                 self.imgidx = list(self.imgrec.keys)
#             self.seq = self.imgidx

#     def __getitem__(self, idx):
#         # Map global index to class ID and local index
#         actual_idx = idx + 1  # MXNet indices start from 1
        
#         # Read record
#         header, s = recordio.unpack(self.imgrec.read_idx(actual_idx))
#         img = mx.image.imdecode(s).asnumpy()
#         label = int(header.label)
        
#         # # Convert to PIL and apply transforms
#         img = Image.fromarray(img)
#         if self.transform:
#             img = self.transform(img)
        
#         return img, label

#     def __len__(self):
#         return len(self.seq)

In [4]:
class CASIAWebFaceDataset(Dataset):
    def __init__(self, path_dataset, transform=None):
        self.transform = transform
        # the dataset is a folder with subfolders. The subfolder name is the label,
        # and the images are in the subfolder. Images represent their index in dataset.
        self.imgs = []
        self.labels = []
        
        # get all subfolders in the dataset folder
        subfolders = [f.path for f in os.scandir(path_dataset) if f.is_dir()]
        # get all images in the subfolders
        for label, subfolder in enumerate(subfolders):
            for img_file in os.listdir(subfolder):
                if img_file.endswith('.jpg') or img_file.endswith('.png'):
                    self.imgs.append(os.path.join(subfolder, img_file))
                    self.labels.append(label)
        self.imgs = np.array(self.imgs)
        self.labels = np.array(self.labels)
        self.seq = np.arange(len(self.imgs))
        self.id2range = defaultdict(list)
        for i, label in enumerate(self.labels):
            self.id2range[label].append(i)
        self.seq_identity = np.unique(self.labels)
        self.imgidx = np.arange(len(self.imgs))

    def __getitem__(self, idx):
        # Map global index to class ID and local index
        actual_idx = idx
        
        # Read image
        img = Image.open(self.imgs[actual_idx])
        label = int(self.labels[actual_idx])
        
        # Convert to PIL and apply transforms
        if self.transform:
            img = self.transform(img)
        
        return img, label
    def __len__(self):
        return len(self.seq)
    def get_imgidx(self):
        return self.imgidx

In [5]:
class CustomNormalize:
    def __call__(self, img):
        # Convert PIL image to tensor
        img = transforms.ToTensor()(img)
        # Subtract 128 and divide by 128
        img = (img * 255.0 - 127.5) / 128.0
        return img

train_preprocess = transforms.Compose([
    transforms.RandomResizedCrop(DIM, scale=(0.08, 1.0)),  # Randomly crop the image to 112x112
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.RandomRotation(30),  # Randomly rotate the image by up to 10 degrees
    # transforms.RandomVerticalFlip(),  # Randomly flip the image vertically
    # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Randomly change brightness, contrast, saturation and hue
    CustomNormalize()
])

test_preprocess = transforms.Compose([
    transforms.Resize(DIM),
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    CustomNormalize()
])

In [6]:
class AMSoftmax(nn.Module):
    '''
    The am softmax as seen on https://arxiv.org/pdf/1801.05599.pdf,

        in_features: size of the embedding, eg. 512
        n_classes: number of classes on the classification task
        s: s parameter of loss, standard = 30.
        m: m parameter of loss, standard = 0.4, best between 0.35 and 0.4 according to paper.

        *inputs: tensor shaped (batch_size X embedding_size)
        output : tensor shaped (batch_size X n_classes) AM_softmax logits for NLL_loss.

    '''
    def __init__(self, in_features, n_classes, s=30, m=0.35):
        super(AMSoftmax, self).__init__()
        self.linear = nn.Linear(in_features, n_classes, bias=False)
        self.s = s
        self.m = m

    def forward(self, *inputs):
        # x_vector = F.normalize(inputs[0], p=2, dim=-1)
        # self.linear.weight.data = F.normalize(self.linear.weight.data, p=2, dim=-1, eps=1e-10)
        # logits = self.linear(x_vector)
        x_vector = inputs[0]
        normed_weight = F.normalize(self.linear.weight, p=2, dim=-1, eps=1e-10)
        logits = F.linear(x_vector, normed_weight)
        scaled_logits = (logits - self.m)*self.s
        return  scaled_logits - self._am_logsumexp(logits)

    def _am_logsumexp(self, logits):
        '''
        logsumexp designed for am_softmax, the computation is numerically stable

        '''
        max_x = torch.max(logits, dim=-1)[0].unsqueeze(-1)
        term1 = (self.s*(logits - (max_x + self.m))).exp()
        term2 = (self.s * (logits - max_x)).exp().sum(-1).unsqueeze(-1) \
                - (self.s * (logits - max_x)).exp()
        return self.s*max_x + (term2 + term1).log()

def resface_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.PReLU(out_channels),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.PReLU(out_channels)
    )

def resface_pre(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.PReLU(out_channels)
    )

class Resface20(nn.Module):
    def __init__(self, bottleneck_layer_size=512, dropout_prob=0.4): # Added dropout_prob as argument
        super(Resface20, self).__init__()

        self.conv1_pre = resface_pre(3, 64)
        self.conv1 = self._make_layer(64, 64, 1)

        self.conv2_pre = resface_pre(64, 128)
        self.conv2 = self._make_layer(128, 128, 2)

        self.conv3_pre = resface_pre(128, 256)
        self.conv3 = self._make_layer(256, 256, 4)

        self.conv4_pre = resface_pre(256, 512)
        self.conv4 = self._make_layer(512, 512, 1)

        self.dropout = nn.Dropout(p=dropout_prob) # Use dropout probability
        self.flatten = nn.Flatten()
        self.bottleneck = nn.Linear(512 * 7 * 7, bottleneck_layer_size)  # Assuming input image size is 112x112. Adjust 7x7 accordingly
        # Note: The input size to the bottleneck layer depends on the input image size and the number of pooling layers.
        # For 112x112, the size is 7x7 after 4 stride-2 convolutions.  For 64x64, this size becomes 2x2.

        # Initialize weights using Xavier initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def _make_layer(self, in_channels, out_channels, blocks):
        layers = []
        for _ in range(blocks):
            layers.append(resface_block(in_channels, out_channels))
            in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1_pre(x)
        x = self.conv1(x)

        x = self.conv2_pre(x)
        x = self.conv2(x)

        x = self.conv3_pre(x)
        x = self.conv3(x)

        x = self.conv4_pre(x)
        x = self.conv4(x)

        x = self.flatten(x)
        x = self.dropout(x) # Apply dropout

        x = self.bottleneck(x)
        x = F.normalize(x, p=2, dim=-1)
        return x

In [7]:
# embedding_model = Resface20().cuda()
# torchsummary.summary(embedding_model, (3, 112, 112))
# torch.cuda.empty_cache()

In [8]:
def test_classifier(model, classifier, data_loader, device, message):
    model.eval()
    classifier.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            embeddings = model(images)  # Extract features
            logits = classifier(embeddings)  # Compute AMSoftmax logits
            predictions = torch.argmax(logits, dim=1)  # Get class with max probability
            
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total * 100
    print(f"✅ Classification Accuracy for {message}: {accuracy:.2f}%")
    return accuracy

In [9]:
def train_AMSoftmax(model: nn.Module, classifier: nn.Module, data_loader: DataLoader, val_loader: DataLoader,
                    optimizer: optim.Optimizer, scheduler: optim.lr_scheduler, 
                    criterion: nn.Module, epochs: int, device: torch.device, 
                    retain_graph: bool, checkpoint_interval: int = 10):

    train_losses = []
    tmstmp = time.strftime("%Y%m%d-%H%M%S")
    best_loss = np.inf
    vacc = 0

    # tmstmp = "20250313-074759"

    log_dir = f"runs/{DIM[0]}x{DIM[1]}_ResFace20_AMSoftmax_{tmstmp}"
    writer = SummaryWriter(log_dir=log_dir)

    print(f"Started Training at {tmstmp}")
    
    for e, epoch in enumerate(tqdm(range(epochs), desc="Epochs")):
        model.train()
        classifier.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(tqdm(data_loader, desc="Batches")):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            embeddings = model(images)  # Extract embeddings from model
            logits = classifier(embeddings)  # Compute AMSoftmax logits
            loss = criterion(logits, labels)  # Compute NLL loss

            loss.backward(retain_graph=retain_graph)
            optimizer.step()
            running_loss += loss.item()
            # print(f"Batch {i+1}/{len(data_loader)} - Loss: {loss.item():.4f}")

        scheduler.step()
        avg_loss = running_loss / len(data_loader)
        train_losses.append(avg_loss)

        # Validation
        model.eval()
        running_loss = 0.0
        for i, (images, labels) in enumerate(tqdm(val_loader, desc="Validation")):
            images, labels = images.to(device), labels.to(device)
            embeddings = model(images)
            logits = classifier(embeddings)
            loss = criterion(logits, labels)
            running_loss += loss.item()
            # print(f"Batch {i+1}/{len(val_loader)} - Loss: {loss.item():.4f}")


        # Log per epoch
        writer.add_scalar('Loss/train', avg_loss, epoch)
        writer.add_scalar('Loss/val', running_loss / len(val_loader), epoch)

        # Checkpoint saving
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(log_dir, f'checkpoint_epoch_{epoch + 1}.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'classifier_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")
            train_acc = test_classifier(model, classifier, data_loader, device, "Training")
            val_acc = test_classifier(model, classifier, val_loader, device, "Validation")
            writer.add_scalar('Accuracy/train', train_acc, epoch)
            writer.add_scalar('Accuracy/val', val_acc, epoch)
            if val_acc > vacc:
                vacc = val_acc
                torch.save(model.state_dict(), f"{log_dir}/{DIM[0]}x{DIM[1]}_ResFace20_AMSoftmax_validation_{tmstmp}.pt")
                print(f"Saved best model with validation accuracy {vacc}")
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
        # Save Best Model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f"{log_dir}/{DIM[0]}x{DIM[1]}_ResFace20_AMSoftmax_{tmstmp}.pt")
            print(f"Saved best model with loss: {best_loss:.4f}")

    print(f"Finished Training at {time.strftime('%Y%m%d-%H%M%S')} with best validation accuracy {vacc:.4f}")
    writer.close()

    return model, train_losses, tmstmp

In [10]:
def train_on_casia_webface(embedding_model, classifier, optimizer, device, device_ids):
    # Set up paths
    # rec_path = "/home/ichitu/py-files/faces_webface_112x112/train.rec"
    path_dataset = "/home/ichitu/py-files/faces_webface_112x112_cropped"

    
    # Load datasets
    print("Loading CASIA WebFace dataset...")
    # casia_dataset = CASIAWebFaceDataset(
    #     path_imgrec=rec_path,
    #     transform=train_preprocess
    # )

    casia_dataset = CASIAWebFaceDataset(
        path_dataset=path_dataset,
        transform=test_preprocess
    )

    train_idx, test_idx = train_test_split(range(len(casia_dataset)), test_size=0.2, random_state=42)
    train_dataset = torch.utils.data.Subset(casia_dataset, train_idx)
    test_dataset = torch.utils.data.Subset(casia_dataset, test_idx)
    # val_idx, test_idx = train_test_split(test_idx, test_size=0.7, random_state=42)
    # val_dataset = torch.utils.data.Subset(casia_dataset, val_idx)
    # test_dataset = torch.utils.data.Subset(casia_dataset, test_idx)

    print(len(train_dataset))
    # print(len(val_dataset))
    print(len(test_dataset))

    

    print(len(casia_dataset))
    
    # print("Loading LFW dataset...")
    # lfw_dataset = foz.load_zoo_dataset("lfw")
    
    # Find and filter overlapping identities
    # print("Finding overlapping identities...")
    # overlapping_ids = find_overlapping_identities(casia_dataset, lfw_dataset)
    
    # print("Creating filtered dataset...")
    # filtered_dataset = FilteredCASIADataset(casia_dataset, overlapping_ids)
    # filtered_dataset = casia_dataset
    
    # Create data loader
    # train_loader = DataLoader(
    #     train_dataset, 
    #     batch_size=BS * len(device_ids),
    #     shuffle=True, 
    #     num_workers=2,
    #     pin_memory=True
    # )

    # val_loader = DataLoader(
    #     val_dataset, 
    #     batch_size=BS * len(device_ids),
    #     shuffle=True, 
    #     num_workers=2,
    #     pin_memory=True
    # )

    # test_loader = DataLoader(
    #     test_dataset, 
    #     batch_size=BS * len(device_ids),
    #     shuffle=True, 
    #     num_workers=2,
    #     pin_memory=True
    # )

    train_loader = DataLoader(
        casia_dataset,
        batch_size=BS,
        num_workers=4,
        sampler=torch.utils.data.SubsetRandomSampler(train_idx),
        pin_memory=True
    )

    # val_loader = DataLoader(
    #     CASIAWebFaceDataset(
    #         path_imgrec=rec_path,
    #         transform=test_preprocess
    #     ),
    #     batch_size=BS,
    #     num_workers=2,
    #     sampler=torch.utils.data.SubsetRandomSampler(val_idx),
    #     pin_memory=True
    # )

    test_loader = DataLoader(
        # CASIAWebFaceDataset(
        #     path_imgrec=rec_path,
        #     transform=test_preprocess
        # ),
        casia_dataset,
        batch_size=BS,
        num_workers=2,
        sampler=torch.utils.data.SubsetRandomSampler(test_idx),
        pin_memory=True
    )
    
    # Set up model
    print("Setting up model...")
    
    # Set up optimizer (include both models' parameters)
    optimizer = optim.SGD(
        list(embedding_model.parameters()) + list(classifier.parameters()),
        lr=LR,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY
    )
    # optimizer = optim.Adam(
    #     embedding_model.parameters(),
    #     lr=0.001,
    #     weight_decay=5e-4,
    #     eps=1e-8
    # )
    # optimizer = optim.SGD(
    #     list(embedding_model.parameters()) + list(classifier.parameters()),
    #     lr=optimizer['param_groups'][0]['lr'],
    #     momentum=MOMENTUM,
    # weight_decay=WEIGHT_DECAY
    # )

    # Set up scheduler
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-5)
    # scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=10, eta_min=1e-5)
    # scheduler = optim.lr_scheduler.MultiStepLR( optimizer, 
    #                                             milestones=[int(EPOCHS*0.5), int(EPOCHS*0.8), int(EPOCHS*0.9)],
    #                                             gamma=0.1)
    scheduler = optim.lr_scheduler.MultiStepLR( optimizer,
                                                milestones=[16, 24, 28],
                                                gamma=0.1)
    # Set up loss
    criterion = nn.NLLLoss()
    
    # Train
    print("Starting training...")
    model, losses, timestamp = train_AMSoftmax(
        model=embedding_model,
        classifier=classifier,
        data_loader=train_loader,
        # val_loader=val_loader,
        val_loader=test_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        epochs=EPOCHS,  # Adjust as needed
        device=device,
        retain_graph=False,
        checkpoint_interval=LOG_INTERVAL
    )

    print("Testing model...")
    test_acc = test_classifier(embedding_model, classifier, test_loader, device, "Test")
    
    return model, losses, timestamp, test_acc

In [11]:
def train_on_casia_webface_normal_split(embedding_model, classifier, optimizer, device, device_ids):
    # Set up paths
    # rec_path = "/home/ichitu/py-files/faces_webface_112x112/train.rec"
    path_dataset = "/home/ichitu/py-files/faces_webface_112x112_cropped"
    
    # Load datasets
    print("Loading CASIA WebFace dataset...")
    # casia_dataset = CASIAWebFaceDataset(
    #     path_imgrec=rec_path,
    #     transform=test_preprocess
    # )
    casia_dataset = CASIAWebFaceDataset(
        path_dataset=path_dataset,
        transform=train_preprocess
    )

    train_length = int(len(casia_dataset) * 0.8)
    test_length = len(casia_dataset) - train_length
    train_dataset, test_dataset = torch.utils.data.random_split(casia_dataset, [train_length, test_length])

    train_loader = DataLoader(
        train_dataset,
        batch_size=BS,
        num_workers=4,
        pin_memory=True,
        shuffle=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BS,
        num_workers=4,
        pin_memory=True,
        shuffle=True
    )
   
    # Set up model
    print("Setting up model...")
    
    # Set up optimizer (include both models' parameters)
    optimizer = optim.SGD(
        list(embedding_model.parameters()) + list(classifier.parameters()),
        lr=LR,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY
    )
    # optimizer = optim.Adam(
    #     embedding_model.parameters(),
    #     lr=0.001,
    #     weight_decay=5e-4,
    #     eps=1e-8
    # )
    # optimizer = optim.SGD(
    #     list(embedding_model.parameters()) + list(classifier.parameters()),
    #     lr=optimizer['param_groups'][0]['lr'],
    #     momentum=MOMENTUM,
    # weight_decay=WEIGHT_DECAY
    # )

    # Set up scheduler
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
    # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-5)
    # scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=10, eta_min=1e-5)
    scheduler = optim.lr_scheduler.MultiStepLR( optimizer, 
                                                milestones=[int(EPOCHS*0.5), int(EPOCHS*0.8), int(EPOCHS*0.9)],
                                                gamma=0.1)
    # Set up loss
    criterion = nn.NLLLoss()
    
    # Train
    print("Starting training...")
    model, losses, timestamp = train_AMSoftmax(
        model=embedding_model,
        classifier=classifier,
        data_loader=train_loader,
        # val_loader=val_loader,
        val_loader=test_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        epochs=EPOCHS,  # Adjust as needed
        device=device,
        retain_graph=False,
        checkpoint_interval=LOG_INTERVAL
    )

    print("Testing model...")
    test_acc = test_classifier(embedding_model, classifier, test_loader, device, "Test")
    
    return model, losses, timestamp, test_acc

In [12]:
def load_model(model, check_point_dir, device='cuda'):
    check_point = torch.load(check_point_dir, map_location=device)
    state_dict = check_point['model_state_dict']
    # Create new OrderedDict without 'module.' prefix
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:] # remove 'module.' prefix
        else:
            name = k
        new_state_dict[name] = v
    
    # Load the weights
    model.load_state_dict(new_state_dict)
    # Set to evaluation mode
    # model.eval()
    # print(f"Model loaded from {path}")

    optimizer = torch.load(check_point_dir, map_location=device)['optimizer_state_dict']

    return model, optimizer

In [13]:
# assert 1==2

check_point_dir = "/home/ichitu/py-files/runs/112x112_ResNet18_AMSoftmax_20250401-182809/checkpoint_epoch_30.pth"

device_ids = DEVICE_IDS
device = DEVICE
classifier = AMSoftmax(128, 10572)
embedding_model = Resface20(128)
# embedding_model, classifier, optimizer = load_model(embedding_model, classifier, check_point_dir, device)
# embedding_model = load_model(embedding_model, '/home/ichitu/py-files/runs/112x96_ResNet18_AMSoftmax_20250322-161021/112x96_ResNet18_AMSoftmax_20250322-161021.pt', 'cuda:3')

if torch.cuda.device_count() > 1:
    print(f"Avaible {torch.cuda.device_count()} GPUs and using {device_ids}")
    embedding_model = nn.DataParallel(embedding_model, device_ids=device_ids[0])
    classifier = nn.DataParallel(classifier, device_ids=device_ids[1])

# embedding_model = embedding_model.to(device)
# classifier = classifier.to(device)
# embedding_model.load_state_dict(torch.load("Models-pt/112x96_ResNet18_AMSoftmax_20250313-074759.pt"))
# print(embedding_model)
# print(classifier)
# print(optimizer)
# # assert False

optimizer = torch.load(check_point_dir, map_location=device)['optimizer_state_dict']

embedding_model, train_losses, tmstmp, test_acc = train_on_casia_webface(embedding_model, classifier, optimizer, device, device_ids)

# embedding_model, train_losses, tmstmp, test_acc = train_on_casia_webface_normal_split(embedding_model, classifier, optimizer, device, device_ids)


Avaible 4 GPUs and using [[2], [2]]


Loading CASIA WebFace dataset...


392498
98125
490623
Setting up model...
Starting training...
Started Training at 20250403-220324


Epochs:   0%|          | 0/30 [00:00<?, ?it/s]

Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 1/30 - Loss: 20.2308
Saved best model with loss: 20.2308


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 2/30 - Loss: 17.3497
Saved best model with loss: 17.3497


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 3/30 - Loss: 14.7226
Saved best model with loss: 14.7226


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 4/30 - Loss: 12.7923
Saved best model with loss: 12.7923


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 5/30 - Loss: 11.5393
Saved best model with loss: 11.5393


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResFace20_AMSoftmax_20250403-220324/checkpoint_epoch_6.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 35.80%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 31.21%
Saved best model with validation accuracy 31.207133757961785
Epoch 6/30 - Loss: 10.7158
Saved best model with loss: 10.7158


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 7/30 - Loss: 10.1362
Saved best model with loss: 10.1362


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 8/30 - Loss: 9.6801
Saved best model with loss: 9.6801


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 9/30 - Loss: 9.3329
Saved best model with loss: 9.3329


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 10/30 - Loss: 9.0741
Saved best model with loss: 9.0741


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 11/30 - Loss: 8.8665
Saved best model with loss: 8.8665


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResFace20_AMSoftmax_20250403-220324/checkpoint_epoch_12.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 67.77%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 61.77%
Saved best model with validation accuracy 61.77426751592356
Epoch 12/30 - Loss: 8.7138


Saved best model with loss: 8.7138


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 13/30 - Loss: 8.5778
Saved best model with loss: 8.5778


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 14/30 - Loss: 8.4496
Saved best model with loss: 8.4496


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 15/30 - Loss: 8.3501
Saved best model with loss: 8.3501


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 16/30 - Loss: 8.2560
Saved best model with loss: 8.2560


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 17/30 - Loss: 5.7972
Saved best model with loss: 5.7972


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResFace20_AMSoftmax_20250403-220324/checkpoint_epoch_18.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 91.03%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 85.16%
Saved best model with validation accuracy 85.16076433121019
Epoch 18/30 - Loss: 5.0227


Saved best model with loss: 5.0227


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 19/30 - Loss: 4.6974
Saved best model with loss: 4.6974


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 20/30 - Loss: 4.4986
Saved best model with loss: 4.4986


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 21/30 - Loss: 4.3736
Saved best model with loss: 4.3736


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 22/30 - Loss: 4.2946
Saved best model with loss: 4.2946


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 23/30 - Loss: 4.2498
Saved best model with loss: 4.2498


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResFace20_AMSoftmax_20250403-220324/checkpoint_epoch_24.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 92.00%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 84.21%
Epoch 24/30 - Loss: 4.2210
Saved best model with loss: 4.2210


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 25/30 - Loss: 3.3320


Saved best model with loss: 3.3320


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 26/30 - Loss: 3.1233
Saved best model with loss: 3.1233


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 27/30 - Loss: 3.0196
Saved best model with loss: 3.0196


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 28/30 - Loss: 2.9370
Saved best model with loss: 2.9370


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Epoch 29/30 - Loss: 2.7917
Saved best model with loss: 2.7917


Batches:   0%|          | 0/1534 [00:00<?, ?it/s]

Validation:   0%|          | 0/384 [00:00<?, ?it/s]

Checkpoint saved at runs/112x112_ResFace20_AMSoftmax_20250403-220324/checkpoint_epoch_30.pth


Testing:   0%|          | 0/1534 [00:00<?, ?it/s]

✅ Classification Accuracy for Training: 95.49%


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Validation: 87.59%
Saved best model with validation accuracy 87.5943949044586
Epoch 30/30 - Loss: 2.7745


Saved best model with loss: 2.7745
Finished Training at 20250404-022030 with best validation accuracy 87.5944
Testing model...


Testing:   0%|          | 0/384 [00:00<?, ?it/s]

✅ Classification Accuracy for Test: 87.58%


In [14]:
print(test_acc)

87.57707006369426
