In [1]:
import torch, math, time, argparse, os
import random, dataset, utils, losses, net
import numpy as np

from dataset.Inshop import Inshop_Dataset
from net.resnet import *
from net.googlenet import *
from net.bn_inception import *
from dataset import sampler
from torch.utils.data.sampler import BatchSampler
from torch.utils.data.dataloader import default_collate

from tqdm import *
import wandb

In [2]:
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
LOG_DIR = '.'
trn_dataset = dataset.load('market', '../../Market-1501-v15.09.15/', 'train', transform = dataset.utils.make_transform(
                                                                                                                        is_train = True, 
                                                                                                                        is_inception = False))

In [4]:
dl_tr = torch.utils.data.DataLoader(
        trn_dataset,
        batch_size = 50,
        shuffle = True,
        num_workers = 4,
        drop_last = True,
        pin_memory = True
    )

In [5]:
model = Resnet50(embedding_size=512, pretrained=True, is_norm=1, bn_freeze =1).cuda()



In [6]:
criterion = losses.Proxy_Anchor(nb_classes = trn_dataset.nb_classes(), sz_embed = 512, mrg = 0.1, alpha = 32).cuda()

In [7]:
param_groups = [
    {'params': list(set(model.parameters()).difference(set(model.model.embedding.parameters())))},
    {'params': model.model.embedding.parameters(), 'lr':float(1e-4) * 1},
]
param_groups.append({'params': criterion.parameters(), 'lr':float(1e-4) * 100})

In [8]:
opt = torch.optim.AdamW(param_groups, lr=float(1e-4), weight_decay = 1e-4)

In [9]:
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=10, gamma = 0.5)

In [10]:
for epoch in range(0, 60):
    model.train()
    bn_freeze = True

    if bn_freeze:
            modules = model.model.modules()
            for m in modules: 
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()

    losses_per_epoch = []
    unfreeze_model_param = list(model.model.embedding.parameters()) + list(criterion.parameters())

    if epoch == 0:
        for param in list(set(model.parameters()).difference(set(unfreeze_model_param))):
            param.requires_grad = False
    if epoch == 1:
        for param in list(set(model.parameters()).difference(set(unfreeze_model_param))):
            param.requires_grad = True

    pbar = tqdm(enumerate(dl_tr))

    for batch_idx, (x, y) in pbar:                         
        m = model(x.squeeze().cuda())
        loss = criterion(m, y.squeeze().cuda())
        
        opt.zero_grad()
        loss.backward()
        
        torch.nn.utils.clip_grad_value_(model.parameters(), 10)
        
        torch.nn.utils.clip_grad_value_(criterion.parameters(), 10)

        losses_per_epoch.append(loss.data.cpu().numpy())
        opt.step()

        pbar.set_description(
            'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(
                epoch, batch_idx + 1, len(dl_tr),
                100. * batch_idx / len(dl_tr),
                loss.item()))

  return F.conv2d(input, weight, bias, self.stride,
