# train

# import

In [8]:
import io
import os
import sys
import time
import json
import math
import datetime
import argparse
import numpy as np
from pathlib import Path
from collections import defaultdict, deque

import torch
import torch.distributed as dist
from torch import nn, einsum
from torch.nn import functional as F
import torch.backends.cudnn as cudnn
import torch.utils.checkpoint as checkpoint

from typing import Iterable, Optional
from timm.models import create_model
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
from timm.data import Mixup,create_transform
from timm.models.registry import register_model
from timm.models.layers import DropPath, trunc_normal_
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.utils import accuracy, ModelEma,NativeScaler, get_state_dict, ModelEma

from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader
from torch.optim.lr_scheduler import _LRScheduler
from functools import partial
from einops import rearrange

In [5]:
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## utils

In [21]:
def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()

In [None]:
def cal_flops_params_with_fvcore(model, inputs):
    """
    print model's flops
    """
    from fvcore.nn import FlopCountAnalysis, parameter_count_table, parameter_count
    flops = FlopCountAnalysis(model, inputs)
    params = parameter_count(model)
    print('flops(fvcore): %f M' % (flops.total()/1000**2))
    print('params(fvcore): %f M' % (params['']/1000**2))

In [4]:
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [
        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
    ]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

In [25]:
def save_on_master(*args, **kwargs):
    """
    save model
    """
    if is_main_process():
        torch.save(*args, **kwargs)

In [22]:
def get_rank():
    """
    get gpu's num
    """
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()

In [3]:
def set_seed(seed = 1234):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)

