# Set up environment

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
import torch
print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))

Setup complete. Using torch 1.9.1+cu102 _CudaDeviceProperties(name='Quadro RTX 8000', major=7, minor=5, total_memory=49152MB, multi_processor_count=72)


In [5]:
# import relevant libraries
import pandas as pd
import numpy as np
from statistics import mean, stdev
import cv2
import random
import matplotlib.pyplot as plt
from PIL import Image
import torch
import os
%matplotlib inline

In [6]:
# clone repo
if os.path.exists('C:/Users/dental-1/Documents/dent_seg/3_training_code/missing_teeth_and_tooth_numbering/yolov5/') == False:
    !git clone https://github.com/ultralytics/yolov5
    %cd yolov5
else:
    %cd yolov5
    !git init
    !git pull https://github.com/ultralytics/yolov5

C:\Users\dental-1\Documents\dent_seg\3_training_code\missing_teeth_and_tooth_numbering\yolov5
Reinitialized existing Git repository in C:/Users/dental-1/Documents/dent_seg/3_training_code/missing_teeth_and_tooth_numbering/yolov5/.git/
Already up to date.


From https://github.com/ultralytics/yolov5
 * branch            HEAD       -> FETCH_HEAD


# Define functions

In [7]:
def closest(lst, K): 
    close =  lst[min(range(len(lst)), key = lambda i: abs(lst[i]-K))]
    if close == min(lst): 
        return 'top'
    elif close == max(lst): 
        return 'bottom'
    else: return 'There is a problem!'

In [8]:
# assign teeth to quadrant
def tooth_quad(data):
    data['quadrant'] = 0
    for i in range(data.shape[0]):
        if data['x_center'][i] < 0.5:
            if data['y_center'][i] < 0.5:
                data['quadrant'][i]=1
            else: data['quadrant'][i]=4
        else:
            if data['y_center'][i] < 0.5:
                data['quadrant'][i]=2
            else: data['quadrant'][i]=3

    # calculate x_left and x_right values for bounding boxes
    data['x_left'] = data['x_center'] - data['width']/2
    data['x_right'] = data['x_center'] + data['width']/2
    data['x_values'] = list(zip(data.x_left, data.x_right))
    data.drop(['x_left', 'x_right'], axis=1, inplace=True)

    # calculate y_left and y_right values for bounding boxes
    data['y_left'] = data['y_center'] - data['height']/2
    data['y_right'] = data['y_center'] + data['height']/2
    data['y_values'] = list(zip(data.y_left, data.y_right))
    data.drop(['y_left', 'y_right'], axis=1, inplace=True)

    data.sort_values(by=['quadrant', 'x_center'], inplace=True, ignore_index=True)

    # reassign quadrant number if more than 9 teeth per quadrant
    if (len(data[data['quadrant']==1])<8) & ((len(data[data['quadrant']==2])>8)):
        data.iloc[data[data['quadrant']==2].index[0],5] = 1
    elif (len(data[data['quadrant']==4])<8) & ((len(data[data['quadrant']==3])>8)):
        data.iloc[data[data['quadrant']==3].index[0],5] = 4
    elif (len(data[data['quadrant']==1])>8) & ((len(data[data['quadrant']==2])<8)):
        data.iloc[data[data['quadrant']==1].index[-1],5] = 2
    elif (len(data[data['quadrant']==4])>8) & ((len(data[data['quadrant']==3])<8)):
        data.iloc[data[data['quadrant']==4].index[-1],5] = 3

    # reassign quadrant number if more than 2 incisors per quadrant    
    while ((len(data[(data['quadrant']==1) & (data['tooth_label']==3)])<2) & (len(data[(data['quadrant']==2) & (data['tooth_label']==3)])>2)):
        data.iloc[data[data['quadrant']==2].index[0],5] = 1
    while ((len(data[(data['quadrant']==1) & (data['tooth_label']==3)])>2) & (len(data[(data['quadrant']==2) & (data['tooth_label']==3)])<2)):
        data.iloc[data[data['quadrant']==1].index[-1],5] = 2
    while ((len(data[(data['quadrant']==4) & (data['tooth_label']==3)])<2) & (len(data[(data['quadrant']==3) & (data['tooth_label']==3)])>2)):
        data.iloc[data[data['quadrant']==3].index[0],5] = 4
    while ((len(data[(data['quadrant']==4) & (data['tooth_label']==3)])>2) & (len(data[(data['quadrant']==3) & (data['tooth_label']==3)])<2)):
        data.iloc[data[data['quadrant']==4].index[-1],5] = 3

    #check against average y value of each quadrant
    if (data[data['quadrant']==1].shape[0]>0):
        if (data[data['quadrant']==4].shape[0]>0):
            y_center_1 = mean(data[data['quadrant']==1]['y_center'])
            y_center_4 = mean(data[data['quadrant']==4]['y_center'])
            mean_y_center_14 = tuple(list((y_center_1, y_center_4)))
    if (data[data['quadrant']==2].shape[0]>0):
        if (data[data['quadrant']==3].shape[0]>0):
            y_center_2 = mean(data[data['quadrant']==2]['y_center'])
            y_center_3 = mean(data[data['quadrant']==3]['y_center'])
            mean_y_center_23 = tuple(list((y_center_2, y_center_3)))
    
    for i in data.index:
        if (data[data['quadrant']==1].shape[0]>0) and (data[data['quadrant']==4].shape[0]>0):
            if (data.loc[i, 'quadrant']==1) and (closest(mean_y_center_14, data.loc[i, 'y_center'])=='bottom'):
                data.loc[i, 'quadrant'] = 4
            elif (data.loc[i, 'quadrant']==4) and (closest(mean_y_center_14, data.loc[i, 'y_center'])=='top'):
                data.loc[i, 'quadrant'] = 1
        if (data[data['quadrant']==2].shape[0]>0) and (data[data['quadrant']==3].shape[0]>0):
            if (data.loc[i, 'quadrant']==2) and (closest(mean_y_center_23, data.loc[i, 'y_center'])=='bottom'):
                data.loc[i, 'quadrant'] = 3
            elif (data.loc[i, 'quadrant']==3) and (closest(mean_y_center_23, data.loc[i, 'y_center'])=='top'):
                data.loc[i, 'quadrant'] = 2


    data.sort_values(by=['quadrant','x_center'], inplace=True, ignore_index=True)
    return(data)

