In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
import cv2
from torch.nn import CrossEntropyLoss
from torch.nn import MSELoss
import pandas as pd
from src.pytorch_datasets.driving_dataset import DrivingDataset
from src.dataset_augmentations.dataset_augmentator import *
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader

In [3]:
class MetricMonitor:
    def __init__(self, float_precision=3):
        self.float_precision = float_precision
        self.reset()

    def reset(self):
        self.metrics = {"val": 0, "count": 0, "avg": 0}

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]

        metric["val"] += val
        metric["count"] += 1
        metric["avg"] = metric["val"] / metric["count"]

    def __str__(self):
        return " | ".join(
            [
                "{metric_name}: {avg:.{float_precision}f}".format(
                    metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
                )
                for (metric_name, metric) in self.metrics.items()
            ]
        )

In [4]:
class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

In [5]:
dataset_folder = "E:\\temp\\datasets"  #указать свой

classes_path = os.path.join(dataset_folder, 'classes.csv')
train_folder = os.path.join(dataset_folder, 'train')
val_folder = os.path.join(dataset_folder, 'val')

In [6]:
classes = pd.read_csv(classes_path, sep='\t', names=['id', 'name'])
print(classes)

   id           name
0   0           bike
1   1            bus
2   2            car
3   3          motor
4   4         person
5   5          rider
6   6  traffic light
7   7   traffic sign
8   8          train
9   9          truck


In [7]:
class_names = ['__background__'] + classes['name'].values.tolist()
class_names

['__background__',
 'bike',
 'bus',
 'car',
 'motor',
 'person',
 'rider',
 'traffic light',
 'traffic sign',
 'train',
 'truck']

In [8]:
train_dataset = DrivingDataset(train_folder, class_names, transforms=get_train_transform())
val_dataset = DrivingDataset(val_folder, class_names, transforms=get_train_transform())

In [9]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision


def create_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

In [10]:
# веса для потерь
LABELS = 1.0
BBOX = 1.0

In [21]:
def collate_fn(batch):
    return tuple(zip(*batch))


def train(train_loader, model, optimizer, epoch, params):
    print('Training')
    global train_loss_list

    model.train()
    progress = tqdm(train_loader)

    for i, (images, target) in enumerate(progress, start=1):
        optimizer.zero_grad()

        images = list(image.to(params['device']) for image in images)
        target = [{k: v.to(params['device']) for k, v in t.items()} for t in target]

        output = model(images, target)

        total_loss = sum(loss for loss in output.values())

        loss_value = total_loss.item()
        train_loss_list.append(loss_value)
        train_loss_hist.send(loss_value)

        total_loss.backward()
        optimizer.step()

        progress.set_description(f"Training. Loss: {loss_value:.4f}")


def validate(val_loader, bbox_criterion, classes_criterion, model, epoch, params):
    global val_itr
    global val_loss_list

    model.eval()
    stream = tqdm(val_loader)

    with torch.no_grad():
        for i, (images, target) in enumerate(stream, start=1):
            images = list(image.to(params['device']) for image in images)
            target = [{k: v.to(params['device']) for k, v in t.items()} for t in target]

            bboxes = [t['boxes'] for t in target]
            labels = [t['labels'] for t in target]

            output = model(images)

            bbox_loss = bbox_criterion(output[0], bboxes)
            class_loss = classes_criterion(output[1], labels)

            total_loss = bbox_loss + class_loss

            stream.set_description(
                "Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=total_loss.item())
            )


def train_and_validate(model, train_dataset, val_dataset, params):
    train_loader = DataLoader(
        train_dataset,
        batch_size=params["batch_size"],
        shuffle=True,
        pin_memory=torch.cuda.is_available(),
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=params["batch_size"],
        shuffle=False,
        pin_memory=torch.cuda.is_available(),
        collate_fn=collate_fn
    )

    bbox_criterion = MSELoss().to(params["device"])
    labels_criterion = CrossEntropyLoss().to(params["device"])

    optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])

    for epoch in range(1, params["epochs"] + 1):
        train_loss_hist.reset()
        val_loss_hist.reset()

        train(train_loader, model, optimizer, epoch, params)
        validate(val_loader, model, bbox_criterion, labels_criterion, epoch, params)

        print(f"Epoch #{epoch} train loss: {train_loss_hist.value:.3f}")
        print(f"Epoch #{epoch} validation loss: {val_loss_hist.value:.3f}")

    return model

In [12]:
params = {
    "device": "cpu",
    "lr": 0.003,
    "batch_size": 16,
    "epochs": 10,
}

In [13]:
train_loss_hist = Averager()
val_loss_hist = Averager()
train_loss_list = []
val_loss_list = []

In [22]:
model = create_model(len(class_names))
model = train_and_validate(model, train_dataset, val_dataset, params)



Training


  0%|          | 0/4375 [00:00<?, ?it/s]

KeyboardInterrupt: 

Error in callback <bound method AutoreloadMagics.post_execute_hook of <IPython.extensions.autoreload.AutoreloadMagics object at 0x000001B63E8AF950>> (for post_execute):


KeyboardInterrupt: 