In [1]:
import os

BASE_PATH = ".DATA"

try:
  from google.colab import drive

  DRIVE_PATH = os.path.join("content", "drive")

  drive.mount(DRIVE_PATH)

  BASE_PATH = os.path.join(DRIVE_PATH, "MyDrive", "TP_MACHINE_LEARNING")
except:
  pass

In [2]:
import csv

import tqdm

import torch
import torch.nn as nn

from torch.utils.data import DataLoader

from torchvision import datasets, models, transforms

from sklearn.metrics import (
    accuracy_score,
    balanced_accuracy_score,
    precision_score,
    recall_score,
    f1_score
)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using: {device}")

Using: cpu


In [4]:
train_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

val_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

train_ds = datasets.ImageFolder(os.path.join(BASE_PATH, "DATASET", "train"), transform=train_tf)
val_ds = datasets.ImageFolder(os.path.join(BASE_PATH, "DATASET", "val"), transform=val_tf)
test_ds = datasets.ImageFolder(os.path.join(BASE_PATH, "DATASET", "test"), transform=val_tf)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=16)
test_loader = DataLoader(test_ds, batch_size=16)

In [5]:
RUNS_PATH = os.path.join(BASE_PATH, "RUNS", "RESNET")

os.makedirs(RUNS_PATH, exist_ok=True)

In [6]:
model = models.resnet18(weights='IMAGENET1K_V1')
model.fc = nn.Linear(model.fc.in_features, len(train_ds.classes))

CHECKPOINT_PATH = os.path.join(RUNS_PATH, "CHECKPOINT.pt")

CHECKPOINT = dict(E=100, Ei=0, Estag=0, training_classification_header=True, training_backbone=False, best_metric=-torch.inf)

if os.path.exists(CHECKPOINT_PATH):
    CHECKPOINT = torch.load(CHECKPOINT_PATH)

if "W" in CHECKPOINT:
    model.load_state_dict(CHECKPOINT["W"], strict=False)

for param in model.parameters():
    param.requires_grad = CHECKPOINT["training_backbone"]

for param in model.fc.parameters():
    param.requires_grad = CHECKPOINT["training_classification_header"]

params_to_optimize = list()

if CHECKPOINT["training_backbone"]:
    params_to_optimize += [ p for n, p in model.named_parameters() if n not in [ "fc.weight", "fc.bias" ] ]

if CHECKPOINT["training_classification_header"]:
    params_to_optimize += list(model.fc.parameters())

optimizer = torch.optim.Adam(params_to_optimize, lr=1e-3)

if "optimizer" in CHECKPOINT:
    optimizer.load_state_dict(CHECKPOINT["optimizer"])

model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
labels = torch.tensor(train_ds.targets)

class_counts = torch.bincount(labels)
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum()
class_weights = class_weights.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)

In [8]:
def train(model: nn.Module, optimizer: torch.optim.Optimizer, criterion: nn.Module, loader: DataLoader):
    model.train()

    sum_loss = 0.0

    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)

        loss: torch.Tensor = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        sum_loss = sum_loss + loss.item()

    return sum_loss / len(loader)

In [9]:
def eval(model: nn.Module, loader: torch.utils.data.DataLoader, criterion: nn.Module):
    model.eval()

    all_preds, all_labels = list(), list()

    total_loss = 0.0

    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)

            total_loss += criterion(outputs, labels).item() * labels.size(0)

            preds = outputs.argmax(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    loss = total_loss / len(all_labels)

    return {
        "accuracy": accuracy_score(all_labels, all_preds),
        "balanced_accuracy": balanced_accuracy_score(all_labels, all_preds),
        "precision_macro": precision_score(all_labels, all_preds, average="macro", zero_division=0),
        "recall_macro": recall_score(all_labels, all_preds, average="macro", zero_division=0),
        "f1_macro": f1_score(all_labels, all_preds, average="macro", zero_division=0),
        "f1_weighted": f1_score(all_labels, all_preds, average="weighted", zero_division=0),
        "loss": loss
    }


In [None]:
train_log_path = os.path.join(RUNS_PATH, "train_log.csv")
val_log_path = os.path.join(RUNS_PATH, "val_log.csv")

if not os.path.exists(train_log_path):
    with open(train_log_path, "w", newline="") as f:
        csv.writer(f).writerow([ "epoch", "loss" ])

if not os.path.exists(val_log_path):
    with open(val_log_path, "w", newline="") as f:
        csv.writer(f).writerow([
            "epoch",
            "loss",
            "accuracy",
            "balanced_accuracy",
            "precision_macro",
            "recall_macro",
            "f1_macro",
            "f1_weighted",
            "test_loss",
            "test_accuracy",
            "test_balanced_accuracy",
            "test_precision_macro",
            "test_recall_macro",
            "test_f1_macro",
            "test_f1_weighted"
        ])

for epoch in tqdm.tqdm(range(CHECKPOINT["Ei"], CHECKPOINT["E"]), leave=False):
    if CHECKPOINT["Estag"] >= 10:
        CHECKPOINT["Eearly"] = epoch
        break

    train_loss = train(model=model, optimizer=optimizer, loader=train_loader, criterion=criterion)
    metrics = eval(model=model, loader=val_loader, criterion=criterion)
    test_metrics = eval(model=model, loader=test_loader, criterion=criterion)

    with open(train_log_path, "a", newline="") as f_train, open(val_log_path, "a", newline="") as f_val:
        csv.writer(f_train).writerow([epoch + 1, train_loss])
        csv.writer(f_val).writerow([
            epoch + 1,
            metrics["loss"],
            metrics["accuracy"],
            metrics["balanced_accuracy"],
            metrics["precision_macro"],
            metrics["recall_macro"],
            metrics["f1_macro"],
            metrics["f1_weighted"],
            test_metrics["loss"],
            test_metrics["accuracy"],
            test_metrics["balanced_accuracy"],
            test_metrics["precision_macro"],
            test_metrics["recall_macro"],
            test_metrics["f1_macro"],
            test_metrics["f1_weighted"]
        ])
        f_train.flush(); f_val.flush()
        os.fsync(f_train.fileno()); os.fsync(f_val.fileno())


    if metrics["balanced_accuracy"] > CHECKPOINT["best_metric"]:
        CHECKPOINT["best_metric"] = metrics["balanced_accuracy"]
        CHECKPOINT["best_weights"] = model.state_dict()

        tqdm.tqdm.write(f"Best model updated at epoch {epoch + 1} -> balanced_acc = {CHECKPOINT['best_metric']:.4f}")

        CHECKPOINT["Estag"] = 0
    else:
        CHECKPOINT["Estag"] = CHECKPOINT["Estag"] + 1

    CHECKPOINT["Ei"] = epoch + 1
    CHECKPOINT["W"] = model.state_dict()
    CHECKPOINT["optimizer"] = optimizer.state_dict()

    torch.save(CHECKPOINT, CHECKPOINT_PATH)

CHECKPOINT.pop("optimizer")
CHECKPOINT.pop("Estag")

torch.save(CHECKPOINT, CHECKPOINT_PATH)

                                                 