In [None]:
%pip install ultralytics -q

In [None]:
%env WANDB_DISABLED=True

In [31]:
from ultralytics import YOLO
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import json
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import shutil
from PIL import Image

In [None]:
base_dir = Path('./KITTI')
img_path = base_dir / 'data_object_image_2' / 'training' / 'image_2'
label_path = Path('./labels_with_dont_care')
with open('./classes_with_dont_care.json','r') as f:
    classes = json.load(f)

classes

In [None]:
ims = sorted(list(img_path.glob('*')))
labels = sorted(list(label_path.glob('*')))
pairs = list(zip(ims,labels))
pairs[:2]

In [None]:
train, test = train_test_split(pairs,test_size=0.1,shuffle=True)
len(train), len(test)

In [6]:
train_path = Path('train').resolve()
train_path.mkdir(exist_ok=True)
valid_path = Path('valid').resolve()
valid_path.mkdir(exist_ok=True)

In [None]:
for t_img, t_lb in tqdm(train):
    im_path = train_path / t_img.name
    lb_path = train_path / t_lb.name
    shutil.copy(t_img,im_path)
    shutil.copy(t_lb,lb_path)

In [None]:
for t_img, t_lb in tqdm(test):
    im_path = valid_path / t_img.name
    lb_path = valid_path / t_lb.name
    shutil.copy(t_img,im_path)
    shutil.copy(t_lb,lb_path)

In [10]:
yaml_file = 'names:\n'
yaml_file += '\n'.join(f'- {c}' for c in classes)
yaml_file += f'\nnc: {len(classes)}'
yaml_file += f'\ntrain: {str(train_path)}\nval: {str(valid_path)}'
with open('kitti.yaml','w') as f:
    f.write(yaml_file)

In [None]:
!cat kitti.yaml

In [None]:
model = YOLO('yolov8n.yaml')
model = YOLO('yolov8n.pt')

In [None]:
train_results = model.train(
    data='./kitti.yaml', 
    epochs=10,
    patience=3,
    mixup=0.1,
    project='yolov8n-kitti',
    device='cpu'
)

In [None]:
valid_results = model.val(name='val')

In [None]:
plt.figure(figsize=(10,20))
plt.imshow(Image.open('./yolov8n-kitti/train/results.png'))
plt.axis('off')
plt.show()

In [None]:
plt.figure(figsize=(10,20))
plt.imshow(Image.open('./yolov8n-kitti/train/confusion_matrix.png'))
plt.axis('off')
plt.show()

In [None]:
preds = model.predict([test[idx][0] for idx in np.random.randint(0,len(test),(20,))],save=True,name='predict')

In [29]:
preds = list(Path('yolov8n-kitti/predict').glob('*'))

In [None]:
def plot_images(images):
    num_images = len(images)
    rows = num_images
    cols = 1
    fig, axes = plt.subplots(rows, cols, figsize=(15, 80))
    for ax in axes.flat:
        ax.axis('off')
    for i, img_path in enumerate(images):
        img = Image.open(img_path)
        axes[i].imshow(img)
    
    plt.tight_layout()
    plt.show()
    
plot_images(preds)