In [1]:
from datasets import KorSTSDatasets, KorSTS_collate_fn, bucket_pair_indices
from models import SBERT_with_KLUE_BERT
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from torch.optim import Adam

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

train_x_dir = "KorSTS/train_x.npy"
train_y_dir = "KorSTS/train_y.npy"
valid_x_dir = "KorSTS/valid_x.npy"
valid_y_dir = "KorSTS/valid_y.npy"

train_datasets = KorSTSDatasets(train_x_dir, train_y_dir)
valid_datasets = KorSTSDatasets(valid_x_dir, valid_y_dir)

train_seq_lengths = []
for s1, s2 in train_datasets.x:
    train_seq_lengths.append((len(s1), len(s2)))
train_sampler = bucket_pair_indices(train_seq_lengths, batch_size=16, max_pad_len=10)

train_loader = DataLoader(
    train_datasets, 
    collate_fn=KorSTS_collate_fn, 
    batch_sampler=train_sampler
)
valid_loader = DataLoader(
    valid_datasets,
    batch_size=16
)

In [3]:
model = SBERT_with_KLUE_BERT()

epochs = 1
criterion = nn.MSELoss()
optimizer = Adam(params=model.parameters(), lr=.2e-5)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
for name, param in model.named_parameters():
    print(name, param.shape)

bert.embeddings.word_embeddings.weight torch.Size([32000, 768])
bert.embeddings.position_embeddings.weight torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight torch.Size([2, 768])
bert.embeddings.LayerNorm.weight torch.Size([768])
bert.embeddings.LayerNorm.bias torch.Size([768])
bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder

In [5]:
train_loss = []
valid_loss = []

for epoch in range(epochs):
    for data in train_loader:
        s1, s2, label = data
        logits = model(s1, s2)
        loss = criterion(logits, label)
        train_loss.append(loss.detach())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    val_loss = 0
    with torch.no_grad():
        for i, data in enumerate(valid_loader):
            s1, s2, label = data
            logits = model(s1, s2)
            loss = criterion(logits, label)
            val_loss += loss
    valid_loss.append(val_loss.detach()/i)


KeyboardInterrupt: 

In [6]:
train_loss

[tensor(21.4045),
 tensor(5.1395),
 tensor(5.2421),
 tensor(7.4910),
 tensor(7.2007),
 tensor(7.2855),
 tensor(4.5699),
 tensor(3.8467)]