In [None]:
import os
import pdb
import glob
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import numpy as np


from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models.resnet import resnet50, resnet18
from torch.autograd import Variable

from sklearn.cluster import _k_means_fast as _k_means
from sklearn.metrics import silhouette_score

# Data Loader

In [None]:
class SneakersX(Dataset):
    def __init__(self, data_path='/home/data/sneakers/resized_img/', 
                 transform_org=None, transform_all=None, transform_shape=None, transform_color=None, 
                 meta_datapath='/home/data/sneakers/val_sneakers_df.pkl', mode='train'):
        
        self.meta = pd.read_pickle(meta_datapath)
        self.pid = self.meta['pid'].values
        image_pids = []
        for image_fname in glob.glob(data_path+'*'):
            image_pids.append(image_fname.split('/')[-1][:-4])
        set_pid, set_img_pid = set(self.pid), set(image_pids)
        print(len(set_pid))
        print(len(set_img_pid))
        
        self.pid = list(set_pid.intersection(set_img_pid))
        self.img = []
        self.data_path = data_path
        for pid in self.pid:
            img_name = self.data_path+pid+'.jpg'
            img = Image.open(img_name)
            self.img.append(np.array(img))
        
        self.transform_all=transform_all
        self.transform_shape=transform_shape
        self.transform_color=transform_color
        self.transform_org=transform_org
        if mode=='meta':
            self.mode = 'meta'
        else:
            self.mode = 'train'
            
    def __getitem__(self,idx):
        pid = self.pid[idx]
        img = self.img[idx]
        img = Image.fromarray(img)
        if self.transform_shape and self.transform_color:
            pos_org = self.transform_org(img)
            pos_all = self.transform_all(img)
            pos_shape = self.transform_shape(img)
            pos_color = self.transform_color(img)
        else:
            pos_org,pos_all,pos_shape,pos_color = img,img,img,img
            
        if self.mode=='train':
            return pos_org,pos_all,pos_shape,pos_color
        
        return pos_org,pid
    
    
    def __len__(self):
        return len(self.pid)

In [None]:
# [Image Augmentaiton] Dataloader usage
train_transform_org = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

