In [None]:
import importlib

import torch

In [None]:
# Test whether progress bars work.
from tqdm.auto import tqdm
_ = list(tqdm(range(1)))

## Prepare dataset

In [None]:
from awe.data import qa_dataset, swde

_ = importlib.reload(qa_dataset)

In [None]:
sds = swde.Dataset(suffix='-exact')
pages = [p for w in sds.verticals[0].websites[:1] for p in w.pages]
len(pages)

In [None]:
#qa_dataset.prepare_entries(pages, skip_existing=True)

In [None]:
loader = qa_dataset.QaEntryLoader(pages)

In [None]:
#loader.validate()

In [None]:
entry = loader[0]
entry.id

In [None]:
entry.get_all_answer_spans()

In [None]:
entry.labels

## Invoke Transformer

In [None]:
from awe import qa_model
importlib.reload(qa_model)
pipeline = qa_model.QaPipeline()

In [None]:
pipeline.load()

In [None]:
importlib.reload(qa_dataset)
ds = qa_dataset.QaTorchDataset(pages[:1], pipeline.tokenizer)
len(ds)

In [None]:
#ds.validate_all()

In [None]:
from awe import qa_model
importlib.reload(qa_model)
model = qa_model.QaModel(pipeline)

In [None]:
# Test metric computation.
if False:
    batch = ds.get_encodings(0).convert_to_tensors('pt', prepend_batch_axis=True)
    metrics = model.compute_metrics(batch)
    metrics

In [None]:
import torch.utils.data
import pytorch_lightning as pl

In [None]:
from awe import gym
g = gym.Gym(None, None, version_name='')

In [None]:
trainer = pl.Trainer(
    gpus=torch.cuda.device_count(),
    max_epochs=1,
    logger=g.create_logger(),
)

In [None]:
import numpy as np
rng = np.random.default_rng(42)

In [None]:
train_pages = rng.choice(pages, 1000, replace=False)
val_pages = [p for p in pages if p not in train_pages]
len(train_pages), len(val_pages)

In [None]:
train_ds = qa_dataset.QaTorchDataset(train_pages, pipeline.tokenizer)
val_ds = qa_dataset.QaTorchDataset(val_pages, pipeline.tokenizer)

In [None]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=1, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=1)
pred_loader = torch.utils.data.DataLoader(ds, batch_size=1)

In [None]:
trainer.validate(model, pred_loader)

In [None]:
preds = trainer.predict(model, pred_loader)

In [None]:
ds.decode_predictions(preds)

### Finetune

In [None]:
trainer.fit(model, train_loader)

In [None]:
trainer.validate(model, pred_loader)

In [None]:
preds = trainer.predict(model, pred_loader)

In [None]:
ds.decode_predictions(preds)

### Test on unseen

In [None]:
test_pages = [p for w in sds.verticals[0].websites[1:2] for p in w.pages[:10]]
len(test_pages)

In [None]:
test_ds = qa_dataset.QaTorchDataset(test_pages, pipeline.tokenizer)

In [None]:
#qa_dataset.prepare_entries(test_pages)

In [None]:
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=1)

In [None]:
trainer.validate(model, test_loader)

In [None]:
preds = trainer.predict(model, test_loader)

In [None]:
ds.decode_predictions(preds[4:10])