In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch import optim
import time
from tqdm import tqdm
import os
from torch.utils.data import Dataset
from PIL import Image

In [2]:
from sklearn import metrics
from sklearn.neighbors import KNeighborsClassifier

In [3]:
import pandas as pd

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 

In [6]:
train_path = '/home/ielab/dataset/ms1m_dataset/train_gan/'
test_path = '/home/ielab/dataset/ms1m_dataset/test_gan'

og_test_path = '/home/ielab/dataset/ms1m_dataset/test/'

In [7]:
result_path = '/home/ielab/project/samples/results/gan/'

In [8]:
col_list = os.listdir(train_path)

lb = {string : i for i,string in enumerate(col_list)}

In [9]:
batch_size = 64

In [10]:
num_classe = len(col_list)

In [11]:
trans = transforms.Compose([
    transforms.Resize((112,112)),
    transforms.ToTensor(),
    ])

In [12]:
class CustomDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.dir_list = os.listdir(img_dir)
        self.transform = transform
        self.target_transform = target_transform

        self.files = []
        for folder in self.dir_list:
          folder_list = os.listdir(os.path.join(img_dir, folder))
          for file in folder_list:
            self.files.append([
                os.path.join(self.img_dir, folder+'/'+file),
                folder])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        image = Image.open(self.files[idx][0])
        image = image.convert('RGB')
        image = self.transform(image)

        label = self.files[idx][1]
        label = lb[label]

        return image, label

In [13]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

trainset = CustomDataset(train_path, trans)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

testset = CustomDataset(test_path, trans)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)

In [14]:
class ArcFaceLoss(nn.Module):
    def __init__(self, num_classes, embedding_size, margin, scale):
        """
        ArcFace: Additive Angular Margin Loss for Deep Face Recognition
        (https://arxiv.org/pdf/1801.07698.pdf)
        Args:
            num_classes: The number of classes in your training dataset
            embedding_size: The size of the embeddings that you pass into
            margin: m in the paper, the angular margin penalty in radians
            scale: s in the paper, feature scale
        """
        super().__init__()
        self.num_classes = num_classes
        self.embedding_size = embedding_size
        self.margin = margin
        self.scale = scale
        
        self.W = torch.nn.Parameter(torch.Tensor(num_classes, embedding_size))
        nn.init.xavier_normal_(self.W)
        
    def forward(self, embeddings, labels):
        """
        Args:
            embeddings: (None, embedding_size)
            labels: (None,)
        Returns:
            loss: scalar
        """
        cosine = self.get_cosine(embeddings) # (None, n_classes)
        mask = self.get_target_mask(labels) # (None, n_classes)
        cosine_of_target_classes = cosine[mask == 1] # (None, )
        modified_cosine_of_target_classes = self.modify_cosine_of_target_classes(
            cosine_of_target_classes
        ) # (None, )
        diff = (modified_cosine_of_target_classes - cosine_of_target_classes).unsqueeze(1) # (None,1)
        logits = cosine + (mask * diff) # (None, n_classes)
        logits = self.scale_logits(logits) # (None, n_classes)
        return logits
        
    def get_cosine(self, embeddings):
        """
        Args:
            embeddings: (None, embedding_size)
        Returns:
            cosine: (None, n_classes)
        """
        cosine = F.linear(F.normalize(embeddings), F.normalize(self.W))
        return cosine
    
    def get_target_mask(self, labels):
        """
        Args:
            labels: (None,)
        Returns:
            mask: (None, n_classes)
        """
        batch_size = labels.size(0)
        onehot = torch.zeros(batch_size, self.num_classes, device=labels.device)
        onehot.scatter_(1, labels.unsqueeze(-1), 1)
        return onehot
        
    def modify_cosine_of_target_classes(self, cosine_of_target_classes):
        """
        Args:
            cosine_of_target_classes: (None,)
        Returns:
            modified_cosine_of_target_classes: (None,)
        """
        eps = 1e-6
        # theta in the paper
        angles = torch.acos(torch.clamp(cosine_of_target_classes, -1 + eps, 1 - eps))
        return torch.cos(angles + self.margin)
    
    def scale_logits(self, logits):
        """
        Args:
            logits: (None, n_classes)
        Returns:
            scaled_logits: (None, n_classes)
        """
        return logits * self.scale
    
