# Install necessary packages

In [None]:
!pip install -U --quiet pip wheel jupyter
!pip install --quiet numpy pytorch-lightning rank-bm25 torch tqdm transformers

# Import necessary packages

In [None]:
import os

from pytorch_lightning import Trainer
from transformers import AutoTokenizer

from Data import QADataModule
from Model import QAModel

# Set Environment Variables

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'True'

# Set Variables

In [None]:
num_epoch: int = 1

# Load source

In [None]:
data_loader = QADataModule()

# Instantiate model


In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model = QAModel(tokenizer)
trainer = Trainer(
	max_epochs=num_epoch,
	accelerator='auto',
	default_root_dir='model',
	log_every_n_steps=1,
	check_val_every_n_epoch=1,
	limit_train_batches=1,
	limit_val_batches=1,
	limit_test_batches=1,
	limit_predict_batches=1
)

# Train

In [None]:
model.unfreeze()
trainer.fit(model, datamodule=data_loader)

# Test

In [None]:
# model.freeze()
# trainer.test(model, datamodule=data_loader)

# Predict


In [None]:
model.freeze()
predictions = trainer.predict(model, datamodule=data_loader, return_predictions=True)

In [None]:
questions = [question for batch in predictions for question in batch['questions']]
answers = [answer for batch in predictions for answer in batch['answers']]

In [None]:
target_path: str = 'source/test-submit-out.txt'

with open(target_path, 'w', encoding='UTF-8') as target:
	for question, answer in zip(questions, answers):
		target.write('{} ||| {}\n'.format(question, answer))

In [None]:
!cat 'source/test-submit-out.txt'
