In [1]:
import os
import torch
import numpy as np
import torchvision
from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as T
import torch.nn.functional as F

from ig_pkg.datasets import get_datasets

from ig_pkg.models.generator import get_model
from ig_pkg.models.classifier import get_classifier
from ig_pkg.models.pretrained_models import get_pretrained_model

from ig_pkg.inputattribs.ig import ig
from ig_pkg.inputattribs.baseline_generator import get_baseline_generator

from ig_pkg.misc import process_heatmap, normalize_tensor, convert_to_img, label_to_class, tran, na_imshow

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]


In [7]:
import time
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from ig_pkg.loss.focal_loss import FocalLoss
from ig_pkg.loss.metrics import ArcMarginProduct, AddMarginProduct
import torchvision.models as models

device="cuda:0"
train_dataset, valid_dataset = get_datasets(name= 'celebAHQ_whole', data_path = '/root/data/whole') # 3819,2398 
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

model = models.resnet18(pretrained=True)
num_features = model.fc.in_features # 512
model.fc = nn.Linear(num_features, 3819) # multi-class classification (num_of_class == 307)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 30
start_time = time.time()

for epoch in range(num_epochs):
    """ Training Phase """
    model.train()

    running_loss = 0.
    running_corrects = 0

    # load a batch data of images
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward inputs and get output
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # get loss value and update the network weights
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects / len(train_dataset) * 100.
    print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
    if epoch_acc > 99: break

save_path = f'/root/pretrained/celebahq_whole_flip.pth'
torch.save(model.state_dict(), save_path)