In [9]:
# get dataframe of missing teeth and append to df
def tooth_num(data):
    # create df for each quadrant
    q1_data = data[data['quadrant']==1]
    q1_data.sort_values(by='x_center', ascending=False, inplace=True, ignore_index=True)
    q1_data['ISO_tooth_label'] = ''
    for i in range(q1_data.shape[0]):
        q1_data.iloc[i, q1_data.columns.get_loc('ISO_tooth_label')] = '1'+str(i+1)

    q2_data = data[data['quadrant']==2]
    q2_data.sort_values(by='x_center', inplace=True, ignore_index=True)
    q2_data['ISO_tooth_label'] = ''
    for i in range(q2_data.shape[0]):
        q2_data.iloc[i, q2_data.columns.get_loc('ISO_tooth_label')] = '2'+str(i+1)

    q3_data = data[data['quadrant']==3]
    q3_data.sort_values(by='x_center', inplace=True, ignore_index=True)
    q3_data['ISO_tooth_label'] = ''
    for i in range(q3_data.shape[0]):
        q3_data.iloc[i, q3_data.columns.get_loc('ISO_tooth_label')] = '3'+str(i+1)

    q4_data = data[data['quadrant']==4]
    q4_data.sort_values(by='x_center', ascending=False, inplace=True, ignore_index=True)
    q4_data['ISO_tooth_label'] = ''
    for i in range(q4_data.shape[0]):
        q4_data.iloc[i, q4_data.columns.get_loc('ISO_tooth_label')] = '4'+str(i+1)

    data = pd.concat([q1_data, q2_data, q3_data, q4_data], ignore_index=True)
    return(data)

