# Load the model and requirements

In [None]:
!pip install gdown -q
!pip install -q scipy==1.8.1 --force-reinstall

In [None]:
!pip show scipy
!python --version

In [None]:
%%writefile /kaggle/working/Arcface_torch/requirement.txt
tensorboard
easydict
mxnet==1.8.0
onnx
scikit-learn
opencv-python
numpy==1.23.5

In [None]:
%cd /kaggle/working/Arcface_torch
!pip install -q -r requirement.txt

## Load checkpoint

In [None]:
%%writefile /kaggle/working/Arcface_torch/dataset.py
import numbers
import os
import queue as Queue
import threading
from typing import Iterable

import mxnet as mx
import numpy as np
import torch
from functools import partial
from torch import distributed
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from utils.utils_distributed_sampler import DistributedSampler
from utils.utils_distributed_sampler import get_dist_info, worker_init_fn


def get_dataloader(
    root_dir,
    local_rank,
    batch_size,
    dali = False,
    dali_aug = False,
    seed = 2048,
    num_workers = 2,
    image_size = 112,
  
    ) -> Iterable:

    rec = os.path.join(root_dir, 'train.rec')
    idx = os.path.join(root_dir, 'train.idx')
    train_set = None

    # Synthetic
    if root_dir == "synthetic":
        train_set = SyntheticDataset()
        dali = False

    # Mxnet RecordIO
    elif os.path.exists(rec) and os.path.exists(idx):
        train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank, image_size = image_size)

    # Image Folder
    else:
        transform = transforms.Compose([
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
             ])
        train_set = ImageFolder(root_dir, transform)

    # DALI
    if dali:
        return dali_data_iter(
            batch_size=batch_size, rec_file=rec, idx_file=idx,
            num_threads=2, local_rank=local_rank, dali_aug=dali_aug)

    rank, world_size = get_dist_info()
    train_sampler = DistributedSampler(
        train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed)

    if seed is None:
        init_fn = None
    else:
        init_fn = partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed)

    train_loader = DataLoaderX(
        local_rank=local_rank,
        dataset=train_set,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
        worker_init_fn=init_fn,
    )

    return train_loader

class BackgroundGenerator(threading.Thread):
    def __init__(self, generator, local_rank, max_prefetch=6):
        super(BackgroundGenerator, self).__init__()
        self.queue = Queue.Queue(max_prefetch)
        self.generator = generator
        self.local_rank = local_rank
        self.daemon = True
        self.start()

    def run(self):
        torch.cuda.set_device(self.local_rank)
        for item in self.generator:
            self.queue.put(item)
        self.queue.put(None)

    def next(self):
        next_item = self.queue.get()
        if next_item is None:
            raise StopIteration
        return next_item

    def __next__(self):
        return self.next()

    def __iter__(self):
        return self


