# Dog Keypoints - Save to Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
SAVE_DIR = '/content/drive/MyDrive/dogfacs_keypoints'
!mkdir -p {SAVE_DIR}
print(f"Files will be saved to: {SAVE_DIR}")

In [None]:
!pip install -q timm albumentations

In [None]:
import os
os.environ['KAGGLE_USERNAME'] = "dreadhorse"
os.environ['KAGGLE_KEY'] = "KGAT_9b0180da358cafd8fdde273e4f84c7cc"

In [None]:
!kaggle datasets download -d georgemartvel/dogflw -p /content/data --unzip --force

In [None]:
import json, random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import timm
from PIL import Image
from torchvision import transforms
import albumentations as A
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

device = torch.device("cuda")
DATA = Path("/content/data/DogFLW")
NUM_KP = 46

In [None]:
def load_split(name):
    imgs = DATA/name/"images"
    labs = DATA/name/"labels"
    out = []
    for js in labs.glob("*.json"):
        im = None
        for e in [".png",".jpg",".PNG",".JPG"]:
            c = imgs/f"{js.stem}{e}"
            if c.exists(): im=c; break
        if not im: continue
        try:
            with open(js) as f: d=json.load(f)
            kp = np.array(d["landmarks"][:NUM_KP], dtype=np.float32)
            if not np.isnan(kp).any() and (kp>=0).all():
                out.append({"img":str(im),"kp":kp})
        except: pass
    print(f"{name}: {len(out)}")
    return out

train_raw = load_split("train")
test_data = load_split("test")
random.shuffle(train_raw)
val_data = train_raw[int(0.9*len(train_raw)):]
train_data = train_raw[:int(0.9*len(train_raw))]

In [None]:
class DS(Dataset):
    def __init__(self, data, aug=False):
        self.data=data
        self.tf=transforms.Compose([transforms.ToTensor(),transforms.Normalize([.485,.456,.406],[.229,.224,.225])])
        self.aug=A.Compose([A.HorizontalFlip(p=0.5),A.Rotate(limit=15,p=0.5),A.RandomBrightnessContrast(p=0.3)],keypoint_params=A.KeypointParams(format="xy",remove_invisible=False)) if aug else None
    def __len__(self): return len(self.data)
    def __getitem__(self,i):
        d=self.data[i]
        im=Image.open(d["img"]).convert("RGB")
        ow,oh=im.size
        kp=d["kp"].copy()
        im=im.resize((256,256))
        kp[:,0]*=256/ow; kp[:,1]*=256/oh
        kp=np.clip(kp,0,255)
        if self.aug:
            try:
                r=self.aug(image=np.array(im),keypoints=[(k[0],k[1]) for k in kp])
                im=Image.fromarray(r["image"])
                kp=np.clip(np.array(r["keypoints"],dtype=np.float32),0,255)
            except: pass
        hm=np.zeros((NUM_KP,64,64),dtype=np.float32)
        for k in range(NUM_KP):
            x,y=int(kp[k,0]/4),int(kp[k,1]/4)
            x,y=max(0,min(63,x)),max(0,min(63,y))
            for di in range(-3,4):
                for dj in range(-3,4):
                    yi,xj=y+di,x+dj
                    if 0<=yi<64 and 0<=xj<64: hm[k,yi,xj]=np.exp(-(di**2+dj**2)/8)
        return self.tf(im),torch.from_numpy(hm),torch.from_numpy(kp)

train_ds=DS(train_data,aug=True)
val_ds=DS(val_data)
test_ds=DS(test_data)
train_ld=DataLoader(train_ds,batch_size=32,shuffle=True,num_workers=0)
val_ld=DataLoader(val_ds,batch_size=32,num_workers=0)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.bb=timm.create_model("resnet50",pretrained=True,features_only=True,out_indices=[-1])
        self.head=nn.Sequential(nn.ConvTranspose2d(2048,256,4,2,1),nn.BatchNorm2d(256),nn.ReLU(True),nn.ConvTranspose2d(256,256,4,2,1),nn.BatchNorm2d(256),nn.ReLU(True),nn.ConvTranspose2d(256,256,4,2,1),nn.BatchNorm2d(256),nn.ReLU(True))
        self.out=nn.Conv2d(256,NUM_KP,1)
    def forward(self,x): return self.out(self.head(self.bb(x)[-1]))

model=Model().to(device)
def decode(hm):
    B,K,H,W=hm.shape
    kp=torch.zeros(B,K,2)
    for b in range(B):
        for k in range(K):
            idx=hm[b,k].argmax()
            kp[b,k,0]=(idx%W)*4
            kp[b,k,1]=(idx//W)*4
    return kp
def pck(p,g,th=0.1): return (torch.sqrt(((p-g)**2).sum(-1))<th*256).float().mean().item()

In [None]:
crit=nn.MSELoss()
opt=optim.AdamW(model.parameters(),lr=1e-4,weight_decay=1e-4)
sch=optim.lr_scheduler.CosineAnnealingLR(opt,50)
best=0
hist=[]

for ep in range(50):
    model.train()
    for x,hm,_ in tqdm(train_ld,leave=False,desc=f"Ep{ep+1}"):
        x,hm=x.to(device),hm.to(device)
        opt.zero_grad()
        crit(model(x),hm).backward()
        opt.step()
    model.eval()
    pk=[]
    with torch.no_grad():
        for x,hm,gt in val_ld:
            pk.append(pck(decode(model(x.to(device)).cpu()),gt))
    sch.step()
    vp=np.mean(pk)
    hist.append(vp)
    m=""
    if vp>best:
        best=vp
        torch.save(model.state_dict(),f"{SAVE_DIR}/keypoints_best.pt")
        m=" * SAVED TO DRIVE"
    print(f"Ep{ep+1:2d} PCK:{vp*100:.1f}%{m}")
print(f"\nBest: {best*100:.1f}%")

In [None]:
plt.plot([p*100 for p in hist])
plt.axhline(75,color="r",ls="--")
plt.title(f"Best PCK: {best*100:.1f}%")
plt.savefig(f"{SAVE_DIR}/curve.png")
plt.show()

In [None]:
test_ld=DataLoader(test_ds,batch_size=32,num_workers=0)
model.load_state_dict(torch.load(f"{SAVE_DIR}/keypoints_best.pt"))
model.eval()
tp=[]
with torch.no_grad():
    for x,_,gt in test_ld:
        tp.append(pck(decode(model(x.to(device)).cpu()),gt))
test_pck=np.mean(tp)
print("="*40)
print(f"TEST PCK: {test_pck*100:.1f}%")
print("="*40)

with open(f"{SAVE_DIR}/metrics.json","w") as f:
    json.dump({"best_val_pck":best,"test_pck":test_pck},f)
print(f"\nSaved to Google Drive: {SAVE_DIR}")
!ls -la {SAVE_DIR}