train_transform_all = transforms.Compose([
    transforms.RandomResizedCrop(256),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=1), # color augmentation
    transforms.RandomGrayscale(p=0.2), # color augmentation
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

train_transform_shape = transforms.Compose([
    transforms.RandomResizedCrop(256),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

train_transform_color = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=1), # color augmentation
    transforms.RandomGrayscale(p=0.2), # color augmentation
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

In [None]:
dataset = SneakersX(transform_org = train_transform_org, 
                    transform_all = train_transform_all, 
                    transform_shape= train_transform_shape,
                    transform_color = train_transform_color)

# Train

In [None]:
# train for one epoch to learn unique features
def train(net, data_loader, train_optimizer):
    net.train()
    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)

    for pos_org, pos_all, pos_shape, pos_color in train_bar:
        batch_size = pos_org.size(0)
        pos_org, pos_all = pos_org.cuda(non_blocking=True), pos_all.cuda(non_blocking=True)
        pos_shape, pos_color = pos_shape.cuda(non_blocking=True), pos_color.cuda(non_blocking=True)
        
        pos_all_pair = torch.cat((pos_org, pos_all))
        pos_shape_pair = torch.cat((pos_org, pos_shape))
        pos_color_pair = torch.cat((pos_org, pos_color))
        
        _, out_alls = net(pos_all_pair, query_type=torch.tensor(0).repeat(batch_size * 2).cuda())
        _, out_shapes = net(pos_shape_pair, query_type=torch.tensor(1).repeat(batch_size * 2).cuda())
        _, out_colors = net(pos_color_pair, query_type=torch.tensor(2).repeat(batch_size * 2).cuda())
        
        out_org1, out_all = torch.split(out_alls, len(out_alls) // 2)
        out_org2, out_shape = torch.split(out_shapes, len(out_shapes) // 2)
        out_org3, out_color = torch.split(out_colors, len(out_colors) // 2)
               
        # all-invariant loss
        sim_matrix1 = torch.exp(torch.mm(out_alls, out_alls.t().contiguous()) / temperature)
        mask1 = (torch.ones_like(sim_matrix1) - torch.eye(2 * batch_size, device=sim_matrix1.device)).bool()
        sim_matrix1 = sim_matrix1.masked_select(mask1).view(2 * batch_size, -1)[:batch_size, :]
        pos_org_all = torch.exp(torch.sum(out_org1 * out_all, dim=-1) / temperature)
        all_loss = (- torch.log(pos_org_all / sim_matrix1.sum(dim=-1))).mean()
        
        # shape-invariant loss
        sim_matrix2 = torch.exp(torch.mm(out_shapes, out_shapes.t().contiguous()) / temperature)
        mask2 = (torch.ones_like(sim_matrix2) - torch.eye(2 * batch_size, device=sim_matrix2.device)).bool()
        sim_matrix2 = sim_matrix2.masked_select(mask2).view(2 * batch_size, -1)[:batch_size, :]
        pos_org_shape = torch.exp(torch.sum(out_org2 * out_shape, dim=-1) / temperature)
        shape_loss = (- torch.log(pos_org_shape / sim_matrix2.sum(dim=-1))).mean()
        
        # color-invariant loss
        sim_matrix3 = torch.exp(torch.mm(out_colors, out_colors.t().contiguous()) / temperature)
        mask3 = (torch.ones_like(sim_matrix3) - torch.eye(2 * batch_size, device=sim_matrix3.device)).bool()
        sim_matrix3 = sim_matrix3.masked_select(mask3).view(2 * batch_size, -1)[:batch_size, :]
        pos_org_shape = torch.exp(torch.sum(out_org3 * out_color, dim=-1) / temperature)
        color_loss = (- torch.log(pos_org_shape / sim_matrix3.sum(dim=-1))).mean()
        
        loss = all_loss + shape_loss + color_loss
        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += batch_size
        total_loss += loss.item() * batch_size
        train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num))

    return total_loss / total_num

In [None]:
class ResNet18(nn.Module):
    def __init__(self, feature_dim=128):
        super(ResNet18, self).__init__()

        self.f = []
        for name, module in resnet18(pretrained=True).named_children():
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        self.query = nn.Embedding(num_embeddings = 3, embedding_dim = 512, padding_idx = 1)
        
        # projection head
        self.g_all = nn.Sequential(nn.Linear(512, 512, bias=False), nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))
        
        self.g_shape = nn.Sequential(nn.Linear(512, 512, bias=False), nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))
        
        self.g_color = nn.Sequential(nn.Linear(512, 512, bias=False), nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))
        self.g_list = nn.ModuleList([self.g_all, self.g_shape, self.g_color])

        
    def forward(self, x, query_type, normalize=True):
        x = self.f(x)
        org_feature = torch.flatten(x, start_dim=1)
        mask = torch.sigmoid(self.query(query_type))
        feature = org_feature * mask
        out = self.g_list[query_type[0]](feature)
        if normalize == False:
            return org_feature, feature, F.normalize(out, dim=-1)
        
        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)

# Main

In [None]:
parser = argparse.ArgumentParser(description='Train SimCLR')
parser.add_argument('--feature_dim', default=128, type=int, help='Feature dim for latent vector')
parser.add_argument('--temperature', default=0.5, type=float, help='Temperature used in softmax')
parser.add_argument('--k', default=200, type=int, help='Top k most similar images used to predict the label')
parser.add_argument('--batch_size', default=128, type=int, help='Number of images in each mini-batch')
parser.add_argument('--epochs', default=500, type=int, help='Number of sweeps over the dataset to train')
parser.add_argument('--seed', default=1567010775, type=int, help='random seed')

# args parse
args = parser.parse_args('')  # settings become default
feature_dim, temperature, k = args.feature_dim, args.temperature, args.k
batch_size, epochs = args.batch_size, args.epochs

# data prepare
train_data = dataset
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)

# model setup and optimizer config
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet18(feature_dim)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    model.cuda()

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

# training loop
save_name_pre = 'preact_looc_{}_{}_{}_{}_{}'.format(feature_dim, temperature, k, batch_size, epochs)
if not os.path.exists('results'):
    os.mkdir('results')

