In [1]:
import warnings
from datetime import datetime
from pathlib import Path
from enum import Enum, auto

import cv2
import numpy as np
import PIL
import torch
import torchvision.models as models
from IPython import get_ipython
from IPython.display import display
from torch import nn, optim
from torch.nn.functional import mse_loss
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from src.dataset import JetBotDataset

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True


In [2]:
if "google.colab" in str(get_ipython()):
    from google.colab.patches import cv2_imshow

    imshow = cv2_imshow
else:

    def imshow(a):
        """
        img= img.clip(0, 255).astype('uint8')
        plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        plt.axis('off')
        """
        a = a.clip(0, 255).astype("uint8")
        if a.ndim == 3:
            if a.shape[2] == 4:
                a = cv2.cvtColor(a, cv2.COLOR_BGRA2RGBA)
            else:
                a = cv2.cvtColor(a, cv2.COLOR_BGR2RGB)
        display(PIL.Image.fromarray(a))


In [3]:
train_dataset = JetBotDataset("dataset/augmented", use_next=True)
test_dataset = JetBotDataset("dataset/augmented", split_type="test", use_next=True)


In [4]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=2)


In [5]:
def train(
    model: nn.Module,
    train_dataloader: DataLoader,
    test_dataloader: DataLoader,
    optimizer: optim.Optimizer,
    no_epochs: int,
    model_save_dir: str,
    save_archive: bool = False,
):
    best_loss = np.inf

    # creating unique timestamp for the run
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    model_root_path = Path(model_save_dir)
    model_root_path.mkdir(parents=True, exist_ok=True)

    model_best_path = model_root_path / timestamp / "best"
    model_best_path.mkdir(parents=True, exist_ok=True)
    model_best_primitive_log_path = model_best_path / "primitive_log.txt"

    model_archive_path = model_root_path / timestamp / "archive"
    model_archive_path.mkdir(parents=True, exist_ok=True)
    model_archive_primitive_log_path = model_archive_path / "primitive_log.txt"

    fp_best_primitive_log = open(model_best_primitive_log_path, "w")
    fp_archive_primitive_log = open(model_archive_primitive_log_path, "w")

    with open(model_archive_primitive_log_path, "w") as fp_archive_primitive_log:

        for epoch in tqdm(range(no_epochs)):

            model.train()
            train_loss = 0.0
            for images, labels in iter(train_dataloader):
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = mse_loss(outputs, labels)
                train_loss += loss
                loss.backward()
                optimizer.step()

            train_loss /= len(train_dataloader)

            model.eval()

            test_loss = 0.0
            for images, labels in iter(test_dataloader):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)

                loss = mse_loss(outputs, labels)
                test_loss += float(loss)
            test_loss /= len(test_dataloader)

            print(f"Epoch: {epoch} | Train loss: {train_loss} | Test loss: {test_loss}")

            if save_archive:
                epoch_save = f"epoch-{epoch}.pt"
                torch.save(model.state_dict(), model_archive_path / epoch_save)
                fp_archive_primitive_log.write(
                    f"Epoch: {epoch} | Train loss: {train_loss} | Test loss: {test_loss} \n"
                )

            if test_loss < best_loss:
                best_save = "best.pt"
                torch.save(model.state_dict(), model_best_path / best_save)

                with open(model_best_primitive_log_path, "w") as fp_best_primitive_log:
                    fp_best_primitive_log.write(
                        f"Epoch: {epoch} | Train loss: {train_loss} | Test loss: {test_loss}"
                    )
                best_loss = test_loss


In [6]:
class Models(Enum):
    SQUEEZENET_1_1 = auto()
    MOBILENETV3_SMALL = auto()
    MOBILENETV3_LARGE = auto()
    RESNET_18 = auto()


# choose what to learn
run_models = [Models.SQUEEZENET_1_1, Models.MOBILENETV3_SMALL]


### SqueezeNet_1.1