In [26]:
class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """
    def __init__(self, optimizer, total_iters, last_epoch=-1):

        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]

In [12]:
def build_dataset(is_train, args):
    transform = build_transform(is_train, args)
    if not args.use_mcloader:
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
        dataset = datasets.ImageFolder(root, transform=transform)
    else:
        from mcloader import ClassificationDataset
        dataset = ClassificationDataset(
            'train' if is_train else 'val',
            pipeline=transform
        )
    nb_classes = 2

    return dataset, nb_classes


def build_transform(is_train, args):
    """
    is_train
    """
    # 单独设置
    
    # 随机改变图像的色调
    # hue_change = transforms.ColorJitter(hue=0.5)
   
    if is_train:
        # 随机改变图像的亮度
        brightness_change = transforms.ColorJitter(brightness=0.5)
        # 随机改变图像的对比度
        contrast_change = transforms.ColorJitter(contrast=0.5)
        train_transform = transforms.Compose([
                    transforms.ToTensor(),  # 图片转张量，同时归一化0-255 ---》 0-1
                    transforms.Normalize(norm_mean, norm_std),  # 标准化均值为0标准差为1
                    brightness_change,
                    contrast_change,
                    transforms.RandomHorizontalFlip(p=0.5)#随机水平翻转

                ])
    else:
        train_transform = transforms.Compose([
                    transforms.ToTensor(),  # 图片转张量，同时归一化0-255 ---》 0-1
                    transforms.Normalize(norm_mean, norm_std),  # 标准化均值为0标准差为1
                    # brightness_change,
                    # contrast_change,
                    transforms.RandomHorizontalFlip(p=0.5)#随机水平翻转
                ])
    
    return train_transform

In [13]:
def build_weight(path='',nu_class=1000):
    """
    set 
    """
    ud_map={}
    trian_path = glob.glob(path+'/train/*/*')
    print()
    for i in trian_path:
        i = i.split('/')
        if i[-2] in ud_map:
            ud_map[i[-2]]+=1
        else:
            ud_map.update({i[-2]:1})
    tol = len(trian_path)
    weight = []
    for k in ud_map:
        for i in range(ud_map[k]):
            weight.append(tol/ud_map[k])
    return weight

In [14]:
class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        log_msg = [
            header,
            '[{0' + space_fmt + '}/{1}]',
            'eta: {eta}',
            '{meters}',
            'time: {time}',
            'data: {data}'
        ]
        if torch.cuda.is_available():
            log_msg.append('max mem: {memory:.0f}')
        log_msg = self.delimiter.join(log_msg)
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / len(iterable)))



# train&&eval

In [27]:
def train_one_epoch(model,class_criterion,nce_criterion
                    data_loader, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,warmup,warmupstep:int, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True):
    model.train(set_training_mode)
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    if epoch < warmupstep:
            warmup_scheduler.step()
    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        with torch.cuda.amp.autocast():
            cl,log,lab,fe,attn = model(samples)
            loss_cls = class_criterion(cl,targets)
            nce_loss = nce_criterion(log,lab)
            sim_loss = similiarity()
            aux_loss1= aux_similiarity(fe[0][0],fe[3][0])
            aux_loss2= aux_similiarity(fe[1][0],fe[4][0])
            aux_loss3= aux_similiarity(fe[2][0],fe[5][0])
            #cancer,logits,labels,features,attn
            loss = criterion(samples, outputs, targets)

        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        loss_scaler(loss, optimizer, clip_grad=max_norm,
                    parameters=model.parameters(), create_graph=is_second_order)

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}



In [19]:
def evaluate(data_loader, model, device):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    for images, target in metric_logger.log_every(data_loader, 10, header):
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast():
            output = model(images)
            loss = criterion(output, target)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        batch_size = images.shape[0]
        metric_logger.update(loss=loss.item())
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        # metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print('* Acc@1 {top1.global_avg:.3f}  loss {losses.global_avg:.3f}'
          .format(top1=metric_logger.acc1, losses=metric_logger.loss))

    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

## 吞吐量

In [20]:
@torch.no_grad()
def throughput(data_loader, model, logger):
    model.eval()

    for idx, (images, _) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        batch_size = images.shape[0]
        for i in range(50):
            model(images)
        torch.cuda.synchronize()
        logger.info(f"throughput averaged with 30 times")
        tic1 = time.time()
        for i in range(30):
            model(images)
        torch.cuda.synchronize()
        tic2 = time.time()
        logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
        return

# build

## embeding patch

In [5]:
NORM_EPS = 1e-5
class PatchEmbed(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride=1,
                 mode = "V"
                ):
        super(PatchEmbed, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        if stride == 4 and mode == "V":
            self.avgpool = nn.AvgPool2d((4, 32), stride=4, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif stride == 2 and mode == "V":
            self.avgpool = nn.AvgPool2d((2, 16), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif stride == 2 and mode == "H":
            self.avgpool = nn.AvgPool2d((64, 2), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif stride == 1 and mode == "H":
            self.avgpool = nn.AvgPool2d((32, 1), stride=1, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif in_channels != out_channels:
            self.avgpool = nn.Identity()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        else:
            self.avgpool = nn.Identity()
            self.conv = nn.Identity()
            self.norm = nn.Identity()

    def forward(self, x):
        return self.norm(self.conv(self.avgpool(x)))

## product QKV

In [6]:
class E_PQKV(nn.Module):
    """
    product QKV
    """
    def __init__(self, dim, head_dim=32, out_dim=None, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)
        q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        k = self.k(x)
        k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
        v = self.v(x)
        v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        return q,k,v
    

## Efficient Self attention

In [7]:
class E_MHA(nn.Module):
    """
    Efficient Multi-Head Self Attention
    """
    def __init__(self, head_dim=32, qk_scale=None,attn_drop=0):
        super().__init__()
        self.scale = qk_scale or head_dim ** -0.5
        self.attn_drop = nn.Dropout(attn_drop)
    def forward(self, q,k,v,b,n,c):
        attn = (q @ k) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        attn = (attn @ v).transpose(1, 2).reshape(b,n,c)
        return attn

## MLP and res

In [9]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x1= x 
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        x+=x1
        return x

## Efficient Multi-Head Double-Viwe Attention

In [10]:
class E_MTDVA(nn.Module):
    """
    Efficient Multi-Head Double-Viwe Attention
    x_mh,x_ah,x_mv,x_av->forward
    """
    def __init__(self, dim, head_dim=32, out_dim=None, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super(E_MTDVA,self).__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = E_PQKV( dim, head_dim, out_dim, qkv_bias, qk_scale,attn_drop, proj_drop, sr_ratio)
        self.sa = E_MHA(head_dim,qk_scale,attn_drop)
        self.ca = E_MHA(head_dim,qk_scale,attn_drop)
        self.proj = Mlp(self.dim,self.dim*4, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)
        self.norm = nn.LayerNorm(dim)
        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
        self.is_bn_merged = False

    def merge_bn(self, pre_bn):
        merge_pre_bn(self.q, pre_bn)
        if self.sr_ratio > 1:
            merge_pre_bn(self.k, pre_bn, self.norm)
            merge_pre_bn(self.v, pre_bn, self.norm)
        else:
            merge_pre_bn(self.k, pre_bn)
            merge_pre_bn(self.v, pre_bn)
        self.is_bn_merged = True
    # def pro_pkv(x_mh)

    def forward(self, x_mh,x_ah,x_mv,x_av):
        
        h = torch.cat([x_mh,x_ah],dim = 1)
        v = torch.cat([x_mv,x_av],dim = 1)
        B, N, C = h.shape
        #selfatt
        h = self.norm(h)
        v = self.norm(v)
        q_h,k_h,v_h = self.qkv(h)
        q_v,k_v,v_v = self.qkv(v)
        attn_h = self.sa(q_h,k_h,v_h, B, N, C)
        attn_v = self.sa(q_v,k_v,v_v, B, N, C)
        h+=attn_h
        v+=attn_v
        
        ##ffn
        h = self.norm(h)
        v = self.norm(v)
        h = self.proj(h)
        v = self.proj(v)
        #crossatt
        
        h = self.norm(h)
        v = self.norm(v)
        q_h,k_h,v_h = self.qkv(h)
        q_v,k_v,v_v = self.qkv(v)
        attn_h = self.sa(q_h,k_h,v_h, B, N, C)
        attn_v = self.sa(q_v,k_v,v_v, B, N, C)
        h+=attn_h
        v+=attn_v
        h = self.norm(h)
        v = self.norm(v)
        q_h,k_h,v_h = self.qkv(h)
        q_v,k_v,v_v = self.qkv(v)
        attn_h=self.ca(q_h,k_v,v_h, B, N, C)
        attn_v=self.ca(q_h,k_v,v_h, B, N, C)
        h+=attn_h
        v+=attn_v
        ##ffn
        h = self.norm(h)
        v = self.norm(v)
        h = self.proj(h)
        v = self.proj(v)
        #output
        attn_h = h.mean(1)
        attn_v = v.mean(1)
        x = torch.concat([attn_h,attn_v],dim=1)
        return x

In [11]:
class ShardExamine(nn.Module):
    
    def __init__(self,dim,
                 in_channels,
                 out_channels,
                 out_putdim,
                 num_head=64, 
                 qkv_bias=False, 
                 attn_drop=0.1, 
                 proj_drop=0.1,
                 stride=1
                ):
        super(ShardExamine,self).__init__()
        # super(LocalCoccurrence, self).__init__()
        self.PEV = PatchEmbed(in_channels,out_channels,stride*2,mode = "V")
        self.PEH = PatchEmbed(in_channels,out_channels,stride,mode = "H")
        self.proj = nn.Linear(out_channels*2, out_putdim)
        self.attn  = E_MTDVA(dim,num_head,None,True,qkv_bias,attn_drop,proj_drop)

    def forward(self, u_m, u_a):
        u_mv = self.PEV(u_m)
        u_av = self.PEV(u_a)
        u_m = self.PEH(u_m)
        u_a = self.PEH(u_a)
        B,C,H,W = u_mv.shape
        L = H*W
        dim = C

        u_mv = u_mv.reshape(B,dim,L).permute(0,2,1)
        u_av = u_av.reshape(B,dim,L).permute(0,2,1)
        u_m  = u_m.reshape(B,dim,L).permute(0,2,1)
        u_a  = u_a.reshape(B,dim,L).permute(0,2,1)


        x = self.attn(u_m, u_a,u_mv, u_av)# x_mh,x_ah,x_mv,x_av
        # x_m, x_a = self.attn(u_m, u_a)
        # x_mv, x_av = self.attn(u_mv, u_a)
        # gap_m = x_m.mean(1)
        # x = x.mean(1)
        x = self.proj(x)
        
        return  x

In [12]:
class basNet(nn.Module):
    """
    only two viwe
    """
    def load_pretrain(self, ):
        return

    def __init__(self,K=8192, m=0.999, T=0.07,dim=128,nc=512):
        super(basNet, self).__init__()
        # self.output_type = ['inference', 'loss']

        self.K = K
        self.m = m
        self.T = T
        self.encoder_q = timm.create_model ('efficientnetv2_m',#m 512 s 256
                                          pretrained=False, 
                                          drop_rate = 0.2, 
                                          drop_path_rate = 0.1,
                                          num_classes=nc
                                         )
        self.encoder_k = timm.create_model ('efficientnetv2_m',
                                          pretrained=False, 
                                          drop_rate = 0.2, 
                                          drop_path_rate = 0.1,
                                          num_classes=nc
                                         )
        
        
        # dim_mlp = self.encoder_q.fc.weight.shape[1]
        self.encoder_q.fc = nn.Sequential(
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Linear(512, 128),
        )
        self.encoder_k.fc =nn.Sequential(
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Linear(512, 128),
        )
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

        # dim = 1280

        # self.lc  = LocalCoccurrence(dim)
        # self.gl  = GlobalConsistency(dim)
        self.mlp = nn.Sequential(
            nn.LayerNorm(1280),
            nn.Linear(1280, 512),
            nn.GELU(),
            nn.Linear(512, 128),
        )#<todo> mlp needs to be deep if backbone is strong?
        self.att = ShardExamine(128,512,128,256)
        self.cancer = nn.Linear(128,1)
        # self.labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        
    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
            
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        # keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr
        
    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]
    
    
    def forward(self, x):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """
        
        batch_size,C,H,W = x.shape
        x = x.reshape(-1, C, H, W)
        x_m =torch.split(x,H//2,dim=-1)[0]

        # print(x_m.shape)
        x_a =torch.split(x,H//2,dim=-1)[1]
        features=[]
        def hook(module, input, output):
            features.append(input)
            return None
        # x = (x - self.mean) / self.std
        self.encoder_q.conv_head.register_forward_hook(hook)
        self.encoder_q.blocks[5].register_forward_hook(hook)
        self.encoder_q.blocks[6].register_forward_hook(hook)
        
        
        # compute query features
        q_m = self.encoder_q(x_m)  # queries: NxC
        q_a = self.encoder_q(x_a)  # queries: NxC
        
        # print(features[0][2].shape)
        # print(features[1][2].shape)
        attn= self.att(features[2][0],features[5][0])
        
        q_m = nn.functional.normalize(q_m, dim=1)
        q_a = nn.functional.normalize(q_a, dim=1)
        
        last = torch.cat([q_a,q_m,attn],-1)
        last = self.mlp(last)
        cancer = self.cancer(last).reshape(-1)
        cancer = torch.sigmoid(cancer)
        q_m = self.encoder_q.fc(q_m)
        q_a = self.encoder_q.fc(q_a)
        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder
            k_m = self.encoder_k(x_m)  # keys: NxC
            k_m = nn.functional.normalize(k_m, dim=1)
            k_m = self.encoder_k.fc(k_m) 
            k_a = self.encoder_k(x_a)  # keys: NxC
            k_a = nn.functional.normalize(k_a, dim=1)
            k_a = self.encoder_k.fc(k_a)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos1 = torch.einsum("nc,nc->n", [q_m, k_m]).unsqueeze(-1)
        l_pos2 = torch.einsum("nc,nc->n", [q_m, k_a]).unsqueeze(-1)
        l_pos3 = torch.einsum("nc,nc->n", [q_a, k_m]).unsqueeze(-1)
        l_pos4 = torch.einsum("nc,nc->n", [q_a, k_a]).unsqueeze(-1)
        # negative logits: NxK
        l_neg_m = torch.einsum("nc,ck->nk", [q_m, self.queue.clone().detach()])
        l_neg_a = torch.einsum("nc,ck->nk", [q_a, self.queue.clone().detach()])
        # logits: Nx(1+K)
        logits = torch.cat([l_pos1,l_pos2,l_pos3,l_pos4, l_neg_m,l_neg_a], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[1], dtype=torch.long).cuda()
        labels[1],labels[2],labels[3]=1,2,3
        
        
        # dequeue and enqueue
        self._dequeue_and_enqueue(k_m)
        self._dequeue_and_enqueue(k_a)
        
        
        
        return cancer,logits,labels,features,attn
        # return features


# loss

## 辅助损失函数

In [29]:
def aux_similiarity(x1,x2):#features[0/1/2][0],features[3/4/5][0]
    p12 = (x1*x2).sum(-1)
    p1 = torch.sqrt((x1*x1).sum(-1))
    p2 = torch.sqrt((x2*x2).sum(-1))
    s = sum(nn.Flatten()(p12/(p1*p2+1e-6)))/2
    s = s.sum()/len(s)
    return s

# 相似度损失函数

In [31]:
def similiarity(x):#features[0/1/2][0],features[3/4/5][0]
    x1 = torch.split(x,len(x)//2,dim=-1)[0]
    x2 = torch.split(x,len(x)//2,dim=-1)[1]
    p12 = (x1*x2).sum(-1)
    p1 = torch.sqrt((x1*x1).sum(-1))
    p2 = torch.sqrt((x2*x2).sum(-1))
    s = p12/(p1*p2+1e-6)
    return s

## 对比学习损失函数

In [24]:
# def infoNCE():
#     criterion = nn.CrossEntropyLoss().cuda(args.gpu)
# NEC_loss = criterion(output, target)#logits,labels

## 损失函数比例待定

In [23]:
# class_criterion = nn.BCEWithLogitsLoss().cuda(args.gpu)

# Sampler

In [None]:
class RASampler(torch.utils.data.Sampler):
    """Sampler that restricts data loading to a subset of the dataset for distributed,
    with repeated augmentation.
    It ensures that different each augmented version of a sample will be visible to a
    different process (GPU)
    Heavily based on torch.utils.data.DistributedSampler
    """

    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
        self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
        self.shuffle = shuffle

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        if self.shuffle:
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices = [ele for ele in indices for i in range(3)]
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices[:self.num_selected_samples])

    def __len__(self):
        return self.num_selected_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

In [None]:
class WeightedRandomSamplerDDP(DistributedSampler):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).

    Args:
        data_set: Dataset used for sampling.
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_replicas (int, optional): Number of processes participating in
            distributed training. By default, :attr:`world_size` is retrieved from the
            current distributed group.
        rank (int, optional): Rank of the current process within :attr:`num_replicas`.
            By default, :attr:`rank` is retrieved from the current distributed
            group.
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
        generator (Generator): Generator used in sampling.
    """
    # weights: Tensor
    # num_samples: int
    # replacement: bool

    def __init__(self, data_set, weights: Sequence[float], num_replicas: int, rank: int,
                 replacement: bool = True, generator=None) -> None:
        super(WeightedRandomSamplerDDP, self).__init__(data_set, num_replicas, rank)
        if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
                num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(num_samples))
        if not isinstance(replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(replacement))
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
        self.replacement = replacement
        self.generator = generator
        self.num_replicas = num_replicas
        self.rank = rank
        self.weights = self.weights[self.rank::self.num_replicas]
        self.num_samples = self.num_samples // self.num_replicas

    def __iter__(self):
        rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
        rand_tensor =  self.rank + rand_tensor * self.num_replicas
        return iter(rand_tensor.tolist())

    def __len__(self):
        return self.num_samples

# train_main

In [None]:
def main(args):
    init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
    dataset_val, _ = build_dataset(is_train=False, args=args)
    tmp = build_weight(args.data_path)
    if args.distributed:
        num_tasks = get_world_size()
        global_rank = get_rank()
        if args.repeated_aug:
            sampler_train = WeightedRandomSamplerDDP(
                dataset_train, tmp ,num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        else:
            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
            )
        if args.dist_eval:
            if len(dataset_val) % num_tasks != 0:
                print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
                      'This will slightly alter validation results as extra duplicate entries are added to achieve '
                      'equal num of samples per-process.')
            sampler_val = torch.utils.data.DistributedSampler(
                dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
        else:
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
    else:
        
        sampler_train = torch.utils.data.WeightedRandomSampler(tmp,args.nb_classes*len(dataset_train))
        sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val, sampler=sampler_val,
        batch_size=250,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=False
    )
    # 细粒度分类不需要
    # mixup_fn = None
    # mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    # if mixup_active:
    #     mixup_fn = Mixup(
    #         mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
    #         prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
    #         label_smoothing=args.smoothing, num_classes=args.nb_classes)

    print(f"Creating model: {args.model}")
    model = basNet()

    if not args.distributed or args.rank == 0:
        print(model)
        input_tensor = torch.zeros((1, 3, 224, 224), dtype=torch.float32)
        model.eval()
        cal_flops_params_with_fvcore(model, input_tensor)

    model.to(device)
    model_ema = None


    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
        model_without_ddp = model.module
    else:
        model_without_ddp = model

    linear_scaled_lr = args.lr * args.batch_size * get_world_size() / 512.0

    args.lr = linear_scaled_lr
    optimizer = create_optimizer(args, model_without_ddp)

    loss_scaler = NativeScaler()

    lr_scheduler, _ = create_scheduler(args, optimizer)

    class_criterion = nn.BCEWithLogitsLoss().cuda(args.gpu)
    NCE_criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    if not args.output_dir:
        args.output_dir = args.model
        if is_main_process():
            import os
            if not os.path.exists(args.model):
                os.mkdir(args.model)

    output_dir = Path(args.output_dir)
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        if 'model' in checkpoint:
            model_without_ddp.load_state_dict(checkpoint['model'])
        else:
            model_without_ddp.load_state_dict(checkpoint)
        if not args.finetune:
            if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer'])
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
                if not args.finetune:
                    args.start_epoch = checkpoint['epoch'] + 1
                if 'scaler' in checkpoint:
                    loss_scaler.load_state_dict(checkpoint['scaler'])

    if args.eval:
        if hasattr(model.module, "merge_bn"):
            print("Merge pre bn to speedup inference.")
            model.module.merge_bn()
        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        return

    if args.throughout:
        from logger import create_logger
        logger = create_logger(output_dir=output_dir, dist_rank=get_rank(), name=args.model)
        throughput(data_loader_val, model, logger)
        return

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model, class_criterion,nce_criterion, data_loader_train,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad, model_ema, mixup_fn,
            set_training_mode=True,
        )

        lr_scheduler.step(epoch)
        if args.output_dir:
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                save_on_master({
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'scaler': loss_scaler.state_dict(),
                    'args': args,
                }, checkpoint_path)

        test_stats = evaluate(data_loader_val, model, device)
        print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
        if test_stats["acc1"] > max_accuracy:
            if args.output_dir:
                checkpoint_paths = [output_dir / 'checkpoint_best.pth']
                for checkpoint_path in checkpoint_paths:
                    save_on_master({
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)
        max_accuracy = max(max_accuracy, test_stats["acc1"])
        print(f'Max accuracy: {max_accuracy:.2f}%')

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch}

        if args.output_dir and is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

In [32]:
 def get_args_parser():
    parser = argparse.ArgumentParser('Next-ViT training and evaluation script', add_help=False)
    parser.add_argument('--batch-size', default=16, type=int)
    parser.add_argument('--epochs', default=300, type=int)

    # Model parameters
    parser.add_argument('--model', default='pvt_small', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument('--input-size', default=2048, type=int, help='images input size')

    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')
    parser.add_argument('--flops', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')
    # Optimizer parameters
    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                        help='Optimizer (default: "adamw"')
    parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                        help='Optimizer Epsilon (default: 1e-8)')
    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                        help='Optimizer Betas (default: None, use opt default)')
    parser.add_argument('--clip-grad', type=float, default=5, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')
    # Learning rate schedule parameters
    parser.add_argument('--sched', default='sched', type=str, metavar='SCHEDULER',
                        help='LR scheduler (default: "cosine"')
    parser.add_argument('--lr', type=float, default=5e-6, metavar='LR',
                        help='learning rate (default: 5e-4)')
    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                        help='learning rate noise on/off epoch percentages')
    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                        help='learning rate noise limit percent (default: 0.67)')
    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                        help='learning rate noise std-dev (default: 1.0)')
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
                        help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')

    parser.add_argument('--decay-epochs', type=float, default=40, metavar='N',
                        help='epoch interval to decay LR')
    parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                        help='patience epochs for Plateau LR scheduler (default: 10')
    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                        help='LR decay rate (default: 0.1)')

    # Augmentation parameters
    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
                        help='Color jitter factor (default: 0.4)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + \
                             "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='bicubic',
                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

    parser.add_argument('--repeated-aug', action='store_true')
    parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
    parser.set_defaults(repeated_aug=False)

    # * Random Erase params
    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')

    # * Mixup params
    parser.add_argument('--mixup', type=float, default=0.8,
                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
    parser.add_argument('--cutmix', type=float, default=1.0,
                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup-prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup-mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

    # * Finetuning params
    parser.add_argument('--finetune', action='store_true', help='Perform finetune.')

    # Dataset parameters
    parser.add_argument('--data-path', default='../../datasets/imagenet_full_size/061417/', type=str,
                        help='dataset path')
    parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
                        type=str, help='Image Net dataset path')
    parser.add_argument('--use-mcloader', action='store_true', default=False, help='Use mcloader')
    parser.add_argument('--inat-category', default='name',
                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
                        type=str, help='semantic granularity')

    parser.add_argument('--output-dir', default='../../outputdir',
                        help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='./nextvit_small_in1k6m_384.pth', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin-mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
                        help='')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')

    # test throught
    parser.add_argument('--throughout', action='store_true', help='Perform throughout only')
    return parser