# ResNet Example

Model definition and training code is adapted from https://github.com/kuangliu/pytorch-cifar.

In [None]:
!(pip show torch-summary >& /dev/null || pip install --quiet torch-summary)

In [None]:
!(test -d .git || test -d mp1 || git clone https://github.com/mike10004/csgy6953-mp1.git mp1)

In [None]:
# change to correct branch here
!(test -d mp1 && cd mp1 && git switch dropout-1)
!(test -d mp1 && cd mp1 && git pull && git rev-parse --short HEAD)

In [None]:
!(test -d mp1 && pip install --quiet --editable mp1)
import site
site.main()

In [None]:
import dlmp1
import importlib
importlib.reload(dlmp1)
from pathlib import Path
print("checked importable:", dlmp1, "at", Path(dlmp1.__file__).parent)

In [None]:
import os
import sys
import shutil
from typing import Optional

# set empty to disable saving
# note that the first path component MyDrive is required
GDRIVE_SAVE_DIR = "MyDrive/CS-GY 6953 DL/deep learning midterm project/checkpoints"

def prepare_mount() -> Optional[str]:
    save_path_root = "/content/gdrive"
    local_save_root = str(os.path.join(save_path_root, GDRIVE_SAVE_DIR)) 
    if GDRIVE_SAVE_DIR:
        try:
            from google.colab import drive
            drive.mount(save_path_root)
            return local_save_root
        except ImportError:
            print("(not saving because not in colab environment)")

LOCAL_GDRIVE_SAVE_PATH = prepare_mount()

def upload_file(src_file: Path, dst_path: str, suppress_error: bool = False) -> str:
    dst_file = Path(LOCAL_GDRIVE_SAVE_PATH) / dst_path
    try:
        dst_file.parent.mkdir(exist_ok=True, parents=True)
        shutil.copyfile(src_file, dst_file)
        return str(dst_file)
    except Exception as e:
        if suppress_error:
            print(f"suppressing save error {type(e)} {e} on file {src_file} -> {dst_file}", file=sys.stderr)
        else:
            raise

def upload_checkpoint(checkpoint_file: Path, infix: str) -> Optional[str]:
    if LOCAL_GDRIVE_SAVE_PATH:
        filename = f"{checkpoint_file.stem}-{infix}{checkpoint_file.suffix}"
        dst_file = upload_file(checkpoint_file, filename)
        return dst_file


In [None]:
from dlmp1.models.resnet import CustomResNetWithDropout
from dlmp1.models.resnet import Hyperparametry
from dlmp1.models.resnet import BlockSpec
# noinspection PyPackageRequirements
import torchsummary

def create_model():
    hyperparametry = Hyperparametry(
        pre_blocks_dropout_rate=0.5, 
        post_blocks_dropout_rate=0.5,
        between_blocks_dropout_rate=0.5,
    )
    return CustomResNetWithDropout([
        BlockSpec(2, 64, stride=1),
        BlockSpec(5, 128, stride=2),
        BlockSpec(3, 256, stride=2),
    ], hyperparametry=hyperparametry) 

def summarize_model():
    model = create_model()
    stats = torchsummary.summary(model, verbose=0)
    print(type(model).__name__, f"{stats.trainable_params/1_000_000:.1f}m trainable parameters ({stats.trainable_params})")
    del model

summarize_model()

In [None]:
from dlmp1.train import Partitioning
BATCH_SIZE_TRAIN = 128
DATASET = Partitioning.prepare(BATCH_SIZE_TRAIN, random_seed=12345)


In [None]:
%matplotlib inline
import numpy as np
from matplotlib.figure import Figure
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from dlmp1.train import History


def plot_epochs_curves(train_hist: History, val_hist: History, title: Optional[str] = None):
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    fig: Figure
    if title:
        fig.suptitle(title)
    for ax, measurements, factor, y_bounds, subject, y_label in zip(axes, 
                                               [(train_hist.losses, val_hist.losses), (train_hist.accs, val_hist.accs)],
                                               [1.0, 100.0],
                                               [None, (0.0, 100.0)],
                                               ["Loss", "Accuracy"],
                                               ["Cross-Entropy Loss", "Correct (%)"]):
        ax: Axes
        ax.set_title(subject)
        ax.set_xlabel("Epochs")
        ax.set_ylabel(y_label)
        train_values, val_values = measurements
        train_values, val_values = np.array(train_values), np.array(val_values)
        epochs = list(range(max(len(train_values), len(val_values))))
        ax.plot(epochs, train_values * factor, label=f"Train {subject}")
        ax.plot(epochs, val_values * factor, label=f"Validation {subject}")
        ax.legend()
        if y_bounds is not None:
            ax.set_ylim(*y_bounds)
    plt.show()

In [0]:
from typing import Any
from tqdm import tqdm
import tabulate

import dlmp1.train
import dlmp1.utils
from dlmp1.train import TrainConfig
from dlmp1.train import EpochInference
from dlmp1.select import iterate_model_factories
from dlmp1.select import iterate_selectables

DO_SELECT_MODEL = False

class TrainingManager:
    
    def __init__(self, progress_bar: Optional[tqdm] = None, progress_desc_prefix: Optional[str] = None):
        self.saturation_threshold = 0.995
        self.progress_bar = progress_bar
        self.progress_desc_prefix = progress_desc_prefix or ""
        self.max_val_acc = None

    def maybe_stop_training(self, epoch: int, lr: Any, train_inf: EpochInference, val_inf: EpochInference):
        if self.progress_bar is not None:
            self.progress_bar.update(1)
            acc_report = val_inf.accuracy()
            if self.max_val_acc is None or acc_report > self.max_val_acc:
                self.max_val_acc = acc_report 
                desc = f"{self.progress_desc_prefix}acc {acc_report:1.2f}"
                self.progress_bar.set_description(desc)
        train_acc = train_inf.accuracy()
        if train_acc >= self.saturation_threshold:
            print(f"training accuracy saturated ({train_acc*100:.1f}%) at epoch", epoch, "with learning rate", lr)
            return True
    

def select_model():
    tag = dlmp1.utils.timestamp()
    factories = iterate_model_factories([
            [2, 1, 1, 1],
            [2, 2, 2],
            [2, 4, 2],
            [2, 5, 2],
            [2, 5, 3],
            [3, 5, 3],
            # [2, 3, 2],
            # [2, 4, 3],
            # [3, 4, 3],
    ])
    epochs = 40
    configs = [
        TrainConfig(epoch_count=epochs, checkpoint_file="auto", learning_rate=0.01, seed=45678, optimizer_type="sgd", quiet=True),
        TrainConfig(epoch_count=epochs, checkpoint_file="auto", learning_rate=0.01, seed=45678, optimizer_type="adam", quiet=True),
    ]
    selectables = list(iterate_selectables(
        factories,
        configs,
    ))
    selectable_best_val_accs = []
    for selectable_index, selectable in enumerate(selectables):
        prefix = f"model {selectable_index+1}/{len(selectables)} "
        title = selectable.description or getattr(selectable.model_factory, "description", "")
        print(prefix, title)
        progress_bar = tqdm(total=selectable.train_config.epoch_count, desc=prefix, file=sys.stdout, position=0, leave=True)
        training_manager = TrainingManager(progress_bar, progress_desc_prefix=prefix)
        train_result = dlmp1.train.perform(
            model_provider=selectable.model_factory,
            dataset=DATASET,
            config=selectable.train_config,
            callback=training_manager.maybe_stop_training
        )
        progress_bar.close()
        if train_result.early_stop_reason:
            print("training terminated early:", train_result.early_stop_reason)
        print("training duration:", train_result.duration_readable())
        uploaded_file = upload_file(train_result.checkpoint_file, 
                    f"model-selection/{tag}/{selectable_index+1}-{train_result.checkpoint_file.name}", 
                    suppress_error=True)
        print("uploaded", uploaded_file)
        best_val_acc = max(train_result.val_history.accs, default=0)
        selectable_best_val_accs.append((best_val_acc, train_result.checkpoint_file, title))
        plot_epochs_curves(
            train_result.train_history, 
            train_result.val_history, 
            title=title or f"S{selectable_index+1} {train_result.checkpoint_file.stem}",
        )
    print()
    print(tabulate.tabulate(selectable_best_val_accs, headers=["validation acc", "checkpoint", "title"]))

if DO_SELECT_MODEL:
    select_model()
    

In [0]:
import json

DO_TRAIN = False
CONFIG = TrainConfig(
    epoch_count=100,
    learning_rate=0.1,
    lr_scheduler_spec="step:gamma=0.9;step_size=25",
    seed=987654321,
)

print(json.dumps(CONFIG._asdict(), indent=2))

TRAIN_RESULT = None
if DO_TRAIN:
    TRAIN_RESULT = dlmp1.train.perform(
        model_provider=create_model,
        dataset=DATASET,
        config=CONFIG,
    )

if TRAIN_RESULT is not None:
    CHECKPOINT_DST_PATH = upload_checkpoint(TRAIN_RESULT.checkpoint_file, TRAIN_RESULT.timestamp)
    if CHECKPOINT_DST_PATH:
        print(f"copied checkpoint file to", CHECKPOINT_DST_PATH)

In [None]:

if TRAIN_RESULT is not None:
    plot_epochs_curves(TRAIN_RESULT.train_history, TRAIN_RESULT.val_history)


In [None]:
import dlmp1.train

def evaluate_test_set():
    if not TRAIN_RESULT:
        return
    testset_loader = Partitioning.prepare_test_loader(batch_size=100)
    inference = dlmp1.train.inference_all(
        TRAIN_RESULT.model, 
        TRAIN_RESULT.device, 
        testset_loader,
        show_progress=True,
    )
    print()
    print(f"{inference.accuracy() * 100:.2f}% is accuracy on test set ({inference.correct}/{inference.total})")

evaluate_test_set()