for epoch in range(1, epochs + 1):
    train_loss = train(model, train_loader, optimizer)
    if epoch==1 or epoch % 50 == 0:
        torch.save(model.state_dict(), 'results/epoch{}_{}_model_sneakers_mask.pth'.format(epoch, save_name_pre))
    
torch.save(model.state_dict(), 'results/final_{}_model_sneakers_mask.pth'.format(save_name_pre))
print('\n', '[END] {}_model_sneakers.pth => The trained has been saved!'.format(save_name_pre))

# Inference

In [None]:
resnet_feature_dim = 128
PATH = 'results/final_{}_model_sneakers_mask.pth'.format(save_name_pre)

In [None]:
# model setup and optimizer config
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet18(resnet_feature_dim)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    model.cuda()

model.load_state_dict(torch.load(PATH))
model.eval()

In [None]:
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

In [None]:
testset = SneakersX(transform_org = test_transform, 
                    transform_all = test_transform, 
                    transform_shape= test_transform,
                    transform_color = test_transform, mode = 'meta')

In [None]:
test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True, drop_last=False)

In [None]:
i = 0
org_list = []
feature_all_list = []
feature_shape_list = []
feature_color_list = []
out_all_list = []

img_label_dict = {}
with torch.no_grad():
    for pos_1, meta in tqdm(test_loader):
        pos_1 = pos_1.cuda(non_blocking=True)
        batch_szsz = pos_1.size(0)
        
        org_feature_all, feature_all, out_all = model(pos_1, query_type=torch.tensor(0).repeat(batch_szsz).cuda(), normalize=False)
        _, feature_shape, out_shape = model(pos_1, query_type=torch.tensor(1).repeat(batch_szsz).cuda(), normalize=False)
        _, feature_color, out_color = model(pos_1, query_type=torch.tensor(2).repeat(batch_szsz).cuda(), normalize=False)
        out_total = torch.cat((out_all, out_shape, out_color), dim=1)
        
        org_list.append(org_feature_all)
        feature_all_list.append(feature_all)
        feature_shape_list.append(feature_shape)
        feature_color_list.append(feature_color)
        out_all_list.append(out_total)

        for label_ in meta:
            img_label_dict[i] = label_
            i += 1

In [None]:
org_cat = torch.cat(org_list, dim = 0)
feature_all_cat = torch.cat(feature_all_list, dim = 0)
feature_shape_cat = torch.cat(feature_shape_list, dim = 0)
feature_color_cat = torch.cat(feature_color_list, dim = 0)
out_all_cat = torch.cat(out_all_list, dim = 0)

In [None]:
matrix = feature_all_cat
dimension = 512
df_name = './df_512_mask_all.csv'

img_label_list = matrix.tolist()
for j in range(0,len(img_label_list)):
    img_label_list[j].append(img_label_dict[j])
    
img_label_df = pd.DataFrame(img_label_list)
img_label_df_new = img_label_df.rename(columns={dimension: 'modelId'})
testset.meta_new = testset.meta.rename(columns={'id': 'modelId'})
total_df = pd.merge(img_label_df_new, testset.meta_new, on="modelId")
total_df.to_csv(df_name, index=False)

In [None]:
matrix = feature_shape_cat
dimension = 512
df_name = './df_512_mask_shape.csv'

img_label_list = matrix.tolist()
for j in range(0,len(img_label_list)):
    img_label_list[j].append(img_label_dict[j])
    
img_label_df = pd.DataFrame(img_label_list)
img_label_df_new = img_label_df.rename(columns={dimension: 'modelId'})
testset.meta_new = testset.meta.rename(columns={'id': 'modelId'})
total_df = pd.merge(img_label_df_new, testset.meta_new, on="modelId")
total_df.to_csv(df_name, index=False)

In [None]:
matrix = feature_color_cat
dimension = 512
df_name = './df_512_mask_color.csv'

img_label_list = matrix.tolist()
for j in range(0,len(img_label_list)):
    img_label_list[j].append(img_label_dict[j])
    
img_label_df = pd.DataFrame(img_label_list)
img_label_df_new = img_label_df.rename(columns={dimension: 'modelId'})
testset.meta_new = testset.meta.rename(columns={'id': 'modelId'})
total_df = pd.merge(img_label_df_new, testset.meta_new, on="modelId")
total_df.to_csv(df_name, index=False)

In [None]:
# End.