In [1]:
import argparse
import json
import os
import random

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets import *
from models import *
from utils import *
import pdb

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet

dim_dict = {
    'resnet18': 512,
    'resnet34': 512,
    'resnet50': 2048,
}


class ModelBase(nn.Module):
    """
    For small size figures:
    (i) replaces conv1 with kernel=3, str=1
    (ii) removes pool1
    """

    def __init__(self, figsize=32, num_classes=10, projection_dim=128, arch=None):
        super(ModelBase, self).__init__()
        resnet_arch = getattr(resnet, arch)

        self.net = resnet_arch(pretrained=True)
        if figsize <= 64:  # adapt to small-size images
            self.net.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            self.net.maxpool = nn.Identity()
        self.net.fc = nn.Identity()

        self.feat_dim = dim_dict[arch]
        self.projector = nn.Sequential(
            nn.Linear(self.feat_dim, 512),
            nn.ReLU(True),
            nn.Linear(512, projection_dim)
        )

        self.classifer = nn.Linear(self.feat_dim, num_classes)

    def forward(self, x, feat=False):
        x = self.net(x)
        if feat:
            return x
        else:
            cls, proj = self.classifer(x), self.projector(x)
            return cls, proj

In [6]:
class MaskConNew(nn.Module):
    def __init__(self, num_classes_coarse=10,  dim=128, K=4096, m=0.9, T1=0.1, T2=0.1, arch='resnet18', mode='mixcon', size=32):
        '''
        Modifed based on MoCo framework.

        :param num_classes_coarse: num of coarse classes
        :param dim: dimension of feature projections
        :param K: size of memory bank
        :param m: momentum encoder
        :param T1: temperature of original contrastive loss
        :param T2: temperature for soft labels generation
        :param arch: architecture of encoder
        :param mode: method mode [maskcon, grafit or coins]
        :param size: dataset image size
        '''
        super(MaskConNew, self).__init__()
        self.K = K
        self.m = m
        self.T1 = T1
        self.T2 = T2
        self.mode = mode
        # create the encoders
        self.encoder_q = ModelBase(figsize=size, num_classes=num_classes_coarse, projection_dim=dim, arch=arch)
        self.encoder_k = ModelBase(figsize=size, num_classes=num_classes_coarse, projection_dim=dim, arch=arch)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        self.num_classes_coarse = num_classes_coarse
        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)
        self.register_buffer('coarse_labels', torch.randint(0, num_classes_coarse, [self.K]).long())
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys, coarse_labels):
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.t()  # transpose
        self.coarse_labels[ptr:ptr + batch_size] = coarse_labels

        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_single_gpu(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        """
        # random shuffle index
        idx_shuffle = torch.randperm(x.shape[0]).cuda()

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        return x[idx_shuffle], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        """
        return x[idx_unshuffle]


    def initiate_memorybank(self, dataloader):
        print('Initiate memory bank!')
        num = 0
        iter_data = iter(dataloader)
        for i in range(self.K):  # update the memory bank with image representation
            if num == self.K:
                break
            # print(num)
            try:
                [im_k, _], coarse_label, _ = next(iter_data)
            except:
                iter_data = iter(dataloader)
                [im_k, _], coarse_label, _ = next(iter_data)
            num = num + len(im_k)
            im_k, coarse_label = im_k.cuda(non_blocking=True), coarse_label.cuda(non_blocking=True)
            im_k_, idx_unshufflek = self._batch_shuffle_single_gpu(im_k)
            _, k = self.encoder_k(im_k_)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)  # already normalized
            # undo shuffle
            k = self._batch_unshuffle_single_gpu(k, idx_unshufflek)
            self._dequeue_and_enqueue(k, coarse_label)

    def forward(self, im_k, im_q, coarse_label, args):
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()
        cls_q, q = self.encoder_q(im_q)  # queries:
        q = nn.functional.normalize(q, dim=1)  # already normalized

        # compute key features
        with torch.no_grad():  # no gradient to keys
            # shuffle for making use of BN
            im_k_, idx_unshufflek = self._batch_shuffle_single_gpu(im_k)
            _, k = self.encoder_k(im_k_)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)  # already normalized
            # undo shuffle
            k = self._batch_unshuffle_single_gpu(k, idx_unshufflek)
            
            """
            # soft-labels
            coarse_z = torch.ones(len(q), self.K).cuda()
            new_label = coarse_label.reshape(-1, 1).repeat(1, self.K)
            memory_labels = self.coarse_labels.reshape(1, -1).repeat(len(q), 1)
            coarse_z = coarse_z * (new_label == memory_labels)
            logits_pd = torch.einsum('nc,ck->nk', [k, self.queue.clone().detach()])
            logits_pd /= self.T2
            logits_pd = logits_pd * coarse_z  # mask out non-same-coarse class samples
            logits_pd = logits_pd - logits_pd.max(dim=1, keepdim=True)[0]
            pseudo_soft_z = logits_pd.exp() * coarse_z
            pseudo_sum = torch.sum(pseudo_soft_z, dim=1, keepdim=True)
            maskcon_z = torch.zeros(len(q), self.K + 1).cuda()
            maskcon_z[:, 0] = 1
            tmp = pseudo_soft_z / pseudo_sum
            # rescale by maximum
            tmp = tmp / tmp.max(dim=1, keepdim=True)[0]
            maskcon_z[:, 1:] = tmp
            # generate weighted inter-sample relations
            maskcon_z = maskcon_z / maskcon_z.sum(dim=1, keepdim=True)

            # self-supervised inter-sample relations
            self_z = torch.zeros(len(q), self.K + 1).cuda()
            self_z[:, 0] = 1.0

            labels = args.w * maskcon_z + (1 - args.w) * self_z
            """

        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        
        
        align_loss = 2- 2 * l_pos.mean()
        
        sq_dists = (2 - 2 * l_neg)
        
        mask = (coarse_label.view(-1,1) == self.coarse_labels.view(1,-1))
        
        all_sq_dists = torch.cat([sq_dists, torch.norm(q[:, None] - q, dim=2, p=2).pow(2)], 1)
        all_coarse_labels = torch.cat([self.coarse_labels, coarse_label])
        mask = (coarse_label.view(-1,1) == all_coarse_labels.view(1,-1))
        sqdists_2_average = all_sq_dists[mask].flatten()
        loss_unif = sqdists_2_average.mul(-1/self.T1).exp().mean().log()
        
        
        loss = (align_loss / self.T1) + loss_unif 
        
        
        self._dequeue_and_enqueue(k, coarse_label)
        
        """
            l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
            l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)

            align_loss = 2- 2 * l_pos.mean()

            sq_dists = (2 - 2 * l_pos).flatten()
            sq_dists = torch.cat([sq_dists, torch.pdist(q, p=2).pow(2)])

            loss_unif = sq_dists.mul(-1/self.T1).exp().mean().log()


            loss = (align_loss / self.T1) + loss_unif 

            # inside vs outside?
            self._dequeue_and_enqueue(k, coarse_label)
        """

        return loss

