In [1]:
#!L
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

from anchorless_det.model import CenterNet
from anchorless_det.train import train_cycle
from anchorless_det.face_dataset import dataset
from anchorless_det.utils import make_bboxes

In [7]:
#!L
model = CenterNet(2, resnet=50, pretrained=True)

In [8]:
#!L
opt = torch.optim.Adam(model.parameters(), lr=1e-4)

In [9]:
#!L
train, test, df = dataset.center_face_train_test_split(
    helen_path='./data/helen/helen_1',
    fgnet_path='./data/fg_net/images',
    celeba_path='./data/celeba/img_align_celeba', 
    dataset_limit=50000
)
train = dataset.CenterFaceDataset(train, df)
test = dataset.CenterFaceDataset(test, df, bin_mask=True)

In [10]:
#!L
batch_size = 32
train_gen = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
val_gen = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)

In [11]:
#!L
train_loss, val_loss, val_metric = train_cycle(
    model, opt, 
    train_gen, val_gen, 
    epochs=15, name='centernet_v4'
)

In [25]:
#!L
plt.plot(np.array(train_loss).reshape(15, -1).mean(axis=-1), label='train loss')
plt.plot(np.array(val_loss).reshape(15, -1).mean(axis=-1), label='val loss')
plt.grid()
plt.legend()
plt.ylim(0, 2)
plt.show()

In [23]:
#!L
plt.plot(np.array(val_metric).reshape(15, -1).mean(axis=-1), label='val metric')
plt.legend()
plt.grid()
plt.show()

In [17]:
#!L
img, mask, size, bin_mask = test[1242]
img = img.to('cuda').unsqueeze(0)
mask = mask.to('cuda').unsqueeze(0)
size = size.to('cuda').unsqueeze(0)
bin_mask = bin_mask.to('cuda').unsqueeze(0)

pred_hm, pred_offset, pred_size = model(img)
pred = model.predict(img)

plt.subplot(1, 2, 1)
plt.imshow(img[0].permute(1, 2, 0).cpu().detach())
for c_pred in pred:
    plt.gca().add_patch(Rectangle(
        (c_pred[2], c_pred[3]), 
        c_pred[4] - c_pred[2], 
        c_pred[5] - c_pred[3], linewidth=2, edgecolor='r', facecolor='none'
    ))
plt.subplot(1, 2, 2)
plt.imshow(pred_hm[0].sum(axis=0).cpu().detach())
plt.show()