In [1]:
import torch
from lib.backbone import StackedHGNetV1
from lib.utility import get_dataloader

In [2]:
class Config_300W:
    def __init__(self):
        self.classes_num = [68, 9, 68]
        self.edge_info = (
                (False, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)),  # FaceContour
                (False, (17, 18, 19, 20, 21)),  # RightEyebrow
                (False, (22, 23, 24, 25, 26)),  # LeftEyebrow
                (False, (27, 28, 29, 30)),  # NoseLine
                (False, (31, 32, 33, 34, 35)),  # Nose
                (True, (36, 37, 38, 39, 40, 41)),  # RightEye
                (True, (42, 43, 44, 45, 46, 47)),  # LeftEye
                (True, (48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59)),  # OuterLip
                (True, (60, 61, 62, 63, 64, 65, 66, 67)),  # InnerLip
            )
        self.nstack = 4
        self.add_coord = True
        self.decoder_type = "default"
        self.width = 256
        self.height = 256
        self.use_AAM = True

In [3]:
class Config_WFLW:
    def __init__(self):
        self.classes_num = [98, 9, 98]
        self.edge_info = (
                (False, (
                    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
                    27,
                    28, 29, 30, 31, 32)),  # FaceContour
                (True, (33, 34, 35, 36, 37, 38, 39, 40, 41)),  # RightEyebrow
                (True, (42, 43, 44, 45, 46, 47, 48, 49, 50)),  # LeftEyebrow
                (False, (51, 52, 53, 54)),  # NoseLine
                (False, (55, 56, 57, 58, 59)),  # Nose
                (True, (60, 61, 62, 63, 64, 65, 66, 67)),  # RightEye
                (True, (68, 69, 70, 71, 72, 73, 74, 75)),  # LeftEye
                (True, (76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87)),  # OuterLip
                (True, (88, 89, 90, 91, 92, 93, 94, 95)),  # InnerLip
            )
        self.nstack = 4
        self.add_coord = True
        self.decoder_type = "default"
        self.width = 256
        self.height = 256
        self.use_AAM = True

In [4]:
config_300W = Config_300W()
net_300W = StackedHGNetV1(config=config_300W,
                        classes_num=config_300W.classes_num,
                        edge_info=config_300W.edge_info,
                        nstack=config_300W.nstack,
                        add_coord=config_300W.add_coord,
                        decoder_type=config_300W.decoder_type)
# Pretrained Model
model_path = "300W_STARLoss_NME_2_87.pkl"
checkpoint = torch.load(model_path,map_location=torch.device('cpu'))
net_300W.load_state_dict(checkpoint["net"])

  checkpoint = torch.load(model_path,map_location=torch.device('cpu'))


<All keys matched successfully>

In [5]:
config_WFLW = Config_WFLW()
net_WFLW = StackedHGNetV1(config=config_WFLW,
                        classes_num=config_WFLW.classes_num,
                        edge_info=config_WFLW.edge_info,
                        nstack=config_WFLW.nstack,
                        add_coord=config_WFLW.add_coord,
                        decoder_type=config_WFLW.decoder_type)

In [6]:
import torch
import torch.nn as nn
import torch.nn.init as init


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        # Using Kaiming He initialization for Conv2d layers
        init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        # Initialize BatchNorm with ones for weights and zeros for biases
        init.ones_(m.weight)
        init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        # Xavier initialization for Linear layers
        init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.zeros_(m.bias)
    # Add initialization for custom layers if required
    elif hasattr(m, 'weight') and m.weight is not None:
        # General initialization for custom layers
        init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.zeros_(m.bias)

net_WFLW.apply(initialize_weights)