In [7]:
"""### Set arguments"""

parser = argparse.ArgumentParser(description='Masked contrastive learning.')

# training config:
parser.add_argument('--dataset', default='cifar100', choices=['cifar100', 'cifartoy_bad', 'cifartoy_good', 'cars196', 'sop_split1', 'sop_split2', 'imagenet32'], type=str, help='train dataset')
parser.add_argument('--data_path', default='../datasets/cifar100', type=str, help='train dataset')

# model configs: [Almost fixed for all experiments]
parser.add_argument('--arch', default='resnet18')
parser.add_argument('--dim', default=256, type=int, help='feature dimension')
parser.add_argument('--K', default=8192, type=int, help='queue size; number of negative keys')
parser.add_argument('--m', default=0.99, type=float, help='moco momentum of updating key encoder')
parser.add_argument('--t0', default=0.1, type=float, help='softmax temperature for training')

# train configs:
parser.add_argument('--lr', '--learning-rate', default=0.02, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs')
parser.add_argument('--warm_up', default=5, type=int, metavar='N', help='number of warmup epochs')
parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
parser.add_argument('--aug_q', default='strong', type=str, help='augmentation strategy for query image')
parser.add_argument('--aug_k', default='weak', type=str, help='augmentation strategy for key image')
parser.add_argument('--gpu_id', default='0', type=str, help='gpuid')

# method configs:
parser.add_argument('--mode', default='maskcon', type=str, choices=['maskcon', 'grafit', 'coins'], help='training mode')

# maskcon-specific hyperparameters:
parser.add_argument('--w', default=0.5, type=float, help='weight of self-invariance')  # not-used if maskcon
parser.add_argument('--t', default=0.05, type=float, help='softmax temperature weight for soft label')

# logger configs
parser.add_argument('--wandb_id', type=str, default="cifar100",help='wandb user id')


# train for one epoch
def train(net, data_loader, train_optimizer, epoch, args):
    net.train()
    losses, total_num = 0.0, 0.0
    train_bar = tqdm(data_loader)
    for i, [[im_k, im_q], coarse_targets, fine_targets] in enumerate(train_bar):
        adjust_learning_rate(train_optimizer, args.warm_up, epoch, args.epochs, args.lr, i, data_loader.__len__())
        im_k, im_q, coarse_targets, fine_targets = im_k.cuda(), im_q.cuda(), coarse_targets.cuda(), fine_targets.cuda()
        if args.mode == 'grafit' or args.mode == 'coins':
            loss = net.forward_explicit(im_k, im_q, coarse_targets, args)
        else:  # if args.mode == 'maskcon'
            loss = net(im_k, im_q, coarse_targets, args)
        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += im_k.shape[0]
        losses += loss.item() * im_k.shape[0]
        train_bar.set_description(
            'Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(
                epoch, args.epochs,
                train_optimizer.param_groups[0]['lr'],
                losses / total_num
            ))

    return losses / total_num


