In [9]:
import os
import argparse

import torch

import mixed_precision
from stats import StatTracker
from datasets import Dataset, build_dataset, get_dataset, get_encoder_size
from model import Model
from checkpoint import Checkpointer
from task_self_supervised import train_self_supervised
from task_classifiers import train_classifiers
from task_unsupervised import train_unsupervised

In [None]:
# create target output dir if it doesn't exist yet
training_checkpoint_dir = './checkpoint'
if not os.path.isdir(training_checkpoint_dir):
    os.mkdir(training_checkpoint_dir)

# enable mixed-precision computation
mixed_precision.enable_mixed_precision()

# set the RNG seeds (probably more hidden elsewhere...)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

# get encoder training dataset
yfcc_dataset = get_dataset('YFCC100M')
yfcc_encoder_size = get_encoder_size(yfcc_dataset)

# get classifier training dataset
stl10_dataset = get_dataset('STL10')
stl10_encoder_size = get_encoder_size(stl10_dataset)

# get a helper object for tensorboard logging
stat_tracker = StatTracker(log_dir='./logs')

# get dataloaders for training AMDIM's encoder module
yfcc_train_loader, yfcc_test_loader, yfcc_num_classes = \
    build_dataset(dataset=yfcc_dataset,
                  batch_size=4,
                  input_dir='./yfcc100m',
                  labeled_only=False)

# get dataloaders for testing the classifier
stl10_train_loader, stl10_test_loader, stl10_num_classes = \
    build_dataset(dataset=stl10_dataset,
                  batch_size=200,
                  labeled_only=True)

torch_device = torch.device('cuda')
checkpointer = Checkpointer(training_checkpoint_dir)

# create new unsupervised model with random parameters
encoder_model = Model(
                    ndf=128, 
                    n_classes=yfcc_num_classes, 
                    n_rkhs=1024,
                    tclip=20.0, 
                    n_depth=3, 
                    encoder_size=yfcc_encoder_size,
                    use_bn=False)

encoder_model.init_weights(init_scale=1.0)
checkpointer.track_new_model(encoder_model)

# train encoder model
encoder_model = encoder_model.to(torch_device)

train_encoder_task = train_unsupervised
train_encoder_task(encoder_model, 0.0002, yfcc_dataset, yfcc_train_loader,
     yfcc_test_loader, stat_tracker, checkpointer, training_checkpoint_dir, torch_device)

log_dir: ./logs
Files already downloaded and verified


In [4]:
# # train classifier model
# encoder_model_path = os.path.join(training_checkpoint_dir, 'amdim_cpt.pth')
# classifier_model = checkpointer.restore_model_from_checkpoint(encoder_model_path, training_classifier=True)
# classifier_model = classifier_model.to(torch_device)

# train_classifier_task = train_classifiers
# train_classifier_task(classifier_model, 0.0002, stl10_dataset, stl10_train_loader,
#     stl10_test_loader, stat_tracker, checkpointer, training_checkpoint_dir, torch_device)