flip
[Train #0] Loss: 8.4526 Acc: 0.0395% Time: 35.4661s
[Train #1] Loss: 7.7940 Acc: 0.9089% Time: 71.2390s
[Train #2] Loss: 7.0296 Acc: 4.7287% Time: 107.3937s
[Train #3] Loss: 6.2181 Acc: 12.7107% Time: 143.7977s
[Train #4] Loss: 5.4099 Acc: 23.7355% Time: 180.4143s
[Train #5] Loss: 4.6099 Acc: 36.3541% Time: 216.4243s
[Train #6] Loss: 3.8152 Acc: 49.9078% Time: 252.1479s
[Train #7] Loss: 3.0492 Acc: 63.3562% Time: 288.0173s
[Train #8] Loss: 2.3572 Acc: 75.2107% Time: 324.2294s
[Train #9] Loss: 1.7414 Acc: 84.6681% Time: 360.0335s
[Train #10] Loss: 1.2201 Acc: 93.1639% Time: 396.0384s
[Train #11] Loss: 0.8040 Acc: 98.2745% Time: 431.8322s
[Train #12] Loss: 0.4969 Acc: 99.7761% Time: 467.9667s


In [4]:
import time
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from ig_pkg.loss.focal_loss import FocalLoss
from ig_pkg.loss.metrics import ArcMarginProduct, AddMarginProduct
import torchvision.models as models

device="cuda:0"
train_dataset, valid_dataset = get_datasets(name= 'celebAHQ_5', data_path = '/root/data/train') # 3819,2398 
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

model = models.resnet18(pretrained=True)
num_features = model.fc.in_features # 512
model.fc = nn.Linear(num_features, 2398) # multi-class classification (num_of_class == 307)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 30
start_time = time.time()

for epoch in range(num_epochs):
    """ Training Phase """
    model.train()

    running_loss = 0.
    running_corrects = 0

    # load a batch data of images
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward inputs and get output
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # get loss value and update the network weights
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects / len(train_dataset) * 100.
    print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
    if epoch_acc > 99: break
        
save_path = f'/root/pretrained/celebahq_train_flip.pth'
torch.save(model.state_dict(), save_path)

flip
[Train #0] Loss: 7.3216 Acc: 1.6690% Time: 104.6927s
[Train #1] Loss: 5.5499 Acc: 11.7503% Time: 210.6489s
[Train #2] Loss: 3.9619 Acc: 29.8465% Time: 308.8367s
[Train #3] Loss: 2.7269 Acc: 49.8036% Time: 407.9734s
[Train #4] Loss: 1.8144 Acc: 66.9538% Time: 513.0342s
[Train #5] Loss: 1.1648 Acc: 80.8729% Time: 617.6198s
[Train #6] Loss: 0.7369 Acc: 89.1423% Time: 722.1929s
[Train #7] Loss: 0.4588 Acc: 94.5064% Time: 827.2891s
[Train #8] Loss: 0.2678 Acc: 97.6214% Time: 931.3937s
[Train #9] Loss: 0.1543 Acc: 99.0628% Time: 1038.3805s


In [5]:
import time
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from ig_pkg.loss.focal_loss import FocalLoss
from ig_pkg.loss.metrics import ArcMarginProduct, AddMarginProduct
import torchvision.models as models

device="cuda:0"
train_dataset, valid_dataset = get_datasets(name= 'celebAHQ_5', data_path = '/root/data/identity_celebahq') # 3819,2398 
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

class_names = train_dataset.classes
# np.save('/root/data/celebahq_identity/identity_original.npy', class_names)

model = models.resnet18(pretrained=True)
num_features = model.fc.in_features # 512
model.fc = nn.Linear(num_features, 6217) 
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 30
start_time = time.time()

for epoch in range(num_epochs):
    """ Training Phase """
    model.train()

    running_loss = 0.
    running_corrects = 0

    # load a batch data of images
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward inputs and get output
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # get loss value and update the network weights
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects / len(train_dataset) * 100.
    print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
    if epoch_acc > 99: break
        
save_path = f'/root/pretrained/celebahq_original_{epoch_acc}.pth'
torch.save(model.state_dict(), save_path)

flip
[Train #0] Loss: 8.2136 Acc: 1.1600% Time: 147.1768s
[Train #1] Loss: 6.6772 Acc: 7.7500% Time: 276.9724s
[Train #2] Loss: 5.3267 Acc: 19.8033% Time: 417.5162s
[Train #3] Loss: 4.1686 Acc: 34.1200% Time: 558.7775s
[Train #4] Loss: 3.2264 Acc: 47.7167% Time: 700.2342s
[Train #5] Loss: 2.4480 Acc: 59.7633% Time: 839.9531s
[Train #6] Loss: 1.8121 Acc: 70.5167% Time: 969.7769s
[Train #7] Loss: 1.2896 Acc: 79.4267% Time: 1099.5536s
[Train #8] Loss: 0.8690 Acc: 87.0167% Time: 1239.1792s
[Train #9] Loss: 0.5616 Acc: 92.6433% Time: 1369.0747s
[Train #10] Loss: 0.3361 Acc: 97.0133% Time: 1500.3617s
[Train #11] Loss: 0.1942 Acc: 99.0900% Time: 1639.3315s


In [6]:
import time
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from ig_pkg.loss.focal_loss import FocalLoss
from ig_pkg.loss.metrics import ArcMarginProduct, AddMarginProduct
import torchvision.models as models

device="cuda:0"
train_dataset, valid_dataset = get_datasets(name= 'celebAHQ_5', data_path = '/root/data/identity_celebahq') # 3819,2398 
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

class_names = train_dataset.classes
# np.save('/root/data/celebahq_identity/identity_original.npy', class_names)

model = models.resnet18(pretrained=True)
num_features = model.fc.in_features # 512
model.fc = nn.Linear(num_features, 6217) 
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 30
start_time = time.time()

for epoch in range(num_epochs):
    """ Training Phase """
    model.train()

    running_loss = 0.
    running_corrects = 0

    # load a batch data of images
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward inputs and get output
        optimizer.zero_grad()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # get loss value and update the network weights
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects / len(train_dataset) * 100.
    print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
    if epoch_acc > 99: break
        
save_path = f'/root/pretrained/celebahq_original_flip.pth'
torch.save(model.state_dict(), save_path)

flip
[Train #0] Loss: 8.2226 Acc: 0.9433% Time: 130.2989s
[Train #1] Loss: 6.7668 Acc: 6.8633% Time: 261.0708s
[Train #2] Loss: 5.4323 Acc: 18.3667% Time: 392.4425s
[Train #3] Loss: 4.2724 Acc: 32.9167% Time: 523.3748s
[Train #4] Loss: 3.3012 Acc: 46.5333% Time: 653.9012s
[Train #5] Loss: 2.5012 Acc: 59.0767% Time: 783.5817s
[Train #6] Loss: 1.8478 Acc: 70.0500% Time: 914.2168s
[Train #7] Loss: 1.3232 Acc: 78.8067% Time: 1051.4545s
[Train #8] Loss: 0.8940 Acc: 86.6667% Time: 1183.8315s
[Train #9] Loss: 0.5820 Acc: 92.2900% Time: 1324.8905s
[Train #10] Loss: 0.3470 Acc: 96.9033% Time: 1457.5039s
[Train #11] Loss: 0.1978 Acc: 99.1067% Time: 1589.2342s