class DataLoaderX(DataLoader):

    def __init__(self, local_rank, **kwargs):
        super(DataLoaderX, self).__init__(**kwargs)
        self.stream = torch.cuda.Stream(local_rank)
        self.local_rank = local_rank

    def __iter__(self):
        self.iter = super(DataLoaderX, self).__iter__()
        self.iter = BackgroundGenerator(self.iter, self.local_rank)
        self.preload()
        return self

    def preload(self):
        self.batch = next(self.iter, None)
        if self.batch is None:
            return None
        with torch.cuda.stream(self.stream):
            for k in range(len(self.batch)):
                self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)

    def __next__(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch = self.batch
        if batch is None:
            raise StopIteration
        self.preload()
        return batch


class MXFaceDataset(Dataset):
    def __init__(self, root_dir, local_rank, image_size):
        super(MXFaceDataset, self).__init__()
        self.transform = transforms.Compose(
            [transforms.ToPILImage(),
             transforms.Resize((image_size, image_size)),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
             ])
        self.root_dir = root_dir
        self.local_rank = local_rank
        path_imgrec = os.path.join(root_dir, 'train.rec')
        path_imgidx = os.path.join(root_dir, 'train.idx')
        self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
        s = self.imgrec.read_idx(0)
        header, _ = mx.recordio.unpack(s)
        if header.flag > 0:
            self.header0 = (int(header.label[0]), int(header.label[1]))
            self.imgidx = np.array(range(1, int(header.label[0])))
        else:
            self.imgidx = np.array(list(self.imgrec.keys))

    def __getitem__(self, index):
        idx = self.imgidx[index]
        s = self.imgrec.read_idx(idx)
        header, img = mx.recordio.unpack(s)
        label = header.label
        if not isinstance(label, numbers.Number):
            label = label[0]
        # label = torch.tensor(label, dtype=torch.long)
        label = torch.tensor(int(label), dtype=torch.long)

        sample = mx.image.imdecode(img).asnumpy()
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, label

    def __len__(self):
        return len(self.imgidx)


class SyntheticDataset(Dataset):
    def __init__(self):
        super(SyntheticDataset, self).__init__()
        img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
        img = np.transpose(img, (2, 0, 1))
        img = torch.from_numpy(img).squeeze(0).float()
        img = ((img / 255) - 0.5) / 0.5
        self.img = img
        self.label = 1

    def __getitem__(self, index):
        return self.img, self.label

    def __len__(self):
        return 1000000


def dali_data_iter(
    batch_size: int, rec_file: str, idx_file: str, num_threads: int,
    initial_fill=32768, random_shuffle=True,
    prefetch_queue_depth=1, local_rank=0, name="reader",
    mean=(127.5, 127.5, 127.5), 
    std=(127.5, 127.5, 127.5),
    dali_aug=False
    ):
    """
    Parameters:
    ----------
    initial_fill: int
        Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored.

    """
    rank: int = distributed.get_rank()
    world_size: int = distributed.get_world_size()
    import nvidia.dali.fn as fn
    import nvidia.dali.types as types
    from nvidia.dali.pipeline import Pipeline
    from nvidia.dali.plugin.pytorch import DALIClassificationIterator

    def dali_random_resize(img, resize_size, image_size=112):
        img = fn.resize(img, resize_x=resize_size, resize_y=resize_size)
        img = fn.resize(img, size=(image_size, image_size))
        return img
    def dali_random_gaussian_blur(img, window_size):
        img = fn.gaussian_blur(img, window_size=window_size * 2 + 1)
        return img
    def dali_random_gray(img, prob_gray):
        saturate = fn.random.coin_flip(probability=1 - prob_gray)
        saturate = fn.cast(saturate, dtype=types.FLOAT)
        img = fn.hsv(img, saturation=saturate)
        return img
    def dali_random_hsv(img, hue, saturation):
        img = fn.hsv(img, hue=hue, saturation=saturation)
        return img
    def multiplexing(condition, true_case, false_case):
        neg_condition = condition ^ True
        return condition * true_case + neg_condition * false_case

    condition_resize = fn.random.coin_flip(probability=0.1)
    size_resize = fn.random.uniform(range=(int(112 * 0.5), int(112 * 0.8)), dtype=types.FLOAT)
    condition_blur = fn.random.coin_flip(probability=0.2)
    window_size_blur = fn.random.uniform(range=(1, 2), dtype=types.INT32)
    condition_flip = fn.random.coin_flip(probability=0.5)
    condition_hsv = fn.random.coin_flip(probability=0.2)
    hsv_hue = fn.random.uniform(range=(0., 20.), dtype=types.FLOAT)
    hsv_saturation = fn.random.uniform(range=(1., 1.2), dtype=types.FLOAT)

    pipe = Pipeline(
        batch_size=batch_size, num_threads=num_threads,
        device_id=local_rank, prefetch_queue_depth=prefetch_queue_depth, )
    condition_flip = fn.random.coin_flip(probability=0.5)
    with pipe:
        jpegs, labels = fn.readers.mxnet(
            path=rec_file, index_path=idx_file, initial_fill=initial_fill, 
            num_shards=world_size, shard_id=rank,
            random_shuffle=random_shuffle, pad_last_batch=False, name=name)
        images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
        if dali_aug:
            images = fn.cast(images, dtype=types.UINT8)
            images = multiplexing(condition_resize, dali_random_resize(images, size_resize, image_size=112), images)
            images = multiplexing(condition_blur, dali_random_gaussian_blur(images, window_size_blur), images)
            images = multiplexing(condition_hsv, dali_random_hsv(images, hsv_hue, hsv_saturation), images)
            images = dali_random_gray(images, 0.1)

        images = fn.crop_mirror_normalize(
            images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip)
        pipe.set_outputs(images, labels)
    pipe.build()
    return DALIWarper(DALIClassificationIterator(pipelines=[pipe], reader_name=name, ))


@torch.no_grad()
class DALIWarper(object):
    def __init__(self, dali_iter):
        self.iter = dali_iter

    def __next__(self):
        data_dict = self.iter.__next__()[0]
        tensor_data = data_dict['data'].cuda()
        tensor_label: torch.Tensor = data_dict['label'].cuda().long()
        tensor_label.squeeze_()
        return tensor_data, tensor_label

    def __iter__(self):
        return self

    def reset(self):
        self.iter.reset()

In [None]:
%%writefile /kaggle/working/Arcface_torch/backbones/iresnet_plus.py
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
import math

__all__ = ['iresnet18_plus', 'iresnet34_plus', 'iresnet50_plus', 'iresnet100_plus', 'iresnet200_plus']
using_ckpt = False

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=dilation,
                     groups=groups,
                     bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class IBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        
        # Add dilated convolution support
        self.dilation = dilation
        self.stride = stride
        
        # First conv layer with dilation
        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05)
        self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
        
        # Second conv layer with dilation
        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05)
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, dilation=dilation)
        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05)
        
        # Downsample layer
        self.downsample = downsample
        
        # Initialize weights using Kaiming initialization
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward_impl(self, x):
        identity = x
        
        # First conv block
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        
        # Second conv block
        out = self.conv2(out)
        out = self.bn3(out)
        
        # Add skip connection
        if self.downsample is not None:
            identity = self.downsample(x)
        
        # Add residual connection
        out += identity
        return out

    def forward(self, x):
        if self.training and using_ckpt:
            return checkpoint(self.forward_impl, x)
        else:
            return self.forward_impl(x)


