In [None]:
import sys
sys.path.append('..')

In [None]:
import os
import traceback
import matplotlib.pyplot as plt

from core.datasets import *
from core.networks import *
from core.models import *

# Parameters

In [None]:
device = torch.device('cuda:0')
lr = 1e-3
n_epochs = 34
early_stop_threshold = 1e-2

batch_size = 32
n_workers = 4
shuffle = True

cae_latent_dim = 32
cae_stride = 2
resnet_model_no = 34

dir_data = '/Users/Linsu Han/Documents/Data/celeba/clean/'
dir_load = None
dir_save = '../resources/models/'
path_metadata = '/Users/Linsu Han/Documents/Data/celeba/list_attr_celeba.csv'
features = ['Attractive', 'Bags_Under_Eyes', 'Bangs', 'Chubby', 'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Smiling', 'Wearing_Lipstick', 'Young']

# Initialize Dataloader

In [None]:
dataset = CelebA(dir_data, path_metadata, features)

In [None]:
train_len = int(len(dataset)*.8)
val_len = len(dataset) - train_len
print(train_len, val_len)

In [None]:
dataset_train = torch.utils.data.Subset(dataset, list(range(0, train_len)))
dataset_val = torch.utils.data.Subset(dataset, list(range(train_len, len(dataset))))

In [None]:
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers)

# Initialize Networks

In [None]:
x_shape = (3, 224, 224)
network_cae = ConvAutoencoder(cae_latent_dim, *x_shape, stride=cae_stride).to(device)
network_resnet = ResNet(34, len(features), in_channels=3).to(device)

# Initialize Models

In [None]:
model_cae = ModelCAE(network_cae, lr=lr)
model_cls = ModelSigmoidClassifier(network_resnet, lr=lr)

# Loading Saved Models

In [None]:
if dir_load is not None:
    model_cae.load(dir_load + 'cae.pth')
    model_cls.load(dir_load + 'cls.pth')

# Training Loop

In [None]:
model_cae.network.train()
model_cls.network.train()

In [None]:
torch.cuda.empty_cache()
df_cae = []
df_cls = []

loss_cae_val = np.inf
loss_cls_val = np.inf

for epoch in range(n_epochs):
    print(f'Epoch: {epoch}')
    
    for idx, x, y in tqdm(dataloader_train):
        try:
            x = x.to(device)
            y = y.to(device)
            model_cae.update(x)
            model_cls.update(x, y)
        except Exception as e:
            traceback.print_exc()
            breakpoint()
    
    loss_cae_train = np.mean(model_cae.loss_history['training'][-len(dataloader_train):])
    loss_cls_train = np.mean(model_cls.loss_history['training'][-len(dataloader_train):])
    print('Training Loss (cae):', loss_cae_train)
    print('Training Loss (cls):', loss_cls_train)
    
    for idx, x, y in tqdm(dataloader_val):
        try:
            x = x.to(device)
            y = y.to(device)
            out_cae = model_cae.eval(x)
            out_cls = model_cls.eval(x, y)
        except Exception as e:
            traceback.print_exc()
            breakpoint()
    
    loss_cae_val_prev = loss_cae_val
    loss_cls_val_prev = loss_cls_val
    
    loss_cae_val = np.mean(model_cae.loss_history['validation'][-len(dataloader_val):])
    loss_cls_val = np.mean(model_cls.loss_history['validation'][-len(dataloader_val):])
    print('Validation Loss (cae):', loss_cae_val)
    print('Validation Loss (cls):', loss_cls_val)

    info_cae  = {'Epoch':epoch, 'Model':'cae', 'Training Loss':loss_cae_train, 'Validation Loss':loss_cae_val}
    info_cls  = {'Epoch':epoch, 'Model':'cls', 'Training Loss':loss_cls_train, 'Validation Loss':loss_cls_val}
    
    df_cae.append(info_cae)
    df_cls.append(info_cls)

    print('-'*13)
    
    early_stop = loss_cls_val/loss_cls_val_prev > 1 + early_stop_threshold
    if early_stop:
        break

In [None]:
df_cae = pd.DataFrame(df_cae)

In [None]:
df_cls = pd.DataFrame(df_cls)

In [None]:
plt.plot(df_cae['Training Loss'])
plt.plot(df_cae['Validation Loss'])

In [None]:
plt.plot(df_cls['Training Loss'])
plt.plot(df_cls['Validation Loss'])

# Saving Models

In [None]:
if not os.path.exists(dir_save):
    os.makedirs(dir_save)
model_cae.save(dir_save + 'cae.pth')
model_cls.save(dir_save + 'cls.pth')