In [10]:
def find_missing_teeth(data, quadrant_dict):
    # add missing teeth column to df
    data['missing'] = 0
    # loop through and find missing teeth
    missing_teeth = []
    for quadrant in quadrant_dict.keys():
        if len(quadrant_dict[quadrant]) < 8:
            for i in range(len(quadrant_dict[quadrant])-1):
                tooth_dist = quadrant_dict[quadrant]['x_values'][i+1][0]-quadrant_dict[quadrant]['x_values'][i][1]
                if tooth_dist > 0.005: #average_dist:
                    x_center = tooth_dist/2 + quadrant_dict[quadrant]['x_values'][i][1]
                    y_center = (quadrant_dict[quadrant]['y_center'][i+1]+quadrant_dict[quadrant]['y_center'][i])/2
                    w = tooth_dist
                    h = (quadrant_dict[quadrant]['height'][i+1]+quadrant_dict[quadrant]['height'][i])/2
                    x_val = tuple(list((x_center-w/2, x_center+w/2)))
                    y_val = tuple(list((y_center-h/2, y_center+h/2)))
                    num_teeth = 4
                    missing = 1
                    missing_teeth.append((num_teeth, x_center, y_center, w, h, int(quadrant), missing, x_val, y_val))

    missing_teeth = pd.DataFrame(missing_teeth, columns=['tooth_label', 'x_center', 'y_center', 'width', 'height', 'quadrant','missing', 'x_values', 'y_values'])
    # append missing teeth to label df
    data = data.append(missing_teeth)
    data.sort_values(by=['quadrant','x_center'], inplace=True, ignore_index=True)
    return(data)

In [11]:
## Change tooth label for missing teeth to category if it lies between 2 teeth of the same category for class 3 and 0
def check_class_between(index, series):
    classes = [0,1,2,3]
    new_class = 9
    if len(series) >=2:
        if series[index-1]==series[index+1]:
            new_class = series[index-1]
        # # tooth 2
        elif (series[index-1]==classes[3]) and (series[index+1]==classes[2]):
            new_class = classes[3]
        # tooth 3
        elif (series[index-1]==classes[3]) and (series[index+1]==classes[1]):
            new_class = classes[2]
        # tooth 4
        elif (series[index-1]==classes[2]) and (series[index+1]==classes[1]):
            new_class = classes[1]
        # tooth 5
        elif (series[index-1]==classes[1]) and (series[index-2]!=classes[1]) and (series[index+1]==classes[0]):
            new_class = classes[1]
        # tooth 6
        elif (series[index-1]==classes[1]) and (series[index-2]==classes[1]):
            new_class = classes[0]
    return new_class

In [12]:
## Create dictionary of dataframes for each quadrant
def dict_quad(data):
    q1_data = data[data['quadrant']==1]
    q1_data.sort_values(by='x_center', inplace=True)
    q1_data.reset_index(drop=True, inplace=True)

    q2_data = data[data['quadrant']==2]
    q2_data.reset_index(drop=True, inplace=True)

    q3_data = data[data['quadrant']==3]
    q3_data.reset_index(drop=True, inplace=True)

    q4_data = data[data['quadrant']==4]
    q4_data.sort_values(by='x_center', inplace=True)
    q4_data.reset_index(drop=True, inplace=True)

    quadrant_dict = {'1': q1_data, '2': q2_data, '3': q3_data, '4': q4_data}
    return(quadrant_dict)

In [13]:
## Visualize image with missing teeth and ISO tooth numbering
def visualize_toothnum(img, data):
#     dummy = img.copy()
#     for i in range(data.shape[0]):
#         if data['missing'][i] == 0:
#             color = (255, 0, 0)
#         else: color = (0, 255, 255)
#         start = (int(data['x_values'][i][0]*img.shape[1]), int(data['y_values'][i][0]*img.shape[0]))
#         end = (int(data['x_values'][i][1]*img.shape[1]), int(data['y_values'][i][1]*img.shape[0]))
#         if ((data['quadrant'][i] == 1) | (data['quadrant'][i] == 2)):
#             # label = (start[0], start[1]-25)
#             label = (round(start[0]+((end[0]-start[0])/2))-15, start[1]-25)
#         # else: label = (start[0], end[1]+45)
#         else: label = (round(start[0]+((end[0]-start[0])/2))-20, end[1]+45)
#         # cv2.rectangle(dummy, start, end, color, 3)
#         cv2.putText(dummy, text= ('%s') %(data['tooth_label'][i]), 
#                         org=label,fontFace= cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=color,
#                         thickness=4, lineType=cv2.LINE_AA)