def retrieval(encoder, test_loader, K, chunks=10):
    encoder.eval()
    feature_bank, target_bank = [], []
    with torch.no_grad():
        # for i, (image, _, fine_label) in enumerate(tqdm(test_loader, desc='Retrieval ...')):
        for i, (image, _, fine_label) in enumerate(test_loader):
            image = image.cuda(non_blocking=True)
            label = fine_label.cuda(non_blocking=True)
            output = encoder(image, feat=True)
            feature_bank.append(output)
            target_bank.append(label)

        feature = F.normalize(torch.cat(feature_bank, dim=0), dim=1)
        label = torch.cat(target_bank, dim=0).contiguous()
    label = label.unsqueeze(-1)
    feat_norm = F.normalize(feature, dim=1)
    split = torch.tensor(np.linspace(0, len(feat_norm), chunks + 1, dtype=int), dtype=torch.long).to(feature.device)
    recall = [[] for i in K]
    ids = [torch.tensor([]).to(feature.device) for i in K]
    correct = [torch.tensor([]).to(feature.device) for i in K]
    k_max = np.max(K)

    with torch.no_grad():
        for j in range(chunks):
            torch.cuda.empty_cache()
            part_feature = feat_norm[split[j]: split[j + 1]]
            similarity = torch.einsum('ab,bc->ac', part_feature, feat_norm.T)

            topmax = similarity.topk(k_max + 1)[1][:, 1:]
            del similarity
            retrievalmax = label[topmax].squeeze()
            for k, i in enumerate(K):
                anchor_label = label[split[j]: split[j + 1]].repeat(1, i)
                topi = topmax[:, :i]
                retrieval_label = retrievalmax[:, :i]
                correct_i = torch.sum(anchor_label == retrieval_label, dim=1, keepdim=True)
                correct[k] = torch.cat([correct[k], correct_i], dim=0)
                ids[k] = torch.cat([ids[k], topi], dim=0)

        # calculate recall @ K
        num_sample = len(feat_norm)
        for k, i in enumerate(K):
            acc_k = float((correct[k] > 0).int().sum() / num_sample)
            recall[k] = acc_k

        ##################################################################
        # calculate precision @ K
        # precision = [[] for i in K]
        # num_sample = len(feat_norm)
        # for k, i in enumerate(K):
        #     acc_k = float((correct[k]).int().sum() / num_sample)
        #     precision[k] = acc_k / i
        ##################################################################

    return recall