class SoftmaxLoss(nn.Module):
    def __init__(self, num_classes, embedding_size):
        """
        Regular softmax loss (1 fc layer without bias + CrossEntropyLoss)
        Args:
            num_classes: The number of classes in your training dataset
            embedding_size: The size of the embeddings that you pass into
        """
        super().__init__()
        self.num_classes = num_classes
        self.embedding_size = embedding_size
        
        self.W = torch.nn.Parameter(torch.Tensor(num_classes, embedding_size))
        nn.init.xavier_normal_(self.W)
        
    def forward(self, embeddings, labels):
        """
        Args:
            embeddings: (None, embedding_size)
            labels: (None,)
        Returns:
            loss: scalar
        """
        logits = F.linear(embeddings, self.W)
        return nn.CrossEntropyLoss()(logits, labels)

In [15]:
class Embedder(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        self.model = torchvision.models.resnet50(pretrained = False)
        self.dropout = nn.Dropout(0.5)
        
        self.classifier = nn.Linear(1000, embedding_size)

    def forward(self, images):
        outputs = self.model(images)
        outputs = self.dropout(outputs)
        outputs = self.classifier(outputs)
        return outputs

In [16]:
# Epoch 반복 횟수
repeat_num = 5


embedding_size = 512
max_epochs = 100

embedder = Embedder(embedding_size=embedding_size)
arcface = ArcFaceLoss(num_classes=num_classe, embedding_size=embedding_size,margin=0.3, scale=30.0)

if torch.cuda.device_count() > 1:
    embedder = nn.DataParallel(embedder)
    arcface = nn.DataParallel(arcface)
embedder = embedder.to(device)
arcface = arcface.to(device)

optimizer = optim.Adam(embedder.parameters(), lr=1e-3 ) 
criterion = nn.CrossEntropyLoss()
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.9)


for num in range(repeat_num):
    start = time.time()
    print(f'epoch :{num}')
    for epoch in range(max_epochs):
        e_time = time.time()
        for i, (images, labels) in enumerate(trainloader):
            embedder.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            embeddings = embedder(images)

            logits = arcface(embeddings, labels)
            loss = criterion(logits, labels)
            loss.backward() 
            optimizer.step() 
        
        print(f'Epoch: {epoch} - Batch: {i*batch_size} - Loss: {loss:.6f} - Time:{time.time() - e_time}')
        with open(result_path + 'loss-'+str(num) +'.txt', 'a') as file:
            file.write(f'{loss:.6f}\n')
                    
    torch.save(embedder, result_path + str(num)  + '_model.pth')
    
    #  정확도 측정
    train_results = []
    train_labels = []
    test_results = []
    test_labels = []

    embedder.eval()
    with torch.no_grad():
      for img, label in trainloader:
        img = img.cuda()
        train_results.append(embedder(img).cpu().detach().numpy())
        train_labels.append(label)

    train_results = np.concatenate(train_results)
    train_labels = np.concatenate(train_labels)

    embedder.eval()
    with torch.no_grad():
      for img, label in testloader:
        img = img.cuda()
        test_results.append(embedder(img).cpu().detach().numpy())
        test_labels.append(label)

    test_results = np.concatenate(test_results)
    test_labels = np.concatenate(test_labels)
    
    # kNN 모델 선언
    k = 50
    model = KNeighborsClassifier(n_neighbors = k)
    # 모델 학습
    model.fit(train_results, train_labels)
    #knn 검증
    train_pred = model.predict(train_results)
    train_acc = (train_pred == train_labels).mean()
    with open(result_path + 'train_acc-'+str(num) +'.txt', 'a') as file:
        file.write(f'{train_acc:.6f}\n')
    print(f'train acc : {train_acc}')
    test_pred = model.predict(test_results)
    test_acc = (test_pred == test_labels).mean()
    with open(result_path + 'test_acc-'+str(num) +'.txt', 'a') as file:
        file.write(f'{test_acc:.6f}\n')
    print(f'test acc : {test_acc}')
    # F1-Score
    f1 = metrics.classification_report(test_labels, test_pred)
    with open(result_path + 'test_f1-'+str(num) +'.txt', 'a') as file:
        file.write(f'{f1}')