StackedHGNetV1(
  (pre): Sequential(
    (0): CoordConvTh(
      (addcoords): AddCoordsTh()
      (conv): Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (relu): ReLU()
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResBlock(
      (relu): ReLU()
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): ConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      )
      (skip_layer): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size

In [7]:
net_300W

StackedHGNetV1(
  (pre): Sequential(
    (0): CoordConvTh(
      (addcoords): AddCoordsTh()
      (conv): Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (relu): ReLU()
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResBlock(
      (relu): ReLU()
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): ConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      )
      (skip_layer): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size

In [8]:
net_WFLW

StackedHGNetV1(
  (pre): Sequential(
    (0): CoordConvTh(
      (addcoords): AddCoordsTh()
      (conv): Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (relu): ReLU()
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResBlock(
      (relu): ReLU()
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): ConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): ConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      )
      (skip_layer): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size

In [9]:
def transfer_params(source_model,target_model):
    source_state_dict = source_model.state_dict()
    target_state_dict = target_model.state_dict()

    # Load matching layers
    pretrained_dict = {k: v for k, v in source_state_dict.items() if k in target_state_dict and v.shape == target_state_dict[k].shape}


    target_state_dict.update(pretrained_dict)
    target_model.load_state_dict(target_state_dict)

In [10]:
transfer_params(net_300W,net_WFLW)

In [11]:
for name, param in net_WFLW.named_parameters():
    print(name,param.shape)

pre.0.conv.weight torch.Size([64, 6, 7, 7])
pre.0.conv.bias torch.Size([64])
pre.0.bn.weight torch.Size([64])
pre.0.bn.bias torch.Size([64])
pre.1.bn1.weight torch.Size([64])
pre.1.bn1.bias torch.Size([64])
pre.1.conv1.conv.weight torch.Size([64, 64, 1, 1])
pre.1.conv1.conv.bias torch.Size([64])
pre.1.bn2.weight torch.Size([64])
pre.1.bn2.bias torch.Size([64])
pre.1.conv2.conv.weight torch.Size([64, 64, 3, 3])
pre.1.conv2.conv.bias torch.Size([64])
pre.1.bn3.weight torch.Size([64])
pre.1.bn3.bias torch.Size([64])
pre.1.conv3.conv.weight torch.Size([128, 64, 1, 1])
pre.1.conv3.conv.bias torch.Size([128])
pre.1.skip_layer.conv.weight torch.Size([128, 64, 1, 1])
pre.1.skip_layer.conv.bias torch.Size([128])
pre.3.bn1.weight torch.Size([128])
pre.3.bn1.bias torch.Size([128])
pre.3.conv1.conv.weight torch.Size([64, 128, 1, 1])
pre.3.conv1.conv.bias torch.Size([64])
pre.3.bn2.weight torch.Size([64])
pre.3.bn2.bias torch.Size([64])
pre.3.conv2.conv.weight torch.Size([64, 64, 3, 3])
pre.3.conv2

In [12]:
net_WFLW.out_pointmaps[3].conv.weight.std()

tensor(0.1425, grad_fn=<StdBackward0>)

In [13]:
net_300W.out_pointmaps[3].conv.weight.std()

tensor(0.0435, grad_fn=<StdBackward0>)

In [14]:
class Data_Config:
    def __init__(self):
            self.data_definition = "WFLW"
            self.train_tsv_file = 'pub_annot/WFLW/train.tsv'
            self.val_tsv_file = 'pub_annot/WFLW/test.tsv'
            self.train_pic_dir = 'pub_dataset/WFLW/WFLW_images'
            self.val_pic_dir = 'pub_dataset/WFLW/WFLW_images'
            self.loader_type = 'alignment'
            self.batch_size = 16
            self.val_batch_size = 32
            self.train_num_workers = 1
            self.val_num_workers = 1
            self.width = 256
            self.height = 256
            self.channels = 3
            self.means = (127.5, 127.5, 127.5)
            self.scale = 0.00784313725490196
            self.classes_num = [98, 9, 98]
            self.crop_op = True
            self.aug_prob = 1.0
            self.label_num = 12
            self.edge_info = (
                (False, (
                    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
                    27,
                    28, 29, 30, 31, 32)),  # FaceContour
                (True, (33, 34, 35, 36, 37, 38, 39, 40, 41)),  # RightEyebrow
                (True, (42, 43, 44, 45, 46, 47, 48, 49, 50)),  # LeftEyebrow
                (False, (51, 52, 53, 54)),  # NoseLine
                (False, (55, 56, 57, 58, 59)),  # Nose
                (True, (60, 61, 62, 63, 64, 65, 66, 67)),  # RightEye
                (True, (68, 69, 70, 71, 72, 73, 74, 75)),  # LeftEye
                (True, (76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87)),  # OuterLip
                (True, (88, 89, 90, 91, 92, 93, 94, 95)),  # InnerLip
            )
            self.flip_mapping = (
                [0, 32], [1, 31], [2, 30], [3, 29], [4, 28], [5, 27], [6, 26], [7, 25], [8, 24], [9, 23], [10, 22],
                [11, 21], [12, 20], [13, 19], [14, 18], [15, 17],  # cheek
                [33, 46], [34, 45], [35, 44], [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47],  # elbrow
                [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], [66, 74], [67, 73],
                [55, 59], [56, 58],
                [76, 82], [77, 81], [78, 80], [87, 83], [86, 84],
                [88, 92], [89, 91], [95, 93], [96, 97]
            )
            self.encoder_type = 'default'

            # Val
            self.norm_type = 'default'
            self.nme_left_index = 60
            self.nme_right_index = 72


In [15]:
data_config = Data_Config()
train_loader = get_dataloader(data_config, data_type='train', world_rank=0, world_size=1)
val_loader = get_dataloader(data_config, "val")

In [16]:
import numpy as np
import pandas as pd
import cv2
import re
import matplotlib.pyplot as plt
import torch

In [17]:

import torch.nn.functional as F
from lib.loss import *
from lib.metric import NME, FR_AUC

In [18]:
class Train_Config:
    def __init__(self):
        # self.classes_num = [68, 9, 68]
        self.nstack = 4
        # self.add_coord = True
        # self.decoder_type = "default"
        # self.width = 256
        # self.height = 256
        self.use_AAM = True
        self.label_num = self.nstack * 3 if self.use_AAM else self.nstack

        self.loss_func = "STARLoss_v2"

        # STAR Loss paras
        self.star_w = 1
        self.star_dist = 'smoothl1'

        self.loss_weights, self.criterion_labels, self.metrics = self.set_criterions()
        self.criterions = self.get_criterions()

        self.batch_weight = 1.0

        self.optimizer = "adam"
        self.learn_rate = 0.00001
        self.weight_decay = 0.00001
        self.betas = [0.9, 0.999]
        self.gamma = 0.9
        self.milestones = [10*i for i in range(2,10)]

    def set_criterions(self):
        loss_weights, criterions, metrics = [], [], []
        for i in range(self.nstack):
            factor = (2 ** i) / (2 ** (self.nstack - 1))
            if self.use_AAM:
                loss_weights += [factor * weight for weight in [1.0, 10.0, 10.0]]
                criterions += [self.loss_func, "AWingLoss", "AWingLoss"]
                metrics += ["NME", None, None]
            else:
                loss_weights += [factor * weight for weight in [1.0]]
                criterions += [self.loss_func, ]
                metrics += ["NME", ]
        return loss_weights, criterions, metrics
    
    def get_criterions(self):
        criterions = []
        for k in range(self.label_num):
            label = self.criterion_labels[k]
            if label == "AWingLoss":
                criterion = AWingLoss()
            elif label == "smoothl1":
                criterion = SmoothL1Loss()
            elif label == "l1":
                criterion = F.l1_loss
            elif label == 'l2':
                criterion = F.mse_loss
            elif label == "STARLoss":
                criterion = STARLoss(dist=self.star_dist, w=self.star_w)
            elif label == "STARLoss_v2":
                criterion = STARLoss_v2(dist=self.star_dist, w=self.star_w)
            else:
                assert False
            criterions.append(criterion)
        return criterions

In [19]:
train_config = Train_Config()

In [20]:
import lightning as L
from lib.metric import NME, FR_AUC

In [21]:

# define the LightningModule
class LitSTAR(L.LightningModule):
    def __init__(self, net, data_config, model_config, train_config):
        super().__init__()
        self.net = net
        self.data_config = data_config
        self.model_config = model_config
        self.train_config = train_config
        self.ave_losses = [0] * self.train_config.label_num
        self.list_nmes = [[] for i in range(self.data_config.label_num)]
        
    def training_step(self, sample, batch_idx):

        imgs = sample["data"].float()

        labels = []
        if isinstance(sample["label"], list):
            for label in sample["label"]:
                label = label.float()
                labels.append(label)
        else:
            label = sample["label"].float()
            for k in range(label.shape[1]):
                labels.append(label[:, k])
        labels = self.model_config.nstack * labels

        # forward
        output, heatmaps, landmarks = self.net(imgs)


        losses, sum_loss = self.compute_loss(output, labels, heatmaps, landmarks)
        self.ave_losses = list(map(sum, zip(self.ave_losses, losses)))
        avg_loss = sum(losses) / len(losses)
        # Logging to TensorBoard (if installed) by default
        self.log("sum_loss", sum_loss, prog_bar=True)
        self.log("AVG_loss",avg_loss,prog_bar=True)
        return sum_loss
    
    def validation_step(self,sample,batch_idx):
        
        metric_nme = NME(nme_left_index=self.data_config.nme_left_index, nme_right_index=self.data_config.nme_right_index)
        metric_fr_auc = FR_AUC(data_definition=self.data_config.data_definition)

        output_pd = None

        imgs = sample["data"].float()

        labels = []
        if isinstance(sample["label"], list):
            for label in sample["label"]:
                label = label.float()
                labels.append(label)
        else:
            label = sample["label"].float()
            for k in range(label.shape[1]):
                labels.append(label[:, k])
        labels = self.model_config.nstack * labels

        # forward
        output, heatmaps, landmarks = self.net(imgs)

        for k in range(self.data_config.label_num):
            if self.train_config.metrics[k] is not None:
                self.list_nmes[k] += metric_nme.test(output[k], labels[k])

    def on_validation_epoch_end(self):
        metric_nme = NME(nme_left_index=self.data_config.nme_left_index, nme_right_index=self.data_config.nme_right_index)
        metric_fr_auc = FR_AUC(data_definition=self.data_config.data_definition)
        metrics = [[torch.mean(torch.tensor(nmes)), ] + metric_fr_auc.test(torch.tensor(nmes)) for nmes in self.list_nmes]

        # self.log("Val_metrics",metrics)
        for k, metric in enumerate(metrics):
            nme, fr, auc = metric
            # print(metric)
            if not torch.isnan(nme):
                stack_no = k//3
                self.log(f"Stack{stack_no}_NME",nme,on_epoch=True)
                self.log(f"Stack{stack_no}_FR",fr,on_epoch=True)
                self.log(f"Stack{stack_no}_AUC",auc,on_epoch=True)
                # print("Val/Metric{:3d} in this epoch: [NME {:.6f}, FR {:.6f}, AUC {:.6f}]".format(
                #     k, metric[0], metric[1], metric[2]))

        self.list_nmes = [[] for i in range(self.data_config.label_num)]


    def compute_loss(self, output, labels, heatmap=None, landmarks=None):
        batch_weight = self.train_config.batch_weight
        sum_loss = 0
        losses = list()
        # print(self.train_config.criterion_labels)
        for k in range(self.train_config.label_num):
            
            if self.train_config.criterion_labels[k] in ['smoothl1', 'l1', 'l2', 'WingLoss', 'AWingLoss']:
                loss = self.train_config.criterions[k](output[k], labels[k])
            elif self.train_config.criterion_labels[k] in ["STARLoss", "STARLoss_v2"]:
                _k = int(k / 3) if self.train_config.use_AAM else k
                loss = self.train_config.criterions[k](heatmap[_k], labels[k])
            else:
                assert NotImplementedError
            loss = batch_weight * loss
            sum_loss += self.train_config.loss_weights[k] * loss
            losses.append(loss)
        return losses, sum_loss

    def configure_optimizers(self):
        params = self.net.parameters()

        optimizer = None
        if self.train_config.optimizer == "sgd":
            optimizer = torch.optim.SGD(
                params,
                lr=self.train_config.learn_rate,
                momentum=self.train_config.momentum,
                weight_decay=self.train_config.weight_decay,
                nesterov=self.train_config.nesterov)
        elif self.train_config.optimizer == "adam":
            optimizer = torch.optim.Adam(
                params,
                lr=self.train_config.learn_rate)
        elif self.train_config.optimizer == "rmsprop":
            optimizer = torch.optim.RMSprop(
                params,
                lr=self.train_config.learn_rate,
                momentum=self.train_config.momentum,
                alpha=self.train_config.alpha,
                eps=self.train_config.epsilon,
                weight_decay=self.train_config.weight_decay
            )
        else:
            assert False

        config_dict = {
            "optimizer" : optimizer,
            "lr_scheduler" : {
                "scheduler" : torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.train_config.milestones, gamma=self.train_config.gamma)
            }
        }
        return config_dict


# init the module
STAR = LitSTAR(net_WFLW, data_config, config_WFLW, train_config)

In [22]:
from lightning.pytorch.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval='epoch')

In [23]:
ckpt_path = 'lightning_logs/version_28/checkpoints/epoch=47-step=22512.ckpt'
if ckpt_path is not None:
    STAR = LitSTAR.load_from_checkpoint(ckpt_path,
    net = net_WFLW, 
    data_config = data_config, 
    model_config = config_WFLW,
    train_config = train_config)

In [26]:
from pytorch_lightning.loggers import TensorBoardLogger
log_dir = "."
version = "version_28"
logger = TensorBoardLogger(save_dir=log_dir, version=version)

In [27]:
trainer = L.Trainer(
    max_epochs=100,
    limit_train_batches = None,
    log_every_n_steps=1,
    callbacks=[lr_monitor],
    logger=logger
    )
trainer.fit(
    model=STAR, 
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    ckpt_path=ckpt_path
    )

/home/groups/sammer/haogeh/Python/asymm/STAR/.venv/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/groups/sammer/haogeh/Python/asymm/STAR/.venv/l ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/groups/sammer/haogeh/Python/asymm/STAR/.venv/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory ./lightning_logs/version_28/checkpoints exists and is not empty.
Restoring states from the checkpoint path at lightning_logs/version_28/checkpoints/epoch=47-step=22512.ckpt
/home/groups/sammer/haogeh/Python/asymm/STAR/.venv/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:362: The dirpath has changed from '/home/groups/sammer/haogeh/Py

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]