class IResNet(nn.Module):
    def __init__(self,
                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False, image_size=112):
        super(IResNet, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
        self.extra_gflops = 0.0
        self.fp16 = fp16
        self.inplanes = 64
        self.dilation = 1
        
        # Calculate feature map size after all stride operations
        self.fc_scale = (image_size // 16)**2  # 4 stride=2 operations: 2^4 = 16
        
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        
        self.groups = groups
        self.base_width = width_per_group
        
        # Initial conv layer
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
        self.prelu = nn.PReLU(self.inplanes)
        
        # Create layers with skip connections
        self.layer1 = self._make_layer(block, 96, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 160, layers[1], stride=2,
                                     dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 320, layers[2], stride=2,
                                     dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                     dilate=replace_stride_with_dilation[2])
        
        # Final layers
        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05)
        self.dropout = nn.Dropout(p=dropout, inplace=True)
        
        # Calculate input size for fc layer
        fc_input_size = 512 * block.expansion * self.fc_scale
        self.fc = nn.Linear(fc_input_size, num_features)
        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
        
        # Initialize weights
        self._initialize_weights()
        
        # Set feature weights to 1 and freeze
        nn.init.constant_(self.features.weight, 1.0)
        self.features.weight.requires_grad = False

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        
        if dilate:
            self.dilation *= stride
            stride = 1
            
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion, eps=1e-05),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                          self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                              base_width=self.base_width, dilation=self.dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        
        # Forward through layers with skip connections
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.bn2(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x.float() if self.fp16 else x)
        x = self.features(x)
        return x



def _iresnet_plus(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        raise ValueError()
    return model


def iresnet18_plus(pretrained=False, progress=True, **kwargs):
    return _iresnet_plus('iresnet18_plus', IBasicBlock, [2, 2, 2, 2], pretrained,
                    progress, **kwargs)


def iresnet34_plus(pretrained=False, progress=True, **kwargs):
    return _iresnet_plus('iresnet34_plus', IBasicBlock, [3, 4, 6, 3], pretrained,
                    progress, **kwargs)


def iresnet50_plus(pretrained=False, progress=True, **kwargs):
    return _iresnet_plus('iresnet50_plus', IBasicBlock, [3, 4, 8, 3], pretrained,
                    progress, **kwargs)


def iresnet100_plus(pretrained=False, progress=True, **kwargs):
    return _iresnet_plus('iresnet100_plus', IBasicBlock, [4, 8, 16, 3], pretrained,
                    progress, **kwargs)


def iresnet200_plus(pretrained=False, progress=True, **kwargs):
    return _iresnet_plus('iresnet200_plus', IBasicBlock, [6, 26, 60, 6], pretrained,
                    progress, **kwargs)



if __name__ == "__main__":
    import cv2
    import time
    model = iresnet100_plus(image_size=112)  # Specify image size explicitly
    model.eval()

    img = cv2.imread("data/img1.jpg")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (112, 112))
    img = img.transpose(2, 0, 1)  # [H,W,C] -> [C,H,W]
    img = torch.from_numpy(img).float().unsqueeze(0)  # [1,C,H,W]

    # Test forward pass
    with torch.no_grad():
        start = time.time()
        for _ in range(10):
            y = model(img)
        end = time.time()
    print(y.shape)
    infer_time = (end - start) / 10

    # Results
    print(f'Average inference time: {infer_time:.6f} seconds')
    print('Output shape:', y.shape)  # [1, 512]
    print("Vector norm:", torch.norm(y, p=2, dim=1).item())
    num_params = sum(p.numel() for p in model.parameters())
    print(f"Total number of parameters: {num_params}")

In [None]:
%%writefile /kaggle/working/Arcface_torch/losses.py
import torch
import math

