In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from archehr import PROJECT_DIR
from archehr.data.dataset import QADataset
from archehr.data.utils import load_data, make_query_sentence_pairs


In [50]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

data_path = PROJECT_DIR / "data" / "1.1" / "dev"
data = load_data(data_path)
n_cases = len(data)

data_train = data[:int(0.8 * n_cases)]
data_val = data[int(0.8 * n_cases):]

model_name = "cross-encoder/nli-deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

In [4]:
pairs_train = make_query_sentence_pairs(data_train)
pairs_val = make_query_sentence_pairs(data_val)

dataset_train = QADataset(pairs_train, tokenizer)
dataset_val = QADataset(pairs_val, tokenizer)

In [5]:
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader

collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding=True,
    return_tensors="pt",
)

dataloader_train = DataLoader(
    dataset_train,
    batch_size=16,
    collate_fn=collator,
    shuffle=True,
)

In [6]:
for batch in dataloader_train:
    print(batch)
    break

{'input_ids': tensor([[    1,   273, 93260,  ...,     0,     0,     0],
        [    1,  5047,   342,  ...,     0,     0,     0],
        [    1,   273,   481,  ...,     0,     0,     0],
        ...,
        [    1,   273,   268,  ...,     0,     0,     0],
        [    1,  2709,   315,  ...,     0,     0,     0],
        [    1,  1396,   343,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([2, 1, 1, 0, 1, 2, 0, 2, 1, 2, 1, 1, 1, 2, 2, 1])}


In [10]:
import torch

device = torch.device("cuda")

In [14]:
sum([p.numel() for p in model.parameters()]) / 1e6

184.424451

In [15]:
output = model(**batch)
print(output)

SequenceClassifierOutput(loss=tensor(6.4338, grad_fn=<NllLossBackward0>), logits=tensor([[ 6.3799, -4.4442, -1.0335],
        [-1.7905, -2.2732,  4.1379],
        [-0.5677, -3.6543,  4.7269],
        [-1.2216, -3.6515,  5.2742],
        [-2.1872, -1.7249,  3.8649],
        [-2.1213, -3.4478,  5.8751],
        [-2.5685, -2.6757,  5.2950],
        [-4.4795,  4.2381, -1.1220],
        [ 4.7607, -5.0976,  1.3202],
        [-2.4224, -2.9208,  5.4491],
        [ 1.4655, -4.5428,  4.0046],
        [-0.3691, -3.9276,  4.8572],
        [-2.0808, -3.2519,  5.4992],
        [ 7.1560, -4.4744, -1.9675],
        [-2.6215, -2.7454,  5.3977],
        [-0.1862, -4.6314,  5.5866]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [16]:
loss = output.loss
loss.backward()

In [20]:
for p in model.parameters():
    break

In [25]:
p.numel()

98380800

In [26]:
p.grad.shape.numel()

98380800

In [27]:
from archehr.cross_encode.nli_deberta import remove_last_layer

In [59]:
model_name = "cross-encoder/nli-deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

model_ = remove_last_layer(model)

model has 2307 trainable parameters.


In [52]:
for name, param in model_.named_parameters():
    if 'classifier' in name:
        param.requires_grad = True

    else:
        param.requires_grad = False

In [53]:
sum([p.numel() for p in model_.parameters() if p.requires_grad])

2307

In [15]:
import torch

optimizer = torch.optim.AdamW(model.parameters())

In [None]:
from tqdm import tqdm

model.train()

for batch in tqdm(dataloader_train):
    labels = batch['labels']
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    

 20%|██        | 13/64 [03:09<12:24, 14.60s/it]


KeyboardInterrupt: 

: 