# ResNet Example

Model definition and training code is adapted from https://github.com/kuangliu/pytorch-cifar.

In [None]:
!(pip show torch-summary >& /dev/null || pip install --quiet torch-summary)

In [None]:
!(test -d .git || test -d mp1 || git clone https://github.com/mike10004/csgy6953-mp1.git mp1)

In [None]:
# change to correct branch here
!(test -d mp1 && cd mp1 && git switch main)
!(test -d mp1 && cd mp1 && git pull && git rev-parse --short HEAD)

In [None]:
!(test -d mp1 && pip install --quiet --editable mp1)
import site
site.main()

In [None]:
import dlmp1
import importlib
importlib.reload(dlmp1)
from pathlib import Path
print("checked importable:", dlmp1, "at", Path(dlmp1.__file__).parent)

In [None]:
import os
import shutil
from typing import Optional

# set empty to disable saving
# note that the first path component MyDrive is required
GDRIVE_SAVE_DIR = "MyDrive/CS-GY 6953 DL/deep learning midterm project/checkpoints"

def prepare_mount() -> Optional[str]:
    save_path_root = "/content/gdrive"
    local_save_root = str(os.path.join(save_path_root, GDRIVE_SAVE_DIR)) 
    if GDRIVE_SAVE_DIR:
        try:
            from google.colab import drive
            drive.mount(save_path_root)
            return local_save_root
        except ImportError:
            print("(not saving because not in colab environment)")

LOCAL_GDRIVE_SAVE_PATH = prepare_mount()

def upload_checkpoint(checkpoint_file: Path, infix: str) -> Optional[str]:
    if LOCAL_GDRIVE_SAVE_PATH:
        filename = f"{checkpoint_file.stem}-{infix}{checkpoint_file.suffix}"
        dst_file = os.path.join(LOCAL_GDRIVE_SAVE_PATH, filename)
        shutil.copyfile(checkpoint_file, dst_file)
        return dst_file


In [None]:
from dlmp1.models.resnet import CustomResNet
from dlmp1.models.resnet import BlockSpec
# noinspection PyPackageRequirements
import torchsummary

def create_model():
    return CustomResNet([
        BlockSpec(2, 64, stride=1),
        BlockSpec(5, 128, stride=2),
        BlockSpec(3, 256, stride=2),
    ]) 

def summarize_model():
    model = create_model()
    stats = torchsummary.summary(model, verbose=0)
    print(type(model).__name__, f"{stats.trainable_params/1_000_000:.1f}m trainable parameters ({stats.trainable_params})")
    del model

summarize_model()

In [None]:
from dlmp1.train import Partitioning
BATCH_SIZE_TRAIN = 128
DATASET = Partitioning.prepare(BATCH_SIZE_TRAIN, random_seed=12345)


In [None]:
import dlmp1.train
from dlmp1.train import TrainConfig

DO_SELECT_MODEL = False

def select_model():
    raise NotImplemented() 

In [None]:
import json

DO_TRAIN = False
CONFIG = TrainConfig(
    epoch_count=60,
    learning_rate=0.1,
    # lr_scheduler_spec="step:gamma=0.1;step_size=40",
    seed=987654321,
)

print(json.dumps(CONFIG._asdict(), indent=2))

TRAIN_RESULT = None
if DO_TRAIN:
    TRAIN_RESULT = dlmp1.train.perform(
        model_provider=create_model,
        dataset=DATASET,
        config=CONFIG,
    )

if TRAIN_RESULT is not None:
    CHECKPOINT_DST_PATH = upload_checkpoint(TRAIN_RESULT.checkpoint_file, TRAIN_RESULT.timestamp)
    if CHECKPOINT_DST_PATH:
        print(f"copied checkpoint file to", CHECKPOINT_DST_PATH)

In [ ]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from dlmp1.train import History


def plot_epochs_curves(train_hist: History, val_hist: History):
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    for ax, measurements, factor, y_bounds, subject, y_label in zip(axes, 
                                               [(train_hist.losses, val_hist.losses), (train_hist.accs, val_hist.accs)],
                                               [1.0, 100.0],
                                               [None, (0.0, 100.0)],
                                               ["Loss", "Accuracy"],
                                               ["Cross-Entropy Loss", "Correct (%)"]):
        ax: Axes
        ax.set_title(subject)
        ax.set_xlabel("Epochs")
        ax.set_ylabel(y_label)
        train_values, val_values = measurements
        train_values, val_values = np.array(train_values), np.array(val_values)
        epochs = list(range(max(len(train_values), len(val_values))))
        ax.plot(epochs, train_values * factor, label=f"Train {subject}")
        ax.plot(epochs, val_values * factor, label=f"Validation {subject}")
        ax.legend()
        if y_bounds is not None:
            ax.set_ylim(*y_bounds)
    plt.show()

if TRAIN_RESULT is not None:
    plot_epochs_curves(TRAIN_RESULT.train_history, TRAIN_RESULT.val_history)


In [ ]:
import dlmp1.train

def evaluate_test_set():
    if not TRAIN_RESULT:
        return
    testset_loader = Partitioning.prepare_test_loader(batch_size=100)
    inference = dlmp1.train.inference_all(
        TRAIN_RESULT.model, 
        TRAIN_RESULT.device, 
        testset_loader,
        show_progress=True,
    )
    print()
    print(f"{inference.accuracy() * 100:.2f}% is accuracy on test set ({inference.correct}/{inference.total})")

evaluate_test_set()