In [None]:
# Steps partially taken from https://debuggercafe.com/advanced-facial-keypoint-detection-with-pytorch/

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from sklearn.model_selection import train_test_split
import timm


%matplotlib inline
%config Completer.use_jedi = False

ROOT = "/home/lenin/code/hat_on_the_head/"
DATA = ROOT + "data/"
KP_DATA = DATA + "kaggle_keypoints/"
RANDOM_SEED = 42

In [None]:
def load(path):
    img = cv2.imread(path)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def load_item(root, row):
    img = load(root + row.img)
    keypoints = [[float(e) for e in t.split(",")] for t in row.our_kpts.split(";")]
    return img, [keypoints[0]]

def show_w_kpts(img, kpts):
    plt.figure(figsize=(10, 10))
    plt.imshow(img)
    keypoints = np.array(kpts)
    for j in range(len(keypoints)):
        plt.plot(keypoints[j, 0], keypoints[j, 1], 'b.')
    plt.show()

df = pd.read_csv(KP_DATA + "training_frames_keypoints.csv")
if True:
    def join_kpts(row):
        kpts = []
        kpts.append(f'{row["78"]},{row["79"]}')
#         kpts.append(f'{row["84"]},{row["85"]}')
        return ";".join(kpts)

    df["img"] = df["Unnamed: 0"]
    df["our_kpts"] = df.apply(join_kpts,axis=1)
    df = df.drop(columns=[str(i) for i in range(136)] + ["Unnamed: 0"])

#     df.to_csv(DATA + "our_train_kaggle_keypoints.csv", index=False)

print(f"total images {len(df)}")    
df.head()

In [None]:
H = 224
W = 224

tfms_train = A.Compose([
#     A.LongestMaxSize(448),
#     A.ShiftScaleRotate(border_mode=0, value=0, shift_limit=0.4, scale_limit=0.3, p=0.8),
#     A.RandomBrightnessContrast(p=0.2),
#     A.CLAHE(),
    #A.RandomCrop(320, 320),
    A.Resize(H, W),
    A.Normalize(),
    #ToTensorV2(),
], keypoint_params=A.KeypointParams(format='xy'))
tfms_valid = A.Compose([
    A.Resize(H, W),
    A.Normalize(),
    #ToTensorV2(),
], keypoint_params=A.KeypointParams(format='xy'))

sample = df.sample(1).iloc[0]
orig_img, orig_kpts = load_item(KP_DATA + "training/", sample)

res = tfms_train(image=orig_img, keypoints=orig_kpts)
img_tfmd = res["image"] #.transpose(0, -1).numpy()
kpts_tfmd = res["keypoints"]
show_w_kpts(img_tfmd, kpts_tfmd)

In [None]:
class KeypointsDataset(torch.utils.data.Dataset):
    
    def __init__(self, root, df, aug=A.Compose([])):
        self.root = root
        self.df = df
        self.aug = aug
        self.to_tensor = ToTensorV2()

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, i):
        row = self.df.iloc[i]
        orig_img, orig_kpts = load_item(self.root, row)
        res = self.aug(image=orig_img, keypoints=orig_kpts)
        while len(res["keypoints"]) < 1:
            res = self.aug(image=orig_img, keypoints=orig_kpts)
        img_tfmd = self.to_tensor(image=res["image"])["image"]
        kpts_tfmd = res["keypoints"]
        kpts_tfmd = np.array(kpts_tfmd) / np.array([W, H])
        
        return img_tfmd, torch.FloatTensor([kp for x in kpts_tfmd for kp in x])

In [None]:
train_df, val_df = train_test_split(df, test_size=0.15, shuffle=True, random_state=RANDOM_SEED)

In [None]:
batch_size = 64
num_workers = 0

train_ds = KeypointsDataset(root=KP_DATA + "training/", df=train_df, aug=tfms_train)
val_ds = KeypointsDataset(root=KP_DATA + "training/", df=val_df, aug=tfms_valid)

train_dl = torch.utils.data.DataLoader(
    dataset=train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)

