EDIT THE FIRST CELL TO CUSTOMIZE YOUR ENVIRONMENT

In [1]:
use_colab=True  # True daca folosesti colab
resume_training = True  # True daca vrei sa continui antrenarea unui model deja antrenat. Poti lasa True din prima pentru ca am inglobat si cazul daca nu exista un checkpoint
data_dir = "/content/data"  # directorul in care se afla datele de input
ckpt_dir = "/content/drive/MyDrive/Colab Notebooks/Dissertation Results/Checkpoints/"  # directorul in care se salveaza checkpoint-urile
utils_dir = "/content/drive/My Drive/Colab Notebooks/dissertation_resources"  # directorul in care se afla utilitarele (in zip este utils)
dataset_names = ['brats']  # numele dataset-urilor pe care le folosesti, sunt ordonate [source, target1, target2]
train_imgs_no = 20000  # numarul de imagini de antrenament pentru fiecare dataset. Pentru seturile mai mici, se selecteaza automat un numar mai mic de imagini, dar se cicleaza in dataloader ca sa atinga nr asta de imagini
val_imgs_no = train_imgs_no // 20  # numarul de imagini de validare sau de test pentru fiecare dataset
loader_img_dim = 256  # dimensiunea imaginilor de input pentru dataloader (se redimensioneaza automat)
batch_size = 1  # dimensiunea batch-ului pentru dataloader
num_epochs = 15  # numarul de epoci pentru antrenare

In [2]:
if use_colab:
    from google.colab import drive
    drive.mount('/content/drive')
    !pip install --quiet gdown
    import sys
    sys.path.append(utils_dir)

Mounted at /content/drive


In [3]:
import os
import torch
import time
import numpy as np
from utils.data_processing_utils import extract_and_store_data_from_datasets, get_model
from utils.DomainDataset import DomainDataset, collate_fn
from utils.UDAFasterRCNN import UDAFasterRCNN
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [4]:
extract_and_store_data_from_datasets(data_dir, dataset_names)

Unzipped bmshare.zip successfully to /content/data
Removed bmshare.zip successfully.


In [5]:
source_train_dir = os.path.join(data_dir, dataset_names[0], "splitted_data/train")
source_val_dir = os.path.join(data_dir, dataset_names[0], "splitted_data/val")

domain_labels = {'brats': 0, 'bmshare': 1, 'isles': 2}

# Create datasets
source_train_dataset = DomainDataset(source_train_dir, domain_label=domain_labels[dataset_names[0]], img_size=loader_img_dim, img_extension=".png", length=train_imgs_no)
source_val_dataset   = DomainDataset(source_val_dir,   domain_label=domain_labels[dataset_names[0]], img_size=loader_img_dim, img_extension=".png", length=val_imgs_no)

