In [5]:
import os
import argparse

import torch

import mixed_precision
from stats import StatTracker
from datasets import Dataset, build_dataset, get_dataset, get_encoder_size
from model import Model
from checkpoint import Checkpointer
from task_self_supervised import train_self_supervised
from task_classifiers import train_classifiers
from task_unsupervised import train_unsupervised

In [8]:
# training parameters
output_dir = './runs'
is_amp = True
seed = 1
dataset_name = 'STL10'
run_name = 'default_run'
stl10_batch_size = 200
input_dir = '/mnt/imagenet'
is_train_classifiers = False
input_ndf = 128  # feature width for encoder
input_n_rkhs = 1024
input_tclip = 20.0
input_n_depth = 3
used_batchnorm = False
input_learning_rate = 0.0002

def main():
    if not os.path.isdir(output_dir):
            os.mkdir(output_dir)

    # enable mixed-precision computation if desired
    if is_amp:
        mixed_precision.enable_mixed_precision()

    # set the RNG seeds (probably more hidden elsewhere...)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # get the dataset
    dataset = get_dataset(dataset_name)
    encoder_size = get_encoder_size(dataset)

    # get a helper object for tensorboard logging
    log_dir = os.path.join(output_dir, run_name)
    stat_tracker = StatTracker(log_dir=log_dir)

    # get dataloaders for training and testing
    train_loader, test_loader, num_classes = \
        build_dataset(dataset=dataset,
                      batch_size=stl10_batch_size,
                      input_dir=input_dir,
                      labeled_only=is_train_classifiers)

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer(output_dir)

    # create new model with random parameters
    model = Model(ndf=input_ndf, n_classes=num_classes, n_rkhs=input_n_rkhs,
                tclip=input_tclip, n_depth=input_n_depth, encoder_size=encoder_size,
                use_bn=used_batchnorm)
    model.init_weights(init_scale=1.0)
    checkpointer.track_new_model(model)


    model = model.to(torch_device)

    # select which type of training to do
    task = train_classifiers
    task(model, input_learning_rate, dataset, train_loader,
         test_loader, stat_tracker, checkpointer, output_dir, torch_device)

In [None]:
if __name__ == "__main__":
    main()

log_dir: ./runs/default_run
Files already downloaded and verified
Files already downloaded and verified
Using a 64x64 encoder
Epoch 0, 100 updates -- 0.2052 sec/update