val_dl = torch.utils.data.DataLoader(
    dataset=val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

In [None]:
model = timm.create_model('efficientnet_b0', pretrained=True)
model.classifier = torch.nn.Linear(model.classifier.in_features, out_features=2, bias=True)

In [None]:
import pytorch_lightning as pl
from torchmetrics import functional as metrics
pl.seed_everything(RANDOM_SEED)


In [None]:
class HatModule(pl.LightningModule):
    def __init__(self, model, optimizer_name, optimizer_hparams):
        super().__init__()
        # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
        self.save_hyperparameters()
        # Create model
        self.model = model
        # Create loss module
        self.loss_module = torch.nn.SmoothL1Loss()
        # Example input for visualizing the graph in Tensorboard
        # self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)

    def forward(self, imgs):
        # Forward function that is run when visualizing the graph
        return self.model(imgs)

    def configure_optimizers(self):
        # We will support Adam or SGD as optimizers.
        if self.hparams.optimizer_name == "Adam":
            # AdamW is Adam with a correct implementation of weight decay (see here
            # for details: https://arxiv.org/pdf/1711.05101.pdf)
            optimizer = torch.optim.AdamW(self.model.parameters(), **self.hparams.optimizer_hparams)
        elif self.hparams.optimizer_name == "SGD":
            optimizer = torch.optim.SGD(self.model.parameters(), **self.hparams.optimizer_hparams)
        else:
            assert False, f'Unknown optimizer: "{self.hparams.optimizer_name}"'

        # We will reduce the learning rate by 0.1 after 100 and 150 epochs
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # "batch" is the output of the training data loader.
        imgs, keypoints = batch
        preds = self.model(imgs)
        loss = self.loss_module(preds, keypoints)
        #rmse = metrics.mean_squared_error(preds, keypoints, squared=False)
        #self.log("train_rmse", rmse, prog_bar=True, on_step=True)
        self.log("train_loss", loss, prog_bar=True, on_step=True)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        imgs, keypoints = batch
        preds = self.model(imgs)
        loss = self.loss_module(preds, keypoints)
        rmse = metrics.mean_squared_error(preds, keypoints, squared=False)
        self.log("val_rmse", rmse, prog_bar=True, on_step=True)
        # By default logs it per epoch (weighted average over batches)
        self.log("val_loss", loss, prog_bar=True, on_step=True)

    def test_step(self, batch, batch_idx):
        self.validation_step(batch, batch_idx)

In [None]:
device = "cuda:0"

trainer = pl.Trainer(
    #default_root_dir=os.path.join(CHECKPOINT_PATH, save_name),  # Where to save models
    # We run on a single GPU (if possible)
    gpus=1 if str(device) == "cuda:0" else 0,
    # How many epochs to train for if no patience is set
    max_epochs=30,
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            save_weights_only=True, mode="min", monitor="val_loss", verbose=True,
        ),  # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
        pl.callbacks.LearningRateMonitor("epoch"),
    ],
) 

module = HatModule(model, 'Adam', {"lr": 0.001})

In [None]:
trainer.fit(module, train_dataloaders=train_dl, val_dataloaders=val_dl)

In [None]:
ckpt = "/home/lenin/code/hat_on_the_head/notebooks/lightning_logs/version_18/checkpoints/epoch=6-step=321.ckpt"
module.load_from_checkpoint(ckpt)

In [None]:
module.eval()

In [None]:
sample = val_df.iloc[0]
orig_img, orig_kpts = load_item(KP_DATA + "training/", sample)

res = tfms_valid(image=orig_img, keypoints=orig_kpts)
img_tfmd = res["image"]
kpts_tfmd = res["keypoints"]
#show_w_kpts(img_tfmd, kpts_tfmd)

show_w_kpts(orig_img, orig_kpts)

In [None]:
tfms_test = A.Compose([

#     A.LongestMaxSize(448),
#     A.ShiftScaleRotate(border_mode=0, value=0, shift_limit=0.4, scale_limit=0.3, p=0.8),
#     A.RandomBrightnessContrast(p=0.2),
#     A.CLAHE(),
   # A.Resize(H, W),
    A.Normalize(),
    ToTensorV2(),
])

img = load("/home/lenin/img1.png")
img = tfms_test(image=img)["image"]

out = module.forward(img.unsqueeze(0).to(device))
img = img.moveaxis(0, -1).cpu().detach().numpy()

kpts = out.cpu().detach().numpy()[0] * 224
#kpts = [kpts[:2], kpts[2:]]
kpts

show_w_kpts(img, [kpts])