#     plt.figure(figsize=(14,7))
#     plt.imshow(dummy)
    dummy = img.copy()
    for i in range(data.shape[0]):
        if data['missing'][i] == 0:
            color = (255, 0, 0)
        else: color = (0, 255, 255)
        start = (int(data['x_values'][i][0]*img.shape[1]), int(data['y_values'][i][0]*img.shape[0]))
        end = (int(data['x_values'][i][1]*img.shape[1]), int(data['y_values'][i][1]*img.shape[0]))
        if ((data['quadrant'][i] == 1) | (data['quadrant'][i] == 2)):
            # label = (start[0], start[1]-25)
            label = (round(start[0]+((end[0]-start[0])/2))-15, start[1]-25)
        # else: label = (start[0], end[1]+45)
        else: label = (round(start[0]+((end[0]-start[0])/2))-20, end[1]+45)
        # cv2.rectangle(dummy, start, end, color, 3)
        cv2.putText(dummy, text= ('%s') %(data['ISO_tooth_label'][i]), 
                        org=label,fontFace= cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=color,
                        thickness=4, lineType=cv2.LINE_AA)

    plt.figure(figsize=(14,7))
    plt.imshow(dummy)
#     plt.savefig(file_path+file_name)
    return()

In [14]:
## Verify position of canines/first molar or change tooth number if mismatched:
# tooth_label - list of integers
# ISO_tooth_labels - list of integers
# quadrant - '1', '2', '3' or '4'
def check_toothtype_position(tooth_label, ISO_tooth_label, quadrant):
    flag_pm = 0
    flag_m = 0
    for i in range(len(tooth_label)):
        
        # check canine position
        if ((tooth_label[i] == 2) & (ISO_tooth_label[i] != quadrant*10+3)):
            ISO_tooth_label.insert(i,quadrant*10+3)
            if ISO_tooth_label[i]>=ISO_tooth_label[i+1]:
                n = int(ISO_tooth_label[i])-int(ISO_tooth_label[i+1])+1
                for j in range(i+1,len(tooth_label)):
                    ISO_tooth_label[j] += n
        
        # check premolar1 position
        elif  (tooth_label[i] == 1):
            if (ISO_tooth_label[i] != quadrant*10+4): 
                if flag_pm==0:
                    ISO_tooth_label.insert(i,quadrant*10+4)
                    flag_pm = flag_pm+1
                    if ISO_tooth_label[i]>=ISO_tooth_label[i+1]:
                        n = int(ISO_tooth_label[i])-int(ISO_tooth_label[i+1])+1
                        for j in range(i+1,len(tooth_label)):
                            ISO_tooth_label[j] += n
            else: flag_pm=1
                
    
        # check molar1 position
        elif  (tooth_label[i] == 0):
            if (ISO_tooth_label[i] != quadrant*10+6):
                if flag_m==0:
                    ISO_tooth_label.insert(i,quadrant*10+6)
                    flag_m = flag_m+1
                    if ISO_tooth_label[i]>=ISO_tooth_label[i+1]:
                        n = int(ISO_tooth_label[i])-int(ISO_tooth_label[i+1])+1
                        for j in range(i+1,len(tooth_label)):
                            ISO_tooth_label[j] += n
            else: flag_m=1
  
    return ISO_tooth_label[:len(tooth_label)]

# Segment teeth

Make sure to change desired images in `--source` and file name in `--name`

In [16]:
# run model against specific image
!python detect.py --agnostic --weights ../../../4_weights/tooth_classification_weight.pt --conf 0.4 --source ../../../2_dental_images/raw_images/cate8-00074.jpg --name cate8-00074 --save-txt --save-conf