def main_proc(args, model, train_loader, test_loader):
    # wandb.init(project=args.mode, entity=args.wandb_id, name='train_' + args.results_dir, group=f'train_{args.dataset}_{args.mode}')
    # wandb.config.update(args)
    """### Start training"""
    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
    epoch_start = 0

    with open(f'{args.wandb_id}/{args.results_dir}' + '/args.json', 'w') as fid:
        json.dump(args.__dict__, fid, indent=2)

    train_logs = open(f'{args.wandb_id}/{args.results_dir}/train_logs.txt', 'w')

    # training loop
    best_retrieval_top1 = 0
    best_retrieval_top2 = 0
    best_retrieval_top5 = 0
    best_retrieval_top10 = 0
    best_retrieval_top50 = 0
    best_retrieval_top100 = 0

    model.initiate_memorybank(train_loader)

    for epoch in range(epoch_start, args.epochs):
        if epoch % 10 == 0:
            retrieval_topk = retrieval(model.encoder_q, test_loader, [1, 2, 5, 10, 50, 100])
            retrieval_top1, retrieval_top2, retrieval_top5, retrieval_top10, retrieval_top50, retrieval_top100 = retrieval_topk
            if retrieval_top1 > best_retrieval_top1:
                best_retrieval_top1 = best_retrieval_top1
            if retrieval_top2 > best_retrieval_top2:
                best_retrieval_top2 = best_retrieval_top2
            if retrieval_top5 > best_retrieval_top5:
                best_retrieval_top5 = best_retrieval_top5
            if retrieval_top10 > best_retrieval_top10:
                best_retrieval_top10 = best_retrieval_top10
            if retrieval_top50 > best_retrieval_top50:
                best_retrieval_top50 = best_retrieval_top50
            if retrieval_top100 > best_retrieval_top100:
                best_retrieval_top100 = best_retrieval_top100

            # wandb.log({'R@1': retrieval_top1, 'R@2': retrieval_top2, 'R@5': retrieval_top5, 'R@10': retrieval_top10, 'R@50': retrieval_top50, 'R@100': retrieval_top100}, step=epoch)
            # save statistics
            print(f'Epoch [{epoch}/{args.epochs}]: R@1: {retrieval_top1:.4f}, R@2: {retrieval_top2:.4f}, R@5: {retrieval_top5:.4f}, R@10: {retrieval_top10:.4f},  R@50: {retrieval_top50:.4f},R@100: {retrieval_top100:.4f}')
            train_logs.write(
                f'Epoch [{epoch}/{args.epochs}]: R@1: {retrieval_top1:.4f}, R@2: {retrieval_top2:.4f}, R@5: {retrieval_top5:.4f}, R@10: {retrieval_top10:.4f},  R@50: {retrieval_top50:.4f},R@100: {retrieval_top100:.4f}\n')
            train_logs.flush()

        train(model, train_loader, optimizer, epoch, args)
    # wandb.finish()
    return model

In [None]:

args = parser.parse_args("")
print(args)

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
random.seed(1228)
torch.manual_seed(1228)
torch.cuda.manual_seed_all(1228)
np.random.seed(1228)
torch.backends.cudnn.benchmark = True

"""Define train/test"""
query_transform = get_augment(args.dataset, args.aug_q)
key_transform = get_augment(args.dataset, args.aug_k)
test_transform = get_augment(args.dataset)

if args.dataset == 'cars196':
    train_dataset = CARS196(root=args.data_path, split='train', transform=DMixTransform([key_transform, query_transform], [1, 1]))
    test_dataset = CARS196(root=args.data_path, split='test', transform=test_transform)
    args.num_classes = 8
    args.size = 224

