In [1]:
import os
import yaml
import tqdm
import torch
import datetime
import requests
import torchvision
import numpy as np
import torch.nn as nn
from torch.optim import Adam
from autoencoder_cnn import autoencoder
from torch.utils.data.dataset import Dataset
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data.dataloader import DataLoader
from torchmetrics.image import StructuralSimilarityIndexMeasure

In [None]:
BATCH_SIZE = None
EPOCHS = None
URL = None
OUTPUT_PATH = None
DATASET_PATH = None
AUTOENCODER_PATH = None
PRETRAINED_PATH = None
LEARNING_RATE = None

FREEZE_ENCODER = None

device = None

In [None]:
configFile = "conf.yaml"

In [None]:
conf = yaml.safe_load(configFile)

In [None]:
BATCH_SIZE = conf['finetuning']['batch_size']
EPOCHS = conf['finetuning']['epochs']
URL = conf['finetuning']['webhook']
OUTPUT_PATH = conf['finetuning']['output_path']
DATASET_PATH = conf['finetuning']['dataset_path']
AUTOENCODER_PATH = conf['finetuning']['model_input']
FREEZE_ENCODER = conf['finetuning']['freeze_encoder']
LEARNING_RATE = conf['finetuning']['learning_rate']

PRETRAINED_PATH = conf['autoencoder']['pretrained_path']


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Device: ", device)

In [2]:
#reconstruct the original model ...
model = torchvision.models.vgg16(pretrained=False)

model.load_state_dict(torch.load('pretrained/vgg_face_dag_NEW.pth'))

encoder = model.features[:11]

del model

model = autoencoder(encoder)

model.load_state_dict(torch.load(''))



<All keys matched successfully>

In [3]:



def increasingLoss(losses):
    if len(losses) < 3:
        return False

    last_three = losses[-3:]
    return last_three[0] < last_three[1] < last_three[2]

def send_webHook(url, text):
    current_time = datetime.datetime.now()
    current_time_str = current_time.strftime("%Y-%m-%d %H:%M:%S")
    response = requests.get(url, data=f"at time {current_time_str} -> {text}")

    if response.status_code == 200:
        print("OK")
    else:
        print("Error")

In [4]:
model.freeze_batchNorm()
model.freeze_encoder(FREEZE_ENCODER)

In [None]:
def evaluate_epoch(model,dataset,device):
    model.eval()

    ssim = StructuralSimilarityIndexMeasure(data_range=1).to(device)
    mse = nn.MSELoss().to(device)
    running_vloss = 0.0

    with torch.no_grad():
        for img in (dataset):
            img = img.to(device)
            out = model(img).to(device)
            vloss = 1 - ssim(out, img) + mse(out, img)
            running_vloss += vloss
    return running_vloss / len(dataset)

In [None]:
class EarlyStopper:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


class FacesDataset(Dataset):
    r"""
    Dataset class to load the Bonafide images. 
    """
    def __init__(self, im_path, im_ext='png'):
        r"""
        Init method for initializing the dataset properties
        :param split: train/test to locate the image files
        :param im_path: root folder of images
        :param im_ext: image extension. assumes all
        images would be this type.
        """
        self.im_ext = im_ext
        self.identities = self.load_images(im_path)
        self.keys = list(self.identities.keys())
    
    def load_images(self, im_path):
        r"""
        Gets all images from the path specified
        and stacks them all up
        :param im_path:
        :return:
        """
        identities = {}
        assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
        
        for p in os.listdir(im_path):
            for i in os.listdir(os.path.join(im_path, p)):
                id = i.split('.')[0] # getting only the number ...
                id = int(id)

                if id not in identities:
                    identities[id] = {}
                identities[id][p] = os.path.join(im_path, p, i)

        print('Found {} identities.'.format(len(identities)))
        return identities
    
    def __len__(self):
        return len(self.identities)
    
    def __getitem__(self, index):
        type = "neutral"
        if np.random.randint(2) == 1:
            type = "smile"
        #print("ASKING FOR ITEM:", index, type)

        item = self.keys[index]
        #print("REAL ITEM: ", item)
        im = Image.open(self.identities[item][type])  
        im_tensor = torchvision.transforms.ToTensor()(im)
        # Convert input to -1 to 1 range.
        #im_tensor = (2 * im_tensor) - 1
        return im_tensor

In [None]:
ds = FacesDataset(DATASET_PATH)

train_size = int(0.8 * len(ds))
test_size = len(ds) - train_size
train_ds, test_ds = torch.utils.data.random_split(ds, [train_size, test_size])

In [None]:
def train(model, dataset, eval_set):
    
    dataset = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    eval_set = DataLoader(eval_set, batch_size=BATCH_SIZE, shuffle=True)

    # Create output directories
    if not os.path.exists("out"):
        os.mkdir("out")
    
    learning_rate = LEARNING_RATE
    last_epoch_change = 0
    # Specify training parameters
    optimizer = Adam(model.parameters(), lr=learning_rate)
    # learning rate scheduler to decrease it gradually ...
    scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

    early_stopper = EarlyStopper(patience=3, min_delta=0.001)

    ssim = StructuralSimilarityIndexMeasure(data_range=1).to(device)
    mse = nn.MSELoss().to(device)
    
    # Run training
    for epoch_idx in range(EPOCHS):
        
        model.train(True)
        
        losses = []
        eval_losses = []
        
        if device.type == 'cuda':
             torch.cuda.empty_cache()

        lastimg = None
        lastReco = None

        for image in tqdm(dataset):

            optimizer.zero_grad()
            image = image.float().to(device)

            reconstructed = model(image).to(device)


            loss = 1 - ssim(image, reconstructed) + mse(image, reconstructed) #combining mse and ssim loss
            losses.append(loss.item())
            loss.backward()
            optimizer.step()

        
        # check last three epochs loss, if the optimizer is not converging, decrease the learning rate ...
        if (epoch_idx - last_epoch_change > 2) and increasingLoss(losses) and learning_rate > 0.001:

            # decrease learning rate by factor of 2 ...
            learning_rate = learning_rate * 0.5

            optimizer = Adam(model.parameters(), lr=learning_rate)
            # learning rate scheduler to decrease it gradually ...
            scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
            last_epoch_change = epoch_idx

        else:
            scheduler.step()

        eval_loss = evaluate_epoch(model, eval_set, device)
        eval_losses.append(eval_loss)

        result = ('Finished epoch:{} | Loss : {:.4f} | Learning Rate: {} | Eval loss: {}'.format(
            epoch_idx + 1,
            np.mean(losses),
            learning_rate,
            eval_loss
        ))

        send_webHook(URL, result)
        print(result)

        torch.save(model.state_dict(), os.path.join(OUTPUT_PATH,
                                                    "ae_Casia_{}.pth".format(epoch_idx)))
        if early_stopper.early_stop(eval_loss):
            print('Early stopping')
            break
    
    print('Done Training ...')
    return loss, eval_loss

In [None]:
loss = []
eval_loss = []

model.to(device)

starting_loss = evaluate_epoch(model, train_ds, device)
starting_eval_loss = evaluate_epoch(model, test_ds, device)

print("Starting Loss: ", starting_loss)
print("Starting Eval Loss: ", starting_eval_loss)

loss.append(starting_loss)
eval_loss.append(starting_eval_loss)

In [None]:
train_loss, train_eval_loss = train(model, train_ds, test_ds)

loss.append(train_loss)
eval_loss.append(train_eval_loss)

In [None]:
import matplotlib as plt

plt.plot(loss)
plt.plot(eval_loss)
plt.show()