[34m[1mdetect: [0mweights=['../../../4_weights/tooth_classification_weight.pt'], source=../../../2_dental_images/raw_images/cate8-00074.jpg, imgsz=[640, 640], conf_thres=0.4, iou_thres=0.45, max_det=1000, device=, view_img=False, save_txt=True, save_conf=True, save_crop=False, nosave=False, classes=None, agnostic_nms=True, augment=False, visualize=False, update=False, project=runs\detect, name=cate8-00074, exist_ok=False, line_thickness=3, hide_labels=False, hide_conf=False, half=False, dnn=False
YOLOv5  v6.0-103-g7a39803 torch 1.9.1+cu102 CUDA:0 (Quadro RTX 8000, 49152MiB)

Fusing layers... 
Model Summary: 232 layers, 7254609 parameters, 0 gradients
image 1/1 C:\Users\dental-1\Documents\dent_seg\2_dental_images\raw_images\cate8-00074.jpg: 384x640 5 molars, 4 premolars, 4 canines, 8 incisors, Done. (0.000s)
Speed: 0.0ms pre-process, 0.0ms inference, 15.6ms NMS per image at shape (1, 3, 640, 640)
Results saved to [1mruns\detect\cate8-00074[0m
1 labels saved to runs\detect\cate8-000

In [15]:
# define label/image path
label_path = 'runs/detect/cate8-00075/labels/cate8-00075.txt'
img_path = 'C:/Users/dental-1/Documents/dent_seg/2_dental_images/raw_images/cate8-00075.jpg'

In [16]:
# read in bounding boxes
data = pd.read_csv(label_path, header=None, sep=' ')
data.rename(columns={0: 'tooth_label', 1: 'x_center', 2: 'y_center', 3: 'width', 4: 'height', 5: 'confidence'}, inplace=True)
data.drop(columns = ['confidence'], inplace=True)

# read in image
img = cv2.imread(img_path)

data.head()

Unnamed: 0,tooth_label,x_center,y_center,width,height
0,2,0.589151,0.630878,0.04219,0.250222
1,3,0.553993,0.618456,0.035158,0.193434
2,3,0.468106,0.616238,0.036163,0.188997
3,3,0.526871,0.613576,0.036163,0.187223
4,2,0.435962,0.630878,0.054244,0.220053


In [17]:
# calculate x and y bounding box edges
data['xmin'] = data['x_center'] - data['width']/2
data['xmax'] = data['x_center'] + data['width']/2
data['ymin'] = data['y_center'] - data['height']/2
data['ymax'] = data['y_center'] + data['height']/2

# convert to pixels
data['xmin'] = data['xmin']*img.shape[0]
data['xmax'] = data['xmax']*img.shape[0]
data['ymin'] = data['ymin']*img.shape[1]
data['ymax'] = data['ymax']*img.shape[1]

data.head()

Unnamed: 0,tooth_label,x_center,y_center,width,height,xmin,xmax,ymin,ymax
0,2,0.589151,0.630878,0.04219,0.250222,640.199168,687.747186,1006.982097,1505.174099
1,3,0.553993,0.618456,0.035158,0.193434,604.538465,644.161757,1038.782349,1423.909443
2,3,0.468106,0.616238,0.036163,0.188997,507.177781,547.933143,1038.783344,1415.076372
3,3,0.526871,0.613576,0.036163,0.187223,573.405936,614.161298,1035.24932,1408.010312
4,2,0.435962,0.630878,0.054244,0.220053,460.762624,521.895724,1037.015337,1475.14086


# Find missing teeth

In [18]:
data = tooth_quad(data)

In [19]:
quadrant_dict = dict_quad(data)

In [20]:
data = find_missing_teeth(data, quadrant_dict)

In [21]:
data = tooth_num(data)

In [22]:
# assigning class to missing tooth if between 2 teeth:
for i in range(len(data.tooth_label)):
    if data.tooth_label[i]==4:
        data.tooth_label[i] = check_class_between(i, data.tooth_label)

In [23]:
data = tooth_quad(data)

In [24]:
data = tooth_num(data)

In [25]:
quadrant_dict = dict_quad(data)

In [26]:
for quadrant in quadrant_dict.keys():
        quadrant_dict[quadrant].sort_values(by=['ISO_tooth_label'], inplace=True, ignore_index=True)
        premol = quadrant_dict[quadrant].index[quadrant_dict[quadrant]['tooth_label'] == 1].tolist()
        can = quadrant_dict[quadrant].index[quadrant_dict[quadrant]['tooth_label'] == 2].tolist()
        if len(premol)==3:
            if len(can)==1:
                quadrant_dict[quadrant]['tooth_label'][premol[2]]=0
                quadrant_dict[quadrant]['ISO_tooth_label'][premol[2]]=int(quadrant)*10+6
            else:
                quadrant_dict[quadrant]['tooth_label'][premol[0]]=2
                quadrant_dict[quadrant]['ISO_tooth_label'][premol[0]]=int(quadrant)*10+3
        tooth_label = list(quadrant_dict[quadrant]['tooth_label'].astype('int64'))
        ISO_tooth_label = list(quadrant_dict[quadrant]['ISO_tooth_label'].astype('int64'))
        if len(tooth_label) != 8:
            quadrant_dict[quadrant]['ISO_tooth_label'] = check_toothtype_position(tooth_label, ISO_tooth_label, int(quadrant))

In [27]:
data = pd.concat([quadrant_dict['1'], quadrant_dict['2'], quadrant_dict['3'], quadrant_dict['4']], ignore_index=True)
data.ISO_tooth_label = data.ISO_tooth_label.astype(int)
data.sort_values(by='ISO_tooth_label', ascending=True, inplace=True, ignore_index=True)

In [None]:
visualize_toothnum(img, data)