elif args.dataset == 'cifar100':
    train_dataset = CIFAR100(root=args.data_path, download=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
    test_dataset = CIFAR100(root=args.data_path, train=False, download=True, transform=test_transform)
    args.num_classes = 20
    args.size = 32

elif args.dataset == 'cifartoy_good':
    train_dataset = CIFARtoy(root=args.data_path, split='good', download=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
    test_dataset = CIFARtoy(root=args.data_path, split='good', train=False, download=True, transform=test_transform)
    args.num_classes = 2
    args.size = 32

elif args.dataset == 'cifartoy_bad':
    train_dataset = CIFARtoy(root=args.data_path, split='bad', download=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
    test_dataset = CIFARtoy(root=args.data_path, split='bad', train=False, download=True, transform=test_transform)
    args.num_classes = 2
    args.size = 32

elif args.dataset == 'sop_split2':
    train_dataset = StanfordOnlineProducts(split='2', root=args.data_path, train=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
    test_dataset = StanfordOnlineProducts(split='2', root=args.data_path, train=False, transform=test_transform)
    args.num_classes = 12
    args.size = 224

elif args.dataset == 'sop_split1':
    train_dataset = StanfordOnlineProducts(split='1', root=args.data_path, train=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
    test_dataset = StanfordOnlineProducts(split='1', root=args.data_path, train=False, transform=test_transform)
    args.num_classes = 12
    args.size = 224

elif args.dataset == 'imagenet32':
    train_dataset = ImageNetDownSample(root=args.data_path, train=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
    test_dataset = ImageNetDownSample(root=args.data_path, train=False, transform=test_transform)
    args.num_classes = 12
    args.size = 32

else:
    raise ValueError(f'{args.dataset} is not supported now!')

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

# create trainer
trainer = MaskConNew(num_classes_coarse=args.num_classes, dim=args.dim, K=args.K, m=args.m, T1=args.t0, arch=args.arch, size=args.size, T2=args.t, mode=args.mode).cuda()

args.results_dir = f'arch_[{args.arch}]_data[{args.dataset}]_epochs[{args.epochs}]_memorysize[{args.K}]_mode[{args.mode}]_contrastive_temperature[{args.t0}]_temperature_maskcon[{args.t}]_weight[{args.w}]]'

if not os.path.exists(args.wandb_id):
    os.mkdir(args.wandb_id)
if not os.path.exists(f'{args.wandb_id}/{args.results_dir}'):
    os.mkdir(f'{args.wandb_id}/{args.results_dir}')

main_proc(args, trainer, train_loader, test_loader)



Namespace(K=8192, arch='resnet18', aug_k='weak', aug_q='strong', batch_size=128, data_path='../datasets/cifar100', dataset='cifar100', dim=256, epochs=200, gpu_id='0', lr=0.02, m=0.99, mode='maskcon', t=0.05, t0=0.1, w=0.5, wandb_id='cifar100', warm_up=5, wd=0.0005)
Files already downloaded and verified
Files already downloaded and verified
Initiate memory bank!
Epoch [0/200]: R@1: 0.1181, R@2: 0.1708, R@5: 0.2727, R@10: 0.3843,  R@50: 0.7080,R@100: 0.8366


Train Epoch: [0/200], lr: 0.003990, Loss: -0.9655: 100%|██████████| 390/390 [00:20<00:00, 19.22it/s]
Train Epoch: [1/200], lr: 0.007990, Loss: -1.8940: 100%|██████████| 390/390 [00:19<00:00, 19.56it/s]
Train Epoch: [2/200], lr: 0.011990, Loss: -2.3971: 100%|██████████| 390/390 [00:20<00:00, 19.46it/s]
Train Epoch: [3/200], lr: 0.015990, Loss: -2.6796: 100%|██████████| 390/390 [00:20<00:00, 19.48it/s]
Train Epoch: [4/200], lr: 0.019990, Loss: -2.8636: 100%|██████████| 390/390 [00:20<00:00, 19.44it/s]
Train Epoch: [5/200], lr: 0.019999, Loss: -2.9757: 100%|██████████| 390/390 [00:20<00:00, 19.48it/s]
Train Epoch: [6/200], lr: 0.019995, Loss: -3.0989: 100%|██████████| 390/390 [00:20<00:00, 19.47it/s]
Train Epoch: [7/200], lr: 0.019988, Loss: -3.1727: 100%|██████████| 390/390 [00:20<00:00, 19.21it/s]
Train Epoch: [8/200], lr: 0.019979, Loss: -3.2075: 100%|██████████| 390/390 [00:20<00:00, 19.43it/s]
Train Epoch: [9/200], lr: 0.019968, Loss: -3.2483: 100%|██████████| 390/390 [00:20<00:00, 1

Epoch [10/200]: R@1: 0.2527, R@2: 0.3494, R@5: 0.5026, R@10: 0.6205,  R@50: 0.8753,R@100: 0.9401


Train Epoch: [10/200], lr: 0.019953, Loss: -3.2574: 100%|██████████| 390/390 [00:20<00:00, 19.43it/s]
Train Epoch: [11/200], lr: 0.019937, Loss: -3.2810: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [12/200], lr: 0.019917, Loss: -3.3255: 100%|██████████| 390/390 [00:20<00:00, 19.40it/s]
Train Epoch: [13/200], lr: 0.019895, Loss: -3.3201: 100%|██████████| 390/390 [00:20<00:00, 19.44it/s]
Train Epoch: [14/200], lr: 0.019871, Loss: -3.3377: 100%|██████████| 390/390 [00:20<00:00, 19.38it/s]
Train Epoch: [15/200], lr: 0.019843, Loss: -3.3612: 100%|██████████| 390/390 [00:19<00:00, 19.50it/s]
Train Epoch: [16/200], lr: 0.019814, Loss: -3.3576: 100%|██████████| 390/390 [00:20<00:00, 19.44it/s]
Train Epoch: [17/200], lr: 0.019782, Loss: -3.3747: 100%|██████████| 390/390 [00:20<00:00, 19.41it/s]
Train Epoch: [18/200], lr: 0.019747, Loss: -3.3554: 100%|██████████| 390/390 [00:20<00:00, 19.40it/s]
Train Epoch: [19/200], lr: 0.019710, Loss: -3.3786: 100%|██████████| 390/390 [00:2

Epoch [20/200]: R@1: 0.2408, R@2: 0.3352, R@5: 0.4870, R@10: 0.6084,  R@50: 0.8625,R@100: 0.9360


Train Epoch: [20/200], lr: 0.019670, Loss: -3.3836: 100%|██████████| 390/390 [00:20<00:00, 19.39it/s]
Train Epoch: [21/200], lr: 0.019627, Loss: -3.4211: 100%|██████████| 390/390 [00:20<00:00, 19.37it/s]
Train Epoch: [22/200], lr: 0.019583, Loss: -3.4113: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [23/200], lr: 0.019535, Loss: -3.4285: 100%|██████████| 390/390 [00:20<00:00, 19.39it/s]
Train Epoch: [24/200], lr: 0.019485, Loss: -3.4307: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [25/200], lr: 0.019433, Loss: -3.4467: 100%|██████████| 390/390 [00:20<00:00, 19.41it/s]
Train Epoch: [26/200], lr: 0.019379, Loss: -3.4540: 100%|██████████| 390/390 [00:19<00:00, 19.54it/s]
Train Epoch: [27/200], lr: 0.019321, Loss: -3.4640: 100%|██████████| 390/390 [00:20<00:00, 19.47it/s]
Train Epoch: [28/200], lr: 0.019262, Loss: -3.4716: 100%|██████████| 390/390 [00:20<00:00, 19.36it/s]
Train Epoch: [29/200], lr: 0.019200, Loss: -3.4648: 100%|██████████| 390/390 [00:1

Epoch [30/200]: R@1: 0.2384, R@2: 0.3382, R@5: 0.4862, R@10: 0.6142,  R@50: 0.8646,R@100: 0.9321


Train Epoch: [30/200], lr: 0.019136, Loss: -3.4998: 100%|██████████| 390/390 [00:20<00:00, 19.43it/s]
Train Epoch: [31/200], lr: 0.019069, Loss: -3.4876: 100%|██████████| 390/390 [00:20<00:00, 19.45it/s]
Train Epoch: [32/200], lr: 0.019000, Loss: -3.5036: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [33/200], lr: 0.018928, Loss: -3.5090: 100%|██████████| 390/390 [00:20<00:00, 19.40it/s]
Train Epoch: [34/200], lr: 0.018855, Loss: -3.5284: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [35/200], lr: 0.018779, Loss: -3.5285: 100%|██████████| 390/390 [00:20<00:00, 19.45it/s]
Train Epoch: [36/200], lr: 0.018700, Loss: -3.5309: 100%|██████████| 390/390 [00:20<00:00, 19.45it/s]
Train Epoch: [37/200], lr: 0.018620, Loss: -3.5233: 100%|██████████| 390/390 [00:20<00:00, 19.44it/s]
Train Epoch: [38/200], lr: 0.018537, Loss: -3.5567: 100%|██████████| 390/390 [00:20<00:00, 19.38it/s]
Train Epoch: [39/200], lr: 0.018452, Loss: -3.5610: 100%|██████████| 390/390 [00:2

Epoch [40/200]: R@1: 0.2443, R@2: 0.3418, R@5: 0.4893, R@10: 0.6086,  R@50: 0.8629,R@100: 0.9334


Train Epoch: [40/200], lr: 0.018365, Loss: -3.5718: 100%|██████████| 390/390 [00:19<00:00, 19.50it/s]
Train Epoch: [41/200], lr: 0.018276, Loss: -3.5860: 100%|██████████| 390/390 [00:20<00:00, 19.24it/s]
Train Epoch: [42/200], lr: 0.018184, Loss: -3.5761: 100%|██████████| 390/390 [00:20<00:00, 19.41it/s]
Train Epoch: [43/200], lr: 0.018090, Loss: -3.5704: 100%|██████████| 390/390 [00:20<00:00, 19.46it/s]
Train Epoch: [44/200], lr: 0.017995, Loss: -3.5808: 100%|██████████| 390/390 [00:20<00:00, 19.49it/s]
Train Epoch: [45/200], lr: 0.017897, Loss: -3.5655: 100%|██████████| 390/390 [00:20<00:00, 19.44it/s]
Train Epoch: [46/200], lr: 0.017797, Loss: -3.5760: 100%|██████████| 390/390 [00:20<00:00, 19.44it/s]
Train Epoch: [47/200], lr: 0.017695, Loss: -3.5835: 100%|██████████| 390/390 [00:20<00:00, 19.33it/s]
Train Epoch: [48/200], lr: 0.017591, Loss: -3.6040: 100%|██████████| 390/390 [00:20<00:00, 19.41it/s]
Train Epoch: [49/200], lr: 0.017485, Loss: -3.5785: 100%|██████████| 390/390 [00:2

Epoch [50/200]: R@1: 0.2427, R@2: 0.3360, R@5: 0.4896, R@10: 0.6124,  R@50: 0.8624,R@100: 0.9327


Train Epoch: [50/200], lr: 0.017378, Loss: -3.6074: 100%|██████████| 390/390 [00:20<00:00, 19.40it/s]
Train Epoch: [51/200], lr: 0.017268, Loss: -3.6209: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [52/200], lr: 0.017156, Loss: -3.6181: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [53/200], lr: 0.017043, Loss: -3.6181: 100%|██████████| 390/390 [00:20<00:00, 19.39it/s]
Train Epoch: [54/200], lr: 0.016928, Loss: -3.6355: 100%|██████████| 390/390 [00:20<00:00, 19.40it/s]
Train Epoch: [55/200], lr: 0.016810, Loss: -3.6203: 100%|██████████| 390/390 [00:20<00:00, 19.39it/s]
Train Epoch: [56/200], lr: 0.016692, Loss: -3.6080: 100%|██████████| 390/390 [00:20<00:00, 19.37it/s]
Train Epoch: [57/200], lr: 0.016571, Loss: -3.6575: 100%|██████████| 390/390 [00:19<00:00, 19.50it/s]
Train Epoch: [58/200], lr: 0.016449, Loss: -3.6249: 100%|██████████| 390/390 [00:20<00:00, 19.50it/s]
Train Epoch: [59/200], lr: 0.016325, Loss: -3.6396: 100%|██████████| 390/390 [00:2

Epoch [60/200]: R@1: 0.2543, R@2: 0.3538, R@5: 0.5018, R@10: 0.6177,  R@50: 0.8635,R@100: 0.9352


Train Epoch: [60/200], lr: 0.016199, Loss: -3.6360: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [61/200], lr: 0.016072, Loss: -3.6446: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [62/200], lr: 0.015943, Loss: -3.6594: 100%|██████████| 390/390 [00:20<00:00, 19.43it/s]
Train Epoch: [63/200], lr: 0.015813, Loss: -3.6563: 100%|██████████| 390/390 [00:20<00:00, 19.39it/s]
Train Epoch: [64/200], lr: 0.015681, Loss: -3.6588: 100%|██████████| 390/390 [00:20<00:00, 19.44it/s]
Train Epoch: [65/200], lr: 0.015548, Loss: -3.6724: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [66/200], lr: 0.015413, Loss: -3.6591: 100%|██████████| 390/390 [00:20<00:00, 19.41it/s]
Train Epoch: [67/200], lr: 0.015277, Loss: -3.6680: 100%|██████████| 390/390 [00:20<00:00, 19.38it/s]
Train Epoch: [68/200], lr: 0.015139, Loss: -3.6799: 100%|██████████| 390/390 [00:20<00:00, 19.44it/s]
Train Epoch: [69/200], lr: 0.015000, Loss: -3.6851: 100%|██████████| 390/390 [00:2

Epoch [70/200]: R@1: 0.2660, R@2: 0.3599, R@5: 0.5038, R@10: 0.6276,  R@50: 0.8722,R@100: 0.9375


Train Epoch: [70/200], lr: 0.014860, Loss: -3.6696: 100%|██████████| 390/390 [00:20<00:00, 19.43it/s]
Train Epoch: [71/200], lr: 0.014719, Loss: -3.6870: 100%|██████████| 390/390 [00:20<00:00, 19.47it/s]
Train Epoch: [72/200], lr: 0.014576, Loss: -3.6738: 100%|██████████| 390/390 [00:20<00:00, 19.47it/s]
Train Epoch: [73/200], lr: 0.014432, Loss: -3.6940: 100%|██████████| 390/390 [00:19<00:00, 19.55it/s]
Train Epoch: [74/200], lr: 0.014287, Loss: -3.6974: 100%|██████████| 390/390 [00:20<00:00, 19.41it/s]
Train Epoch: [75/200], lr: 0.014141, Loss: -3.6916: 100%|██████████| 390/390 [00:20<00:00, 19.41it/s]
Train Epoch: [76/200], lr: 0.013994, Loss: -3.6812: 100%|██████████| 390/390 [00:20<00:00, 19.39it/s]
Train Epoch: [77/200], lr: 0.013846, Loss: -3.6895: 100%|██████████| 390/390 [00:19<00:00, 19.50it/s]
Train Epoch: [78/200], lr: 0.013697, Loss: -3.6856: 100%|██████████| 390/390 [00:20<00:00, 19.49it/s]
Train Epoch: [79/200], lr: 0.013546, Loss: -3.6856: 100%|██████████| 390/390 [00:2

Epoch [80/200]: R@1: 0.2641, R@2: 0.3629, R@5: 0.5126, R@10: 0.6303,  R@50: 0.8692,R@100: 0.9383


Train Epoch: [80/200], lr: 0.013395, Loss: -3.6900: 100%|██████████| 390/390 [00:19<00:00, 19.55it/s]
Train Epoch: [81/200], lr: 0.013243, Loss: -3.7098: 100%|██████████| 390/390 [00:20<00:00, 19.47it/s]
Train Epoch: [82/200], lr: 0.013091, Loss: -3.7129: 100%|██████████| 390/390 [00:20<00:00, 19.46it/s]
Train Epoch: [83/200], lr: 0.012937, Loss: -3.7079: 100%|██████████| 390/390 [00:20<00:00, 19.37it/s]
Train Epoch: [84/200], lr: 0.012783, Loss: -3.7145: 100%|██████████| 390/390 [00:20<00:00, 19.40it/s]
Train Epoch: [85/200], lr: 0.012627, Loss: -3.7196: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [86/200], lr: 0.012472, Loss: -3.6897: 100%|██████████| 390/390 [00:20<00:00, 19.42it/s]
Train Epoch: [87/200], lr: 0.012315, Loss: -3.7205: 100%|██████████| 390/390 [00:20<00:00, 19.44it/s]
Train Epoch: [88/200], lr: 0.012158, Loss: -3.7376: 100%|██████████| 390/390 [00:20<00:00, 19.39it/s]
Train Epoch: [89/200], lr: 0.012063, Loss: -3.7322:  61%|██████    | 236/390 [00:1

In [None]:
args.wandb_id