epoch :0
Epoch: 0 - Batch: 154176 - Loss: 16.572945 - Time:357.68433475494385
Epoch: 1 - Batch: 154176 - Loss: 16.799253 - Time:354.8403744697571
Epoch: 2 - Batch: 154176 - Loss: 16.459936 - Time:354.91351079940796
Epoch: 3 - Batch: 154176 - Loss: 15.805743 - Time:354.9035131931305
Epoch: 4 - Batch: 154176 - Loss: 15.308963 - Time:354.9006268978119
Epoch: 5 - Batch: 154176 - Loss: 14.646939 - Time:354.92612767219543
Epoch: 6 - Batch: 154176 - Loss: 13.073686 - Time:354.97727489471436
Epoch: 7 - Batch: 154176 - Loss: 10.913660 - Time:354.95422172546387
Epoch: 8 - Batch: 154176 - Loss: 7.061776 - Time:354.89097118377686
Epoch: 9 - Batch: 154176 - Loss: 7.079303 - Time:354.9812250137329
Epoch: 10 - Batch: 154176 - Loss: 6.961909 - Time:354.96593713760376
Epoch: 11 - Batch: 154176 - Loss: 4.173846 - Time:354.93353509902954
Epoch: 12 - Batch: 154176 - Loss: 4.219118 - Time:354.89596700668335
Epoch: 13 - Batch: 154176 - Loss: 3.103828 - Time:354.96120500564575
Epoch: 14 - Batch: 154176 - Los

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 0 - Batch: 154176 - Loss: 0.050388 - Time:355.15546202659607
Epoch: 1 - Batch: 154176 - Loss: 0.182207 - Time:355.0204565525055
Epoch: 2 - Batch: 154176 - Loss: 0.015397 - Time:354.96494579315186
Epoch: 3 - Batch: 154176 - Loss: 0.367057 - Time:354.9551160335541
Epoch: 4 - Batch: 154176 - Loss: 0.016822 - Time:354.96637630462646
Epoch: 5 - Batch: 154176 - Loss: 0.173450 - Time:355.0339078903198
Epoch: 6 - Batch: 154176 - Loss: 0.021558 - Time:354.97234654426575
Epoch: 7 - Batch: 154176 - Loss: 0.018314 - Time:354.9522588253021
Epoch: 8 - Batch: 154176 - Loss: 0.154696 - Time:354.98444414138794
Epoch: 9 - Batch: 154176 - Loss: 0.073390 - Time:355.0129282474518
Epoch: 10 - Batch: 154176 - Loss: 0.030394 - Time:355.01617527008057
Epoch: 11 - Batch: 154176 - Loss: 0.028555 - Time:354.9887173175812
Epoch: 12 - Batch: 154176 - Loss: 0.579171 - Time:354.97189807891846
Epoch: 13 - Batch: 154176 - Loss: 0.047128 - Time:354.99556946754456
Epoch: 14 - Batch: 154176 - Loss: 0.100585 - Time:

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 0 - Batch: 154176 - Loss: 0.029301 - Time:355.14159655570984
Epoch: 1 - Batch: 154176 - Loss: 0.015841 - Time:355.03322076797485
Epoch: 2 - Batch: 154176 - Loss: 0.373524 - Time:355.04896545410156
Epoch: 3 - Batch: 154176 - Loss: 0.007534 - Time:355.00809478759766
Epoch: 4 - Batch: 154176 - Loss: 0.016918 - Time:355.05256390571594
Epoch: 5 - Batch: 154176 - Loss: 0.048696 - Time:355.0219178199768
Epoch: 6 - Batch: 154176 - Loss: 0.027094 - Time:355.02407598495483
Epoch: 7 - Batch: 154176 - Loss: 0.019335 - Time:355.09976053237915
Epoch: 8 - Batch: 154176 - Loss: 0.012522 - Time:355.0571217536926
Epoch: 9 - Batch: 154176 - Loss: 0.027545 - Time:355.0568497180939
Epoch: 10 - Batch: 154176 - Loss: 0.015253 - Time:355.075966835022
Epoch: 11 - Batch: 154176 - Loss: 0.007637 - Time:355.0751795768738
Epoch: 12 - Batch: 154176 - Loss: 0.008777 - Time:355.10882115364075
Epoch: 13 - Batch: 154176 - Loss: 0.008426 - Time:355.1550934314728
Epoch: 14 - Batch: 154176 - Loss: 0.019301 - Time:3

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 0 - Batch: 154176 - Loss: 0.036050 - Time:355.2459325790405
Epoch: 1 - Batch: 154176 - Loss: 0.028912 - Time:355.08333110809326
Epoch: 2 - Batch: 154176 - Loss: 0.009572 - Time:355.1014151573181
Epoch: 3 - Batch: 154176 - Loss: 0.008002 - Time:355.14819836616516
Epoch: 4 - Batch: 154176 - Loss: 0.015061 - Time:355.0971682071686
Epoch: 5 - Batch: 154176 - Loss: 0.012050 - Time:355.0558297634125
Epoch: 6 - Batch: 154176 - Loss: 0.088712 - Time:355.0619671344757
Epoch: 7 - Batch: 154176 - Loss: 0.009223 - Time:355.1595878601074
Epoch: 8 - Batch: 154176 - Loss: 0.006119 - Time:355.08141899108887
Epoch: 9 - Batch: 154176 - Loss: 0.111193 - Time:355.0972936153412
Epoch: 10 - Batch: 154176 - Loss: 0.008907 - Time:355.1262857913971
Epoch: 11 - Batch: 154176 - Loss: 0.014611 - Time:355.09205889701843
Epoch: 12 - Batch: 154176 - Loss: 0.005718 - Time:355.11952114105225
Epoch: 13 - Batch: 154176 - Loss: 0.047975 - Time:355.06729555130005
Epoch: 14 - Batch: 154176 - Loss: 0.008870 - Time:35

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 0 - Batch: 154176 - Loss: 0.007321 - Time:355.1912908554077
Epoch: 1 - Batch: 154176 - Loss: 0.018132 - Time:355.08872628211975
Epoch: 2 - Batch: 154176 - Loss: 0.004213 - Time:355.11781334877014
Epoch: 3 - Batch: 154176 - Loss: 0.009181 - Time:355.1110680103302
Epoch: 4 - Batch: 154176 - Loss: 0.002648 - Time:355.1199731826782
Epoch: 5 - Batch: 154176 - Loss: 0.004245 - Time:355.0886559486389
Epoch: 6 - Batch: 154176 - Loss: 0.004584 - Time:355.116286277771
Epoch: 7 - Batch: 154176 - Loss: 0.003216 - Time:355.1081006526947
Epoch: 8 - Batch: 154176 - Loss: 0.008734 - Time:355.11406087875366
Epoch: 9 - Batch: 154176 - Loss: 0.003900 - Time:355.0896887779236
Epoch: 10 - Batch: 154176 - Loss: 0.015795 - Time:355.13466906547546
Epoch: 11 - Batch: 154176 - Loss: 0.003123 - Time:355.31604528427124
Epoch: 12 - Batch: 154176 - Loss: 0.007725 - Time:355.1312212944031
Epoch: 13 - Batch: 154176 - Loss: 0.031019 - Time:355.14549803733826
Epoch: 14 - Batch: 154176 - Loss: 0.004701 - Time:355

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [17]:
#embedder = torch.load(result_path + "gan_model.pth", map_location=device)
#print(embedder)