Trong bài này chúng ta sẽ tìm hiểu về one-stage detector có tên là RetinaNet. Đầu tiên cùng tìm hiểu về Focal Loss, chính loss function này làm nên sự khác biệt của RetinaNet. Các bài toán object detection trước đó luôn gặp phải một vấn đề về class imbalance (ám chỉ sự chênh lệch giữa foreground và background là quá lớn).
Nhớ lại cross entropy (CE) cho bài toán phân loại đối với phân phối xác suất thực tế
trong đó
Các nhãn bằng 0 tương ứng với
Đóng góp vào loss function của các class là như nhau
Đồ thị trên được vẽ cho bài toán binary classification trong đó
Cùng xem hình trên với
Xem thêm ví dụ dưới đây
Đồ thị $-log(q_i)$
Chúng ta có 100000 easy examples (0.1 loss cho mỗi example) và 100 hard examples (2.3 loss cho mỗi example). Khi đó tập hợp lại ta được
- Loss cho tất cả easy examples = 100000×0.1 = 10000
- Loss cho tất cả hard examples = 100×2.3 = 230
- 10000 / 230 = 43. Loss từ easy examples lớn hơn rất nhiều loss từ hard examples.
Cross entropy loss không phải là sự lựa chọn tốt cho trường hợp rất mất cân bằng dữ liệu. Nếu dùng cross entropy thông thường này thì
- Ngay cả khi mô hình dự đoán sai foreground (predicted probability of foreground thấp, gần gốc tọa độ) thì loss do việc dự đoán foreground sai này vẫn nhỏ so với tổng loss, mô hình thậm chí không cần cải thiện thêm vẫn được (tất nhiên nếu dự đoán đúng foreground thì loss vẫn giảm). Điều này là không chấp nhận được. Mô hình ít quan tâm đến dự đoán đúng foreground vì loss do dự đoán sai foreground không ảnh hưởng nhiều đến loss chung. Phải nhấn mạnh rằng nếu dự đoán đúng foreground (predicted probability of foreground cao, gần điểm 1) loss có giảm nhưng không đáng kể so với tổng loss.
- Hay khi mô hình dự đoán đúng các easy examples - predicted probability of background, gần điểm 1 trên trục hoành thì tổng loss cho việc dự đoán đúng background vẫn lớn. Thử tưởng tượng nếu dự đoán sai các easy examples (đồ thị gần gốc tọa độ) thì loss sẽ cực kỳ lớn. Mô hình đang tập trung dự đoán đúng background để giảm loss.
Chính vì những điều trên chúng ta cần một loss function hiệu quả hơn giúp điều chỉnh loss lớn hơn khi dự đoán sai foreground (object). Điều này giúp chúng ta hạn chế dự đoán sai đối với foreground vì khi dự đoán sai loss sẽ tăng lên đáng kể.
Phương pháp thông thường để giải quyết vấn đề class imbalance là đưa vào trọng số
trong đó
Ta có công thức cho balanced cross entropy:
Việc đưa thêm trọng số của các class vào cross entropy giúp chúng ta giải quyết phần nào class imbalance (class ít hơn sẽ có hệ số đóng góp lớn hơn vào loss function). Tuy nhiên nó chưa thực sự thay đổi được gradient của loss function. Trong khi mô hình được huấn luyện trên mẫu mất cân bằng trầm trọng có giá trị gradient chịu ảnh hưởng phần lớn bởi class chiếm đa số. Do đó chúng ta cần một sự điều chỉnh triệt để hơn giúp gia tăng ảnh hưởng của nhóm thiểu số lên gradient. Đó là lý do Focal loss ra đời.
trong đó
-
$\gamma$ - tunable focusing parameter$\gamma \geq 0$ , thường chọn$\gamma \in [0,1]$ . Trong bài báo tác giả tìm được$\gamma = 2$ tốt nhất cho thử nghiệm của họ.$\gamma$ giúp ta thay đổi dễ dàng trọng số đóng góp của class đa số (background).$\gamma = 0$ chúng ta có balanced cross entropy.
Cùng đánh giá chi tiết hơn vấn đề này:
- Đối với class dễ dự đoán (class chiếm đa số) xác suất
$q_i$ của class cao, do đó$(1-q_i)^\gamma$ sẽ nhỏ. Ví dụ$q_i=0.9$ ,$\gamma = 2$ khi đó$(1-q_i)^\gamma = 1/100$ . Điều này đồng nghĩa với việc đóng góp vào loss của class đa số đã giảm đi 100 lần. - Đối với class dễ dự đoán sai (class thiểu số) giá trị
$q_i$ thường sẽ nhỏ, ví dụ$p_i=0.1$ khi đó$(1-q_i)^\gamma = 1/0.81=1.23$ . Đóng góp vào loss của class thiểu số cũng giảm 1.23 lần nhưng tổng thể trọng số đóng góp của class thiểu số lại tăng lên (bên trên đóng góp của class đa số giảm 100 lần).
Ở đây có predicted probability
với
Đối với hàm
Hàm
Khi đó
Tương tự như vậy ta có thể chứng minh được bất đẳng thức:
Như vậy nếu chọn
Điều này chứng tỏ hàm số
- Class đa số (dễ dự đoán)
$q_i$ lớn:$(1-q_i)^\gamma$ nhỏ nên sự ảnh hưởng lên gradient của các class đa số không đáng kể - Class thiểu số (khó dự đoán)
$q_i$ nhỏ:$(1-q_i)^\gamma$ lớn gần bằng 1, ảnh hưởng lên gradient sẽ lớn. Như vậy khi dự đoán sai class thiểu số, gradient sẽ thay đổi lớn để giảm focal loss.
Đồ thị Focal loss $FP(\mathbf{q}) = -\alpha_i (1-q_i)^{\gamma} \log(q_i)$, với $\alpha_i = 1$
Trường hợp
Trục hoành là xác suất dự đoán được của class
- Đối với class đa số (dễ dự đoán) xác suất
$q_i$ sẽ gần với điểm 1 - Đối với class thiểu số (khó dự đoán) xác suất
$q_i$ thấp nên sẽ gần gốc tọa độ hơn.
Nhận thấy khi
RetinaNet là sự kết hợp giữa các networks bao gồm backbone và subnets. Backbone của RetinaNet chính là Feature Pyramid Network dùng để trích xuất các feature map. Các mạng subnet dùng để thực hiện classification và box regression. RetinaNet dùng focal loss như đã giới thiệu bên trên.
Feature Pyramid Network (FPN)
FPN cung cấp multi-scale feature pyramid dựa trên bottom-up, top-down pathways và lateral (skip) connection.
RetinaNet architecture
Trong RetinaNet, FPN chỉ lấy các megered feature maps từ level P3 đến P7.
Subnets Ở đây class subnet và box subnet có chung cấu trúc, chỉ khác số channels ở layers cuối cùng.
Chú ý: class subnet và box subnet sử dụng riêng các parameters.
Megered feature maps có số channels
- Class subnet: đi qua 4 Conv layers
3x3
với số filters là$C$ , sau đó lại đi qua Conv layers3x3
với số filters là$K \times A$ . Trong đó$K$ - số classes,$A$ - số anchors tại một vị trí. - Box subnet: đi qua 4 Conv layers
3x3
với số filters là$C$ , sau đó lại đi qua Conv layers3x3
với số filters là$4A$ .$A$ thường được chọn bằng 9.
- https://arxiv.org/abs/1708.02002
- https://towardsdatascience.com/review-retinanet-focal-loss-object-detection-38fba6afabe4
- https://towardsdatascience.com/retinanet-the-beauty-of-focal-loss-e9ab132f2981
- https://phamdinhkhanh.github.io/2020/08/23/FocalLoss.html
- https://medium.com/swlh/focal-loss-what-why-and-how-df6735f26616
- https://maxhalford.github.io/blog/lightgbm-focal-loss/
- https://keras.io/examples/vision/retinanet/
- https://medium.com/@14prakash/the-intuition-behind-retinanet-eb636755607d
- https://deep-learning-study-note.readthedocs.io/en/latest/Part%202%20(Modern%20Practical%20Deep%20Networks)/12%20Applications/Computer%20Vision%20External/Focal%20Loss%20for%20Dense%20Object%20Detection.html
- https://leimao.github.io/cv/