class CombinedMarginLoss(torch.nn.Module):
    def __init__(self, 
                 s, 
                 m1,
                 m2,
                 m3,
                 interclass_filtering_threshold=0):
        super().__init__()
        self.s = s
        self.m1 = m1
        self.m2 = m2
        self.m3 = m3
        self.interclass_filtering_threshold = interclass_filtering_threshold
        
        # For ArcFace
        self.cos_m = math.cos(self.m2)
        self.sin_m = math.sin(self.m2)
        self.theta = math.cos(math.pi - self.m2)
        self.sinmm = math.sin(math.pi - self.m2) * self.m2
        self.easy_margin = False


    def forward(self, logits, labels):
        index_positive = torch.where(labels != -1)[0]

        if self.interclass_filtering_threshold > 0:
            with torch.no_grad():
                dirty = logits > self.interclass_filtering_threshold
                dirty = dirty.float()
                mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
                mask.scatter_(1, labels[index_positive], 0)
                dirty[index_positive] *= mask
                tensor_mul = 1 - dirty    
            logits = tensor_mul * logits

        target_logit = logits[index_positive, labels[index_positive].view(-1)]

        if self.m1 == 1.0 and self.m3 == 0.0:
            with torch.no_grad():
                target_logit.arccos_()
                logits.arccos_()
                final_target_logit = target_logit + self.m2
                logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
                logits.cos_()
            logits = logits * self.s        

        elif self.m3 > 0:
            final_target_logit = target_logit - self.m3
            logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
            logits = logits * self.s
        else:
            raise

        return logits


class CombinedDynamicMarginLoss(torch.nn.Module):
    def __init__(self,
                 s: float = 64.0,
                 m1: float = 1.0,
                 m2: float = 0.5,
                 m3: float = 0,
                 interclass_filtering_threshold: float = 0.0,
                 alpha: float = 0.1):
        super().__init__()
        self.s = s
        self.m1 = m1
        self.m2 = m2
        self.m3 = m3
        self.alpha = alpha
        self.interclass_filtering_threshold = interclass_filtering_threshold

    def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        original_logits = logits.clone()
        adjusted_logits = original_logits.clone()
        index_positive = torch.where(labels != -1)[0]

        if self.interclass_filtering_threshold > 0:
            with torch.no_grad():
                dirty = logits > self.interclass_filtering_threshold
                dirty = dirty.float()
                mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
                mask.scatter_(1, labels[index_positive].unsqueeze(1), 0)
                dirty[index_positive] *= mask
                tensor_mul = 1 - dirty
            logits = logits * tensor_mul

        if index_positive.numel() == 0:
            return logits * self.s

        pos_labels = labels[index_positive]
        cos_y = logits[index_positive, pos_labels]

        logits_clone = logits[index_positive].clone()
        logits_clone[torch.arange(index_positive.size(0)), pos_labels] = -1e9
        max_other, _ = logits_clone.max(dim=1)

        h = 1.0 - (cos_y - max_other)
        m_i = self.m2 + self.alpha * h
        theta_y = torch.acos(cos_y.clamp(-1.0, 1.0))
        phi_y = torch.cos(self.m1 * theta_y + m_i) - self.m3

        # đảm bảo đồng biến: nếu phi_y < cos_y thì update
        # (tùy bạn muốn chắc chắn giữ thứ tự)
        mask_update = phi_y < cos_y
        final_phi = torch.where(mask_update, phi_y, cos_y)

        adjusted_logits[index_positive, pos_labels] = final_phi
        return adjusted_logits * self.s


class ArcFace(torch.nn.Module):
    """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
    """
    def __init__(self, s=64.0, margin=0.5):
        super(ArcFace, self).__init__()
        self.s = s
        self.margin = margin
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.theta = math.cos(math.pi - margin)
        self.sinmm = math.sin(math.pi - margin) * margin
        self.easy_margin = False


    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]

        with torch.no_grad():
            target_logit.arccos_()
            logits.arccos_()
            final_target_logit = target_logit + self.margin
            logits[index, labels[index].view(-1)] = final_target_logit
            logits.cos_()
        logits = logits * self.s   
        return logits


class CosFace(torch.nn.Module):
    def __init__(self, s=64.0, m=0.40):
        super(CosFace, self).__init__()
        self.s = s
        self.m = m

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]
        final_target_logit = target_logit - self.m
        logits[index, labels[index].view(-1)] = final_target_logit
        logits = logits * self.s
        return logits

In [None]:
%%writefile /kaggle/working/Arcface_torch/train_v2.py
import argparse
import logging
import os
from datetime import datetime

import numpy as np
import torch
from backbones import get_model
from dataset import get_dataloader
from losses import CombinedMarginLoss, CombinedDynamicMarginLoss
from lr_scheduler import PolynomialLRWarmup
from partial_fc_v2 import PartialFC_V2
from torch import distributed
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils.utils_callbacks import CallBackLogging, CallBackVerification
from utils.utils_config import get_config
from utils.utils_distributed_sampler import setup_seed
from utils.utils_logging import AverageMeter, init_logging
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook

assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \
we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."

try:
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    distributed.init_process_group("nccl")
except KeyError:
    rank = 0
    local_rank = 0
    world_size = 1
    distributed.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:12584",
        rank=rank,
        world_size=world_size,
    )