# Loaders
source_train_loader = DataLoader(source_train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
source_val_loader   = DataLoader(source_val_dataset,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn)

In [6]:
def get_lr(optimizer):
    """Get the current learning rate from the optimizer."""
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [7]:
def train(source_train_loader, source_val_loader,
             model, optimizer, scheduler, max_epochs, root_dir,
             start_epoch=1, resume_training=False):

    best_val_loss = float('inf')
    ckpt_path = os.path.join(root_dir, f"New_Checkpoint_Simple_{dataset_names[0]}.pth")
    best_ckpt_path = os.path.join(root_dir, f"New_Best_Checkpoint_Simple_{dataset_names[0]}.pth")
    empty_ckpt_history = {
        'train_total_loss': [],
        'val_total_loss': [],
        'lr': [],
        'train_epoch_time': [], 'val_epoch_time': [],
    }


    print('resume_training: ', resume_training)
    if resume_training and os.path.exists(ckpt_path):
        print("Loading checkpoint…")
        checkpoint = torch.load(ckpt_path, weights_only=False, map_location='cpu')
        history = checkpoint.get('history', empty_ckpt_history)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint['best_val_loss']
        print(f"Resuming training from epoch {start_epoch} with best_val_loss = {best_val_loss:.4f}")
    else:
        history = empty_ckpt_history
        print("Starting training from scratch")

    device = next(model.parameters()).device  # assume model is already on the right device

    for epoch in range(start_epoch, max_epochs + 1):
        global_step = 0
        last_percent = -1
        step = -1

        train_total_loss = 0.0

        print(f"\nEpoch {epoch}")
        print("Train:", end="", flush=True)

        model.train()
        epoch_train_start = time.time()

        for batches in tqdm(source_train_loader):
            step += 1
            source_batch = batches
            optimizer.zero_grad()

            # Source batch
            source_images, source_targets = source_batch
            source_images = [img.to(device) for img in source_images]
            source_targets = [{k: v.to(device) for k, v in t.items()} for t in source_targets]

            # Forward
            losses = model(source_images, source_targets)
            total_loss = sum(losses.values())

            # Backward + Optimizer step
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Accumulate
            train_total_loss += total_loss.item()
            global_step += 1

        # Normalize
        num_train_steps = step + 1
        train_total_loss /= num_train_steps

        epoch_train_end = time.time()
        elapsed1 = epoch_train_end - epoch_train_start
        print(f"\nEpoch {epoch} training time: {int(elapsed1//60)}m {int(elapsed1%60)}s")

        # Validation
        print("Val:", end="", flush=True)
        # model.eval()
        val_total_loss = 0.0
        step = -1
        last_percent = -1
        epoch_val_start = time.time()

        with torch.no_grad():
            for batches in tqdm(source_val_loader):
                step += 1
                source_batch = batches
                # Source batch
                source_images, source_targets = source_batch
                source_images = [img.to(device) for img in source_images]
                source_targets = [{k: v.to(device) for k, v in t.items()} for t in source_targets]

                # Use the same last alpha or recompute if desired
                losses = model(source_images, source_targets)
                total_loss = sum(losses.values())

                val_total_loss += total_loss.item()

        # Normalize validation
        num_val_steps = step + 1
        val_total_loss /= num_val_steps

        epoch_val_end = time.time()
        elapsed2 = epoch_val_end - epoch_val_start
        print(f"\nEpoch {epoch} validation time: {int(elapsed2//60)}m {int(elapsed2%60)}s")

        # Logging
        print(f"\n[Epoch {epoch}]")
        print(f"Train Loss: {train_total_loss:.4f}")
        print(f" Val  Loss: {val_total_loss:.4f}")

        history['train_total_loss'].append(train_total_loss)
        history['val_total_loss'].append(val_total_loss)
        history['lr'].append(optimizer.param_groups[0]['lr'])
        history['train_epoch_time'].append(elapsed1)
        history['val_epoch_time'].append(elapsed2)

        # Save best
        if val_total_loss < best_val_loss:
            print("Saving best model…")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
                'history': history
            }, best_ckpt_path)
            best_val_loss = val_total_loss

        # Checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_loss': best_val_loss,
            'history': history
        }, ckpt_path)

        if scheduler is not None:
            old_lr = get_lr(optimizer)
            scheduler.step(val_total_loss)
            new_lr = get_lr(optimizer)
            if new_lr != old_lr:
                print(f"Learning rate changed from {old_lr:.6f} to {new_lr:.6f}")

    return


In [None]:
NUM_CLASSES = 1 + len(dataset_names) # 3 + 1 (background)


model = get_model(num_classes=2)
model.to(device)

optimizer = SGD(
    model.parameters(),
    lr=1e-3,
    momentum=0.9,
    weight_decay=1e-4
)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',         # we want to minimize val loss
    factor=0.5,         # lr ← lr * factor
    patience=5,         # wait 5 epochs with no improvement
)

# Pass this into training:
train(source_train_loader, source_val_loader,
              model, optimizer, scheduler, max_epochs=num_epochs,
              root_dir=ckpt_dir,
              resume_training=resume_training
         )

resume_training:  False
Starting training from scratch

Epoch 1
Train:

  1%|          | 126/20000 [00:40<1:40:44,  3.29it/s]