# ResNet Example

In [None]:
!(pip show torch-summary || pip install torch-summary)

In [None]:
!(test -d .git || test -d mp1 || git clone https://github.com/mike10004/csgy6953-mp1.git mp1)

In [None]:
# change to correct branch here
!(cd mp1 && git switch flexy-resnet)

In [None]:
!(test -d mp1 && cd mp1 && git pull)
!(test -d mp1 && cp -r mp1/dlmp1 ./)

In [ ]:

# set empty to disable saving
GDRIVE_SAVE_DIR = "CS-GY 6953 DL/deep learning midterm project/checkpoints"


In [None]:
import dlmp1
from pathlib import Path
print("imported", dlmp1, "from", Path(dlmp1.__file__).parent.parent)

In [None]:
from torch import nn
from dlmp1.models.resnet import CustomResNet
from dlmp1.models.resnet import BlockSpec
import torchsummary

MODEL = CustomResNet([
    BlockSpec(2, 64, stride=1),
    BlockSpec(5, 128, stride=2),
    BlockSpec(3, 256, stride=2),
]) 

def summarize_model(model: nn.Module):
    stats = torchsummary.summary(model, verbose=0)
    print(type(model).__name__, f"{stats.trainable_params/1_000_000:.1f}m trainable parameters ({stats.trainable_params})")

summarize_model(MODEL)

In [ ]:
from dlmp1.main import Dataset
BATCH_SIZE_TRAIN = 128
DATASET = Dataset.acquire(BATCH_SIZE_TRAIN)


In [None]:
import dlmp1.main
from dlmp1.main import TrainConfig

DO_TRAIN = False
CONFIG = TrainConfig(
    epoch_count=50,
    learning_rate=0.1,
    verbose_scheduler=True,
)

TRAIN_RESULT = None
if DO_TRAIN:
    TRAIN_RESULT = dlmp1.main.perform(
        model=MODEL,
        dataset=DATASET,
        config=CONFIG,
    )


In [ ]:
import os
import sys
import shutil
import datetime
from typing import Optional


def get_save_path(checkpoint_file: Path, salt: str = None) -> Optional[str]:
    if not GDRIVE_SAVE_DIR:
        return
    try:
        from google.colab import drive
        save_path_root = "/content/gdrive"
        drive.mount(save_path_root)
        if salt is None:
            salt = "-" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        dst_file = os.path.join(save_path_root, GDRIVE_SAVE_DIR, checkpoint_file.stem + salt + checkpoint_file.suffix)
        return dst_file
    except ImportError:
        print("not saving because not in colab environment", file=sys.stderr)
        

if TRAIN_RESULT is not None:
    CHECKPOINT_DST_PATH = get_save_path(TRAIN_RESULT.checkpoint_file)
    if CHECKPOINT_DST_PATH:
        shutil.copyfile(TRAIN_RESULT.checkpoint_file, CHECKPOINT_DST_PATH)
        print(f"copied checkpoint file to", CHECKPOINT_DST_PATH)