def main(args):

    # get config
    cfg = get_config(args.config)
    # global control random seed
    setup_seed(seed=cfg.seed, cuda_deterministic=False)

    torch.cuda.set_device(local_rank)

    os.makedirs(cfg.output, exist_ok=True)
    init_logging(rank, cfg.output)

    summary_writer = (
        SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
        if rank == 0
        else None
    )
    
    wandb_logger = None
    if cfg.using_wandb:
        import wandb
        # Sign in to wandb
        try:
            wandb.login(key=cfg.wandb_key)
        except Exception as e:
            print("WandB Key must be provided in config file (base.py).")
            print(f"Config Error: {e}")
        # Initialize wandb
        run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
        run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
        try:
            wandb_logger = wandb.init(
                entity = cfg.wandb_entity, 
                project = cfg.wandb_project, 
                sync_tensorboard = True,
                resume=cfg.wandb_resume,
                name = run_name, 
                notes = cfg.notes) if rank == 0 or cfg.wandb_log_all else None
            if wandb_logger:
                wandb_logger.config.update(cfg)
        except Exception as e:
            print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
            print(f"Config Error: {e}")
    train_loader = get_dataloader(
        cfg.rec,
        local_rank,
        cfg.batch_size,
        cfg.dali,
        cfg.dali_aug,
        cfg.seed,
        cfg.num_workers,
        cfg.image_size,
    )

    backbone = get_model(
        cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size, image_size = cfg.image_size).cuda()

    backbone = torch.nn.parallel.DistributedDataParallel(
        module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
        find_unused_parameters=True)
    backbone.register_comm_hook(None, fp16_compress_hook)

    backbone.train()
    # FIXME using gradient checkpoint if there are some unused parameters will cause error
    backbone._set_static_graph()

    margin_loss = CombinedDynamicMarginLoss(
        64,
        cfg.margin_list[0],
        cfg.margin_list[1],
        cfg.margin_list[2],
        cfg.interclass_filtering_threshold
    )

    if cfg.optimizer == "sgd":
        module_partial_fc = PartialFC_V2(
            margin_loss, cfg.embedding_size, cfg.num_classes,
            cfg.sample_rate, False)
        module_partial_fc.train().cuda()
        # TODO the params of partial fc must be last in the params list
        opt = torch.optim.SGD(
            params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
            lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)

    elif cfg.optimizer == "adamw":
        module_partial_fc = PartialFC_V2(
            margin_loss, cfg.embedding_size, cfg.num_classes,
            cfg.sample_rate, False)
        module_partial_fc.train().cuda()
        opt = torch.optim.AdamW(
            params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
            lr=cfg.lr, weight_decay=cfg.weight_decay)
    else:
        raise

    cfg.total_batch_size = cfg.batch_size * world_size
    cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
    cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch

    lr_scheduler = PolynomialLRWarmup(
        optimizer=opt,
        warmup_iters=cfg.warmup_step,
        total_iters=cfg.total_step)

    start_epoch = 0
    global_step = 0
    if cfg.resume:
        dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
        start_epoch = dict_checkpoint["epoch"]
        global_step = dict_checkpoint["global_step"]
        backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
        module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
        opt.load_state_dict(dict_checkpoint["state_optimizer"])
        lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
        del dict_checkpoint

    for key, value in cfg.items():
        num_space = 25 - len(key)
        logging.info(": " + key + " " * num_space + str(value))

    callback_verification = CallBackVerification(
        val_targets=cfg.val_targets, rec_prefix=cfg.rec, 
        summary_writer=summary_writer, wandb_logger = wandb_logger
    )
    callback_logging = CallBackLogging(
        frequent=cfg.frequent,
        total_step=cfg.total_step,
        batch_size=cfg.batch_size,
        start_step = global_step,
        writer=summary_writer
    )

    loss_am = AverageMeter()
    amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)

    for epoch in range(start_epoch, cfg.num_epoch):

        if isinstance(train_loader, DataLoader):
            train_loader.sampler.set_epoch(epoch)
        for _, (img, local_labels) in enumerate(train_loader):
            global_step += 1
            local_embeddings = backbone(img)
            loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels)

            if cfg.fp16:
                amp.scale(loss).backward()
                if global_step % cfg.gradient_acc == 0:
                    amp.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
                    amp.step(opt)
                    amp.update()
                    opt.zero_grad()
            else:
                loss.backward()
                if global_step % cfg.gradient_acc == 0:
                    torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
                    opt.step()
                    opt.zero_grad()
            lr_scheduler.step()

            with torch.no_grad():
                if wandb_logger:
                    wandb_logger.log({
                        'Loss/Step Loss': loss.item(),
                        'Loss/Train Loss': loss_am.avg,
                        'Process/Step': global_step,
                        'Process/Epoch': epoch
                    })
                    
                loss_am.update(loss.item(), 1)
                callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)

                if global_step % cfg.verbose == 0 and global_step > 0:
                    callback_verification(global_step, backbone)
                    if cfg.save_all_states:
                        checkpoint = {
                            "epoch": epoch,
                            "global_step": global_step,
                            "state_dict_backbone": backbone.module.state_dict(),
                            "state_dict_softmax_fc": module_partial_fc.state_dict(),
                            "state_optimizer": opt.state_dict(),
                            "state_lr_scheduler": lr_scheduler.state_dict()
                        }
                        # Lưu checkpoint cho resume (luôn ghi đè)
                        torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
                        # Lưu checkpoint lịch sử (chỉ lưu state_dict backbone, không lưu trạng thái khác)
                        torch.save(backbone.module.state_dict(), os.path.join(cfg.output, f"checkpoint_step_{global_step}_gpu_{rank}.pt"))

        if rank == 0:
            path_module = os.path.join(cfg.output, "model.pt")
            torch.save(backbone.module.state_dict(), path_module)

            if wandb_logger and cfg.save_artifacts:
                artifact_name = f"{run_name}_E{epoch}"
                model = wandb.Artifact(artifact_name, type='model')
                model.add_file(path_module)
                wandb_logger.log_artifact(model)
                
        if cfg.dali:
            train_loader.reset()

    if rank == 0:
        path_module = os.path.join(cfg.output, "model.pt")
        torch.save(backbone.module.state_dict(), path_module)
        
        if wandb_logger and cfg.save_artifacts:
            artifact_name = f"{run_name}_Final"
            model = wandb.Artifact(artifact_name, type='model')
            model.add_file(path_module)
            wandb_logger.log_artifact(model)