In [7]:
if Models.SQUEEZENET_1_1 in run_models:
    squeezenet1_1 = models.squeezenet1_1(pretrained=True)
    squeezenet1_1.classifier[1] = nn.Conv2d(512, 2, kernel_size=(1, 1), stride=(1, 1))
    squeezenet1_1.num_classes = 2

    squeezenet1_1 = squeezenet1_1.to(device)

    squeezenet1_1_optimizer = optim.Adam(squeezenet1_1.parameters())

    train(
        squeezenet1_1,
        train_dataloader,
        test_dataloader,
        squeezenet1_1_optimizer,
        10,
        "models/SqueezeNet1_1",
    )


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 0 | Train loss: 0.20041827857494354 | Test loss: 0.11037209095514339
Epoch: 1 | Train loss: 0.15903347730636597 | Test loss: 0.11999999828960585
Epoch: 2 | Train loss: 0.1511111855506897 | Test loss: 0.10952231188750138
Epoch: 3 | Train loss: 0.14466983079910278 | Test loss: 0.1167365063143813
Epoch: 4 | Train loss: 0.13906697928905487 | Test loss: 0.1220374249893686
Epoch: 5 | Train loss: 0.1350688636302948 | Test loss: 0.11807350514699584
Epoch: 6 | Train loss: 0.13204923272132874 | Test loss: 0.13024793953999228
Epoch: 7 | Train loss: 0.12985582649707794 | Test loss: 0.12180828027751135
Epoch: 8 | Train loss: 0.12908801436424255 | Test loss: 0.11649295938727648
Epoch: 9 | Train loss: 0.1288008689880371 | Test loss: 0.1306081786751747


### MobileNetV3_small


In [8]:
if Models.MOBILENETV3_SMALL in run_models:
    mobilenetv3_small = models.mobilenet_v3_small(pretrained=True)
    mobilenetv3_small.classifier[3] = nn.Linear(
        in_features=1024, out_features=2, bias=True
    )
    mobilenetv3_small.num_classes = 2

    mobilenetv3_small = mobilenetv3_small.to(device)

    mobilenetv3_small_optimizer = optim.Adam(mobilenetv3_small.parameters())

    train(
        mobilenetv3_small,
        train_dataloader,
        test_dataloader,
        mobilenetv3_small_optimizer,
        no_epochs=10,
        model_save_dir="models/MobileNetV3_small",
        save_archive=False,
    )


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 0 | Train loss: 0.06602617353200912 | Test loss: 0.14281956462756448
Epoch: 1 | Train loss: 0.029850967228412628 | Test loss: 0.11480542466692302
Epoch: 2 | Train loss: 0.01799316145479679 | Test loss: 0.1213773884203123
Epoch: 3 | Train loss: 0.01392361056059599 | Test loss: 0.1093647695429947
Epoch: 4 | Train loss: 0.01171116717159748 | Test loss: 0.11567393269227899
Epoch: 5 | Train loss: 0.01003354787826538 | Test loss: 0.12134120406825906
Epoch: 6 | Train loss: 0.008867496624588966 | Test loss: 0.12490960955619812
Epoch: 7 | Train loss: 0.007495146244764328 | Test loss: 0.1344660245206045
Epoch: 8 | Train loss: 0.007638414856046438 | Test loss: 0.12575296895659488
Epoch: 9 | Train loss: 0.006347826682031155 | Test loss: 0.12435894313713779


### MobileNetV3_large


In [9]:
if Models.MOBILENETV3_LARGE in run_models:
    mobilenetv3_large = models.mobilenet_v3_large(pretrained=True)
    mobilenetv3_large.classifier[3] = nn.Linear(
        in_features=1280, out_features=2, bias=True
    )
    mobilenetv3_large = mobilenetv3_large.to(device)
    mobilenetv3_large.num_classes = 2

    mobilenetv3_large_optimizer = optim.AdamW(mobilenetv3_large.parameters())

    train(
        mobilenetv3_large,
        train_dataloader,
        test_dataloader,
        mobilenetv3_large_optimizer,
        no_epochs=10,
        model_save_dir="models/MobileNetV3_large",
        save_archive=False,
    )


### ResNet_18


In [10]:
if Models.RESNET_18 in run_models:
    resnet18 = models.resnet18(pretrained=True)
    resnet18.fc = nn.Linear(in_features=512, out_features=1000, bias=True)
    resnet18.num_classes = 2

    resnet18 = resnet18.to(device)

    resnet18_optimizer = optim.Adam(resnet18.parameters())

    train(
        resnet18,
        train_dataloader,
        test_dataloader,
        resnet18_optimizer,
        no_epochs=10,
        model_save_dir="models/ResNet18",
        save_archive=False,
    )
