In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim

from datasets import load_data
from models import WebObjExtractionNet
from train import train_model, evaluate_model
from utils import print_and_log

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

In [None]:
# Parameters
N_CLASSES = 4
CLASS_NAMES = ['BG', 'Price', 'Image', 'Title']
IMG_HEIGHT = 1280 # Image assumed to have same height and width
EVAL_INTERVAL = 5 # Number of Epochs after which model is evaluated
DATA_DIR = '../data/' # Contains .png and .pkl files for train and test data
OUTPUT_DIR = 'results' # logs are saved here!
MODEL_SAVE_DIR = 'saved_models' # trained models are saved here!
TRAIN_SPLIT_ID_FILE = 'train_imgs.txt' # each line should contain name of the training image (without file extension)
TEST_SPLIT_ID_FILE = 'test_imgs.txt'

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)
if not os.path.exists(MODEL_SAVE_DIR):
    os.makedirs(MODEL_SAVE_DIR)

train_img_ids = np.loadtxt(TRAIN_SPLIT_ID_FILE, dtype=np.int32)
test_img_ids = np.loadtxt(TEST_SPLIT_ID_FILE, dtype=np.int32)

In [None]:
# Hyperparameters
N_EPOCHS = 20
LEARNING_RATE = 1e-3
BACKBONE = 'alexnet'
BATCH_SIZE = 50
ROI_POOL_OUTPUT_SIZE = (3,3)
TRAINABLE_CONVNET = True
WEIGHTED_LOSS = False

if WEIGHTED_LOSS:
    weights = torch.Tensor([1,100,100,100]) # weight inversely proportional to number of examples for the class
    print('Weighted loss with class weights:', weights)
else:
    weights = torch.ones(N_CLASSES)

# load train/test data loaders
train_loader, test_loader = load_data(DATA_DIR, train_img_ids, test_img_ids, BATCH_SIZE)

In [None]:
# NOTE: if same hyperparameter configuration is run again, previous log file will be overwritten
params = '%s batch-%d roi-%d lr-%.0e wt_loss-%d' % (BACKBONE, BATCH_SIZE, ROI_POOL_OUTPUT_SIZE[0], LEARNING_RATE, WEIGHTED_LOSS)
log_file = '%s/logs %s.txt' % (OUTPUT_DIR, params)
model_save_file = '%s/saved_model %s.pth' % (MODEL_SAVE_DIR, params)

print('logs will be saved in \"%s\"' % (log_file))
print_and_log('Batch Size: %d' % (BATCH_SIZE), log_file, 'w')
print_and_log('RoI Pool Output Size: (%d, %d)' % ROI_POOL_OUTPUT_SIZE, log_file)
print_and_log('Learning Rate: %.0e\n' % (LEARNING_RATE), log_file)

In [None]:
model = WebObjExtractionNet(ROI_POOL_OUTPUT_SIZE, IMG_HEIGHT, N_CLASSES, BACKBONE, TRAINABLE_CONVNET, CLASS_NAMES).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(weight=weights, reduction='sum').to(device)

model = train_model(model, train_loader, optimizer, criterion, N_EPOCHS, device, test_loader, EVAL_INTERVAL, log_file)

In [None]:
torch.save(model.state_dict(), model_save_file)
print_and_log('Model can be restored from \"%s\"' % (model_save_file), log_file)