if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    parser = argparse.ArgumentParser(
        description="Distributed Arcface Training in Pytorch")
    parser.add_argument("config", type=str, help="py config file")
    main(parser.parse_args())

In [None]:
%%writefile /kaggle/working/Arcface_torch/backbones/__init__.py
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
from .iresnet_lite import iresnet18_lite, iresnet34_lite, iresnet50_lite, iresnet100_lite, iresnet200_lite
from .iresnet_plus import iresnet18_plus, iresnet34_plus, iresnet50_plus, iresnet100_plus, iresnet200_plus
def get_model(name, **kwargs):
    # resnet
    if name == "r18":
        return iresnet18(False, **kwargs)
    elif name == "r34":
        return iresnet34(False, **kwargs)
    elif name == "r50":
        return iresnet50(False, **kwargs)
    elif name == "r100":
        return iresnet100(False, **kwargs)
    elif name == "r200":
        return iresnet200(False, **kwargs)
    elif name == "r18_lite":
        return iresnet18_lite(False, **kwargs)
    elif name == "r34_lite":
        return iresnet34_lite(False, **kwargs)
    elif name == "r50_lite":
        return iresnet50_lite(False, **kwargs)
    elif name == "r100_lite":
        return iresnet100_lite(False, **kwargs)
    elif name == "r200_lite":
        return iresnet200_lite(False, **kwargs)
    elif name == "r18_plus":
        return iresnet18_plus(False, **kwargs)
    elif name == "r34_plus":
        return iresnet34_plus(False, **kwargs)
    elif name == "r50_plus":
        return iresnet50_plus(False, **kwargs)
    elif name == "r100_plus":
        return iresnet100_plus(False, **kwargs)
    elif name == "r200_plus":
        return iresnet200_plus(False, **kwargs)
    else:
        raise ValueError()

In [None]:
def analyze_lst_file(lst_file_path):
    """
    Đọc file .lst và đếm số lượng ảnh cùng số lượng class.
    
    Args:
        lst_file_path (str): Đường dẫn đến file .lst
        
    Returns:
        tuple: (total_images, num_classes)
            - total_images: Tổng số ảnh (số dòng trong file)
            - num_classes: Số lượng class (số nhãn duy nhất)
    """
    # Khởi tạo biến
    total_images = 0
    labels = set()  # Sử dụng set để lưu các nhãn duy nhất
    
    try:
        # Mở và đọc file
        with open(lst_file_path, 'r') as f:
            for line in f:
                # Tách dòng thành các cột bằng dấu tab
                columns = line.strip().split('\t')
                if len(columns) >= 2:  # Đảm bảo dòng có ít nhất 2 cột (ID và label)
                    total_images += 1
                    label = int(float(columns[2]))  # Chuyển nhãn thành số nguyên
                    labels.add(label)
        
        num_classes = len(labels)
        return total_images, num_classes
    
    except FileNotFoundError:
        print(f"Không tìm thấy file: {lst_file_path}")
        return 0, 0
    except Exception as e:
        print(f"Lỗi khi đọc file: {e}")
        return 0, 0

lst_file = '/kaggle/input/ms1m-retinaface-t1/ms1m-retinaface-t1/train.lst'
total_images, num_classes = analyze_lst_file(lst_file)
print(f"Num-image: {total_images}")
print(f"Num-class: {num_classes}")

In [1]:
%%writefile /kaggle/working/Arcface_torch/configs/ms1mv3_r100_plus.py
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G  tmpfs /train_tmp

config = edict()
config.margin_list = (1.0, 0.5, 0.0)
config.network = "r100_plus"
config.resume = True
config.output = 'ms1mv3_112x112_r100_plus_workdirs'
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1
config.verbose = 4000
config.dali = False
config.save_all_states = True
config.image_size = 112

