In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from datasets import WebDataset
from models import WebObjExtractionNet
from train import train_model, evaluate_model
from utils import custom_collate_fn, pkl_load, print_and_log

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

In [None]:
# Parameters
N_CLASSES = 4 # 0: BG, 1: Price, 2: Image, 3: Title
CLASS_NAMES = ['BG', 'Price', 'Image', 'Title']
IMG_HEIGHT = 1280 # Image assumed to have same height and width
EVAL_INTERVAL = 5 # Number of Epochs after which model is evaluated
TRAIN_DATA_DIR = '../data/web_data/train'
TEST_DATA_DIR = '../data/web_data/test'
OUTPUT_DIR = 'results'

# Hyperparameters
N_EPOCHS = 20
LEARNING_RATE = 1e-3
BACKBONE = 'alexnet' # 'resnet'
BATCH_SIZE = 50
ROI_POOL_OUTPUT_SIZE = (3,3)
TRAINABLE_CONVNET = True
WEIGHTED_LOSS = False

if WEIGHTED_LOSS:
    WEIGHTS = torch.Tensor([1,100,100,100]) # weight inversely proportional to number of examples for the class
    print('Weighted loss with class weights:', WEIGHTS)
else:
    WEIGHTS = torch.ones(N_CLASSES)

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# NOTE: if same hyperparameter configuration is run again, previous log file will be overwritten
LOG_FILE = '%s/logs %s batch-%d roi-%d lr-%.0e wt_loss-%d.txt' % (OUTPUT_DIR, BACKBONE, BATCH_SIZE, ROI_POOL_OUTPUT_SIZE[0],
                                                                  LEARNING_RATE, WEIGHTED_LOSS)
print('logs will be saved in \"%s\"' % (LOG_FILE))
print_and_log('Batch Size: %d' % (BATCH_SIZE), LOG_FILE, 'w')
print_and_log('RoI Pool Output Size: (%d, %d)' % ROI_POOL_OUTPUT_SIZE, LOG_FILE)
print_and_log('Learning Rate: %.0e\n' % (LEARNING_RATE), LOG_FILE)

In [None]:
train_dataset = WebDataset(TRAIN_DATA_DIR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
                          collate_fn=custom_collate_fn, drop_last=False)

test_dataset = WebDataset(TEST_DATA_DIR)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4,
                         collate_fn=custom_collate_fn, drop_last=False)

In [None]:
print('Train Images:', len(train_dataset))
print('Test  Images:', len(test_dataset))

In [None]:
model = WebObjExtractionNet(ROI_POOL_OUTPUT_SIZE, IMG_HEIGHT, N_CLASSES, BACKBONE, TRAINABLE_CONVNET, CLASS_NAMES).to(device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(weight=WEIGHTS, reduction='sum').to(device)

model = train_model(model, train_loader, optimizer, criterion, N_EPOCHS, device, test_loader, EVAL_INTERVAL, LOG_FILE)