## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
REPO_ROOT = "/content/ml-practices"
LAB2_ROOT = os.path.join(REPO_ROOT, "lab2")
DATASET_ROOT = "/content/dataset"
DRIVE_ROOT = "/content/drive/MyDrive/colab/hw2/"
TRAIN_DATASET = os.path.join(DATASET_ROOT, "train_dataset")
TEST_DATASET  = os.path.join(DATASET_ROOT, "test_dataset")
TRAIN_CSV     = os.path.join(DATASET_ROOT, "train.csv")
TEST_CSV      = os.path.join(DATASET_ROOT, "test.csv")
os.makedirs(DRIVE_ROOT, exist_ok=True)

In [None]:
if not os.path.isdir(REPO_ROOT):
    !git clone https://github.com/bhbbbbb/ml-practices.git
    !cd {REPO_ROOT} && git checkout colab
    !pip install git+https://github.com/benjs/nfnets_pytorch
else:
    !cd {REPO_ROOT} && git pull 

### Download dataset

In [None]:
if not os.path.isdir(TRAIN_DATASET):
    !pip install --upgrade --no-cache-dir gdown
    !gdown --id 1hj2zrZI3Nd-C6nlGOE1crgR_gnpoKHQh --output 'dataset.zip'
    !unzip -q dataset.zip -d '/content/dataset' # the -d should be the same as DataPath

else:
    print("File already exists.")

In [None]:
import sys
sys.path.append(LAB2_ROOT)
from imgclf.model_utils import ModelUtils
from imgclf.dataset import Dataset
from imgclf.models import FatLeNet5, FakeVGG16
from hw2.utils import DatasetUtils
from hw2.config import Hw2Config
import torch.cuda
# assert torch.cuda.is_available()

## Training

In [None]:
def train(config, model, epochs):
    df, cat = DatasetUtils.load_csv(TRAIN_CSV, TRAIN_DATASET)
    datasets = Dataset.split(df, split_ratio=[0.7, 0.15], config=config)
    utils = ModelUtils.start_new_training(model=model, config=config)
    utils.train(epochs, *datasets)
    utils.plot_history()

## Restart Training

In [None]:
def train_from(config, model, epochs, checkpoint_path):
    df, cat = DatasetUtils.load_csv(TRAIN_CSV, TRAIN_DATASET)
    datasets = Dataset.split(df, split_ratio=[0.7, 0.15], config=config)
    utils = ModelUtils.load_checkpoint(model=model, config=config, checkpoint_path=checkpoint_path)
    utils.train(epochs, *datasets)
    utils.plot_history()
    return

def retrain(config, model, epochs):
    df, cat = DatasetUtils.load_csv(TRAIN_CSV, TRAIN_DATASET)
    datasets = Dataset.split(df, split_ratio=[0.7, 0.15], config=config)
    utils = ModelUtils.load_last_checkpoint(model=model, config=config)
    utils.train(datasets, epochs=epochs)
    utils.plot_history()
    return

## Inference

In [None]:
def inference(config, categories: list, model):
    df = DatasetUtils.load_test_csv(
            csv_path = TEST_CSV,
            images_root = TEST_DATASET,
        )
    dataset = Dataset(df, config=config, mode="inference")
    utils = ModelUtils.load_last_checkpoint(model=model, config=config)
    df = utils.inference(dataset, categories, confidence=True)
    return df

## Do it!

In [None]:
config = Hw2Config(log_dir = DRIVE_ROOT)
config.display()

In [None]:
# train(config, FakeVGG16(config), 100)

## Pretrained Model (NFnet-F1)

In [None]:
if not os.path.isfile("/content/F1_haiku.npz")
    !wget https://storage.googleapis.com/dm-nfnets/$F1_haiku.npz

In [None]:
from nfnet.nfnet_model_utils import NfnetModelUtils
from nfnet.config import NfnetConfig

## Do it!!!

In [None]:
config = NfnetConfig(log_dir = DRIVE_ROOT)
config.learning_rate = 0.1
config.batch_size["train"] = 32
config.batch_size["eval"] = 32
config.num_workers = 2
config.num_class = 10
config.display()

In [None]:
# nfnet = pretrained_nfnet("/content/F1_haiku.npz")
epochs = 50
df, cat = DatasetUtils.load_csv(TRAIN_CSV, TRAIN_DATASET)
datasets = Dataset.split(df, split_ratio=[0.7, 0.15], config=config)
model = NfnetModelUtils.init_model(config)
utils = NfnetModelUtils.load_last_checkpoint(model, config)
utils.train(epochs, *datasets)