config.rec = "/kaggle/input/ms1m-retinaface-t1/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 30
config.warmup_epoch = 0
config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]

Overwriting /kaggle/working/Arcface_torch/configs/ms1mv3_r100_plus.py


In [None]:
%cd /kaggle/working/Arcface_torch
!torchrun --nproc_per_node=2 train_v2.py configs/ms1mv3_r100_plus.py

/kaggle/working/Arcface_torch
W0603 16:25:56.631000 122 torch/distributed/run.py:793] 
W0603 16:25:56.631000 122 torch/distributed/run.py:793] *****************************************
W0603 16:25:56.631000 122 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0603 16:25:56.631000 122 torch/distributed/run.py:793] *****************************************
  class DALIWarper(object):
  class DALIWarper(object):
2025-06-03 16:26:05.732220: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-03 16:26:05.732219: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has

## Evaluation with dataset

In [None]:
%%writefile /kaggle/working/Arcface_torch/dataset_evaluation.py
import logging
import os
import argparse
import torch

from backbones import get_model
from eval import verification


device = torch.device('cuda' if torch.cuda.is_available() else "cpu")


def evaluate_dataset(bin_path: str, model: torch.nn.Module, image_size=(112, 112)):
    """
    Evaluate the dataset and log the results.
    """

    dataset_name = os.path.basename(bin_path).split('.')[0]

    # Load dataset
    dataset = verification.load_bin(bin_path, image_size)
    dataset = ([img.to(device) for img in dataset[0]], dataset[1])
    
    model.to(device)
    model.eval()
    # Perform evaluation
    acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(dataset, model, 1, 10)

    # Log results
    # Log results
    print(f"[{dataset_name}] XNorm: {xnorm:.6f}")
    print(f"[{dataset_name}] Accuracy (Flip): {acc2:.5f} ± {std2:.5f}")
    print(f"[{dataset_name}] Highest Accuracy: {acc1:.5f}")



    return acc1, std1, acc2, std2, xnorm, embeddings_list

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate a face recognition dataset.")
    
    # Command-line arguments
    parser.add_argument("--image_size", type=int, nargs=2, default=(64, 64), help="Image size (width height).")
    parser.add_argument("--model_name", type=str, default = 'r100', help="backbone model name.")
    parser.add_argument("--model_path", type=str, default = 'weights/glint360k_cosface_r100_fp16_0.1/backbone.pth', help="backbone model path.")
    args = parser.parse_args()

    model = get_model(args.model_name, fp16=False, image_size = args.image_size[0])
    model.load_state_dict(torch.load(args.model_path, map_location=device, weights_only = True))

    bin_folder = '/kaggle/working/VN-celeb-mini'
    for bin_name in os.listdir(bin_folder):
        bin_path = os.path.join(bin_folder, bin_name)
        if bin_path.endswith('bin'):
            print(bin_name)
            evaluate_dataset(bin_path, model, tuple(args.image_size))
            print("__"*20)

In [None]:
%cd /kaggle/working/Arcface_torch
!python dataset_evaluation.py\
--model_path /kaggle/working/Arcface_torch/VN-celeb-mini_112x112_r18_plus_workdirs_ver2_img_size_128/model.pt\
--model_name 'r18_plus'\
--image_size 128 128

In [None]:
%%writefile /kaggle/working/Arcface_torch/calc_threshold.py
import pickle
import torch
import mxnet as mx
from mxnet import ndarray as nd
import numpy as np
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from models.Recognition.Arcface_torch.backbones import get_model
from tqdm import tqdm
import torch.nn.functional as F

# Thiết bị
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

# Tải mô hình
net = get_model("r100", fp16=False)
net.load_state_dict(torch.load("models/Recognition/Arcface_torch/weights/glint360k_cosface_r100_fp16_0.1/backbone.pth", map_location=device))
net.to(device)
net.eval()

# Hàm tải dữ liệu từ file .bin với chuẩn hóa
def load_bin(path, image_size):
    try:
        with open(path, 'rb') as f:
            bins, issame_list = pickle.load(f)  # py2
    except UnicodeDecodeError as e:
        with open(path, 'rb') as f:
            bins, issame_list = pickle.load(f, encoding='bytes')  # py3
    data_list = []
    for flip in [0, 1]:
        data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
        data_list.append(data)
    for idx in range(len(issame_list) * 2):
        _bin = bins[idx]
        img = mx.image.imdecode(_bin)
        if img.shape[1] != image_size[0]:
            img = mx.image.resize_short(img, image_size[0])
        img = nd.transpose(img, axes=(2, 0, 1))  # [C, H, W]
        img = img.astype(np.float32) / 255.0  # Chuẩn hóa về [0, 1]
        for flip in [0, 1]:
            if flip == 1:
                img = mx.ndarray.flip(data=img, axis=2)
            data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
        if idx % 1000 == 0:
            print('loading bin', idx)
    print(data_list[0].shape)
    return data_list, issame_list  # Dùng toàn bộ dataset

