In [None]:
import sys
import os
import time
import yaml
import copy
from pathlib import Path
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import models
import ops.trains as trains
import ops.tests as tests
import ops.datasets as datasets
import ops.schedulers as schedulers

In [None]:
config_path = "%s/configs/imagenet_vit.yaml"
with open(config_path) as f:
    args = yaml.load(f)
    
dataset_args = copy.deepcopy(args).get("dataset")
train_args = copy.deepcopy(args).get("train")
val_args = copy.deepcopy(args).get("val")
model_args = copy.deepcopy(args).get("model")
optim_args = copy.deepcopy(args).get("optim")
env_args = copy.deepcopy(args).get("env")


dataset_train, dataset_test = datasets.get_dataset(**dataset_args, download=True)
dataset_name = dataset_args["name"]
num_classes = len(dataset_train.classes)

dataset_train = DataLoader(dataset_train, 
                           shuffle=True, 
                           num_workers=train_args.get("num_workers", 4), 
                           batch_size=train_args.get("batch_size", 128))
dataset_test = DataLoader(dataset_test, 
                          num_workers=val_args.get("num_workers", 4), 
                          batch_size=val_args.get("batch_size", 128))

print("Train: %s, Test: %s, Classes: %s" % (
    len(dataset_train.dataset), 
    len(dataset_test.dataset), 
    num_classes
))


name = "daa_50"


model = models.get_model(name, num_classes=num_classes)
model = nn.DataParallel(model)



current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = os.path.join("runs", dataset_name, model.name, current_time)
writer = SummaryWriter(log_dir)

with open("%s/config.yaml" % log_dir, "w") as f:
    yaml.dump(args, f)
with open("%s/model.log" % log_dir, "w") as f:
    f.write(repr(model))

print("Create TensorBoard log dir: ", log_dir)



gpu = torch.cuda.is_available()
optimizer, train_scheduler = trains.get_optimizer(model, **optim_args)
warmup_scheduler = schedulers.WarmupScheduler(optimizer, len(dataset_train) * train_args.get("warmup_epochs", 0))

trains.train(model, optimizer,
             dataset_train, dataset_test,
             train_scheduler, warmup_scheduler,
             train_args, val_args, gpu,
             writer, 
             snapshot=-1, dataset_name=dataset_name, uid=current_time)  # Set `snapshot=N` to save snapshots every N epochs.