In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from datasets import WebDataset
from models import WebObjExtractionNet
from train import train_model, evaluate_model
from utils import custom_collate_fn, count_parameters, pkl_load

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Parameters
N_CLASSES = 4 # 0: BG, 1: Price, 2: Image, 3: Title
IMG_HEIGHT = 1440 # Image assumed to have same height and width
EVAL_INTERVAL = 5 # Number of Epochs after which model is evaluated

# Hyperparameters
BATCH_SIZE = 10
ROI_POOL_OUTPUT_SIZE = (5,5)
LEARNABLE_CONVNET = False
LEARNING_RATE = 1e-3
N_EPOCHS = 5

In [None]:
train_dataset = WebDataset('../data/web_data/train')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4,
                          collate_fn=custom_collate_fn, drop_last=False)

test_dataset = WebDataset('../data/web_data/test')
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4,
                         collate_fn=custom_collate_fn, drop_last=False)

In [None]:
model = WebObjExtractionNet(ROI_POOL_OUTPUT_SIZE, IMG_HEIGHT, N_CLASSES, LEARNABLE_CONVNET).to(device)
print('Trainable parameters in model:', count_parameters(model))

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss().to(device)

model = train_model(model, train_loader, optimizer, criterion, N_EPOCHS, device, test_loader, EVAL_INTERVAL)