# Tải dữ liệu (toàn bộ lfw.bin: 12000 ảnh, 6000 cặp)
data_list, issame_list = load_bin('/kaggle/input/ms1m-retinaface-t1/ms1m-retinaface-t1/lfw.bin', (112, 112))

# Hàm tính embedding
def get_embeddings(data_list, model, batch_size=512):
    num_samples = data_list[0].shape[0]  # Số mẫu đầy đủ (12000)
    embeddings_flip0 = []
    
    model.eval()
    with torch.no_grad():
        for i in tqdm(range(0, num_samples, batch_size), desc="Processing flip=0"):
            batch = data_list[0][i:i + batch_size].to(device)
            emb = model(batch)
            embeddings_flip0.append(emb)
    
    embeddings_flip0 = torch.cat(embeddings_flip0, dim=0)  # Toàn bộ embedding
    return embeddings_flip0

# Tính embedding
embeddings = get_embeddings(data_list, net, batch_size=512)
print("Embeddings shape:", embeddings.shape)  # Dự kiến: [12000, 512]

# Hàm tính khoảng cách
def compute_distances(embeddings, issame_list):
    l2_distances = []
    cosine_distances = []
    labels = []

    for i in range(len(issame_list)):  # Lặp qua toàn bộ 6000 cặp
        emb1 = embeddings[2 * i]
        emb2 = embeddings[2 * i + 1]
        l2_dist = torch.norm(emb1 - emb2, p=2).item()
        l2_distances.append(l2_dist)
        cosine_sim = F.cosine_similarity(emb1.unsqueeze(0), emb2.unsqueeze(0)).item()
        cosine_dist = 1 - cosine_sim
        cosine_distances.append(cosine_dist)
        labels.append(issame_list[i])

    l2_distances = np.array(l2_distances)
    cosine_distances = np.array(cosine_distances)
    labels = np.array(labels)
    return l2_distances, cosine_distances, labels

# Tính khoảng cách
l2_distances, cosine_distances, labels = compute_distances(embeddings, issame_list)

print("L2 distances shape:", l2_distances.shape)  # Dự kiến: (6000,)
print("Cosine distances shape:", cosine_distances.shape)  # Dự kiến: (6000,)
print("Labels shape:", labels.shape)  # Dự kiến: (6000,)

# Phân tích L2
l2_same_class_distances = l2_distances[labels == True]
l2_diff_class_distances = l2_distances[labels == False]
print("Mean distance L2 (same class):", l2_same_class_distances.mean())
print("Mean distance L2 (diff class):", l2_diff_class_distances.mean())
l2_threshold = (l2_same_class_distances.mean() + l2_diff_class_distances.mean()) / 2
print("Simple L2 threshold:", l2_threshold)

# Phân tích Cosine
cosine_same_class_distances = cosine_distances[labels == True]
cosine_diff_class_distances = cosine_distances[labels == False]
print("Mean distance Cosine (same class):", cosine_same_class_distances.mean())
print("Mean distance Cosine (diff class):", cosine_diff_class_distances.mean())
cosine_threshold = (cosine_same_class_distances.mean() + cosine_diff_class_distances.mean()) / 2
print("Simple Cosine threshold:", cosine_threshold)

# Dùng ROC Curve cho L2
fpr, tpr, thresholds = roc_curve(labels, l2_distances, pos_label=0)
roc_auc = auc(fpr, tpr)
optimal_idx = np.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]
print("Optimal L2 threshold (ROC):", optimal_threshold)
print("L2 AUC:", roc_auc)

# Vẽ ROC Curve cho L2
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for L2 Distance')
plt.legend(loc="lower right")
plt.show()

# Dùng ROC Curve cho Cosine
fpr, tpr, thresholds = roc_curve(labels, cosine_distances, pos_label=0)
roc_auc = auc(fpr, tpr)
optimal_idx = np.argmax(tpr - fpr)
optimal_threshold = thresholds[optimal_idx]
print("Optimal Cosine threshold (ROC):", optimal_threshold)
print("Cosine AUC:", roc_auc)

# Vẽ ROC Curve cho Cosine
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve for Cosine Distance')
plt.legend(loc="lower right")
plt.show()

In [None]:
!python calc_threshold.py

In [None]:
check_point = torch.load('/kaggle/working/Arcface_torch/ms1m_retinaface_t1_workdirs/checkpoint_gpu_1.pt')
print(check_point['epoch'])
print(check_point['global_step'])

In [None]:
cd /kaggle/working/Arcface_torch/ms1m_retinaface_t1_workdirs

In [None]:
!zip -r /kaggle/working/faces_webface_112x112_r50_se_workdirs.zip /kaggle/working/Arcface_torch/faces_webface_112x112_r50_se_workdirs