# Load Data

In [None]:
data_path = "./kaggle_aes_data"

In [None]:
import pandas as pd
import numpy as np
import os

# Load train dataset
data_df = pd.read_csv(os.path.join(data_path, "train.csv"))

In [None]:
def one_hot_encoding(score, num_classes=6):
  vector = np.zeros(num_classes)
  vector[int(score)-1] = 1
  return vector

In [None]:
# split data into text and score
X_data = data_df['full_text'].tolist()

# change score into one-hot encoding vector
y_data = data_df['score'].map(one_hot_encoding).tolist()

In [None]:
print(X_data[:5])
print(y_data[:5])

['Many people have car where they live. The thing they don\'t know is that when you use a car alot of thing can happen\xa0like you can get in accidet or\xa0the smoke that the car has is bad to breath\xa0on if someone is walk but in VAUBAN,Germany they dont have that proble because 70 percent of vauban\'s families do not own cars,and 57 percent sold a car to move there. Street parkig ,driveways and home garages are forbidden\xa0on the outskirts of freiburd that near the French and Swiss borders. You probaly won\'t see a car in Vauban\'s streets because they are completely "car free" but\xa0If some that lives in VAUBAN that owns a car ownership is allowed,but there are only two places that you can park a large garages at the edge of the development,where a car owner buys a space but it not cheap to buy one they sell the space for you car for $40,000 along with a home. The vauban people completed this in 2006 ,they said that this an example of a growing trend in Europe,The untile states a

In [None]:
from sklearn.model_selection import train_test_split

# split data into train and valid
X_train, X_valid, y_train, y_valid = train_test_split(X_data, y_data, test_size=0.2, stratify=y_data, random_state=42)

In [None]:
print(len(X_train), len(y_train))
print(len(X_valid), len(y_valid))

13845 13845
3462 3462


In [None]:
from transformers import BertTokenizerFast

# load tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# define dataloader
class DataSet(Dataset):
  def __init__(self, tokenizer, X, y=None):
    super(DataSet, self).__init__()
    self.X = X
    self.y = y
    self.tokenizer = tokenizer

  def __len__(self):
    return len(self.X)

  def __getitem__(self, idx):
    encoded = self.tokenizer(self.X[idx], return_tensors='pt', padding='max_length', truncation=True)
    encoded.input_ids = encoded.input_ids.squeeze(0)
    encoded.token_type_ids = encoded.token_type_ids.squeeze(0)
    encoded.attention_mask = encoded.attention_mask.squeeze(0)

    if self.y is not None:
      return {"input_ids":encoded.input_ids, "token_type_ids":encoded.token_type_ids, "attention_mask":encoded.attention_mask}, self.y[idx]
    else:
      return {"input_ids":encoded.input_ids, "token_type_ids":encoded.token_type_ids, "attention_mask":encoded.attention_mask}

# create dataset
train_dataset = DataSet(tokenizer, X_train, y_train)
valid_dataset = DataSet(tokenizer, X_valid, y_valid)

# create dataloader and iter
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, drop_last=True)

train_iter = iter(train_loader)
valid_iter = iter(valid_loader)

# Model

In [None]:
from transformers import BertForSequenceClassification

# Use Bert classification model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
model = model.to("cuda")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Loss

In [None]:
import torch.nn as nn

# Ordinal Log Loss: A simple log-based loss function for ordinal text classification, COLING 2022, Castagnos et al
class OrdinalLogLoss(nn.Module):
  def __init__(self, alpha=1.5):
    super(OrdinalLogLoss, self).__init__()
    self.alpha = alpha

  def forward(self, y_pred, y_true):    # shape: (bath_size, labels) == (32, 6)
    bz, lb = y_true.shape
    index = torch.tensor(list(range(lb))).unsqueeze(0).expand(bz, lb)
    y_pred = F.softmax(y_pred, dim=-1)
    distance = torch.abs(index - torch.argmax(y_true, dim=1).unsqueeze(-1).expand(bz, lb))    # (32, 6)
    loss = torch.mean(-1 * torch.sum(torch.log(1 - y_pred) * torch.pow(distance, self.alpha), dim=1)) # (32, ) = > (1, )
    return loss

# Metrics

In [None]:
import numpy as np
from sklearn.metrics import cohen_kappa_score

def compute_qwk(y_pred, y_true):
  score = cohen_kappa_score(y_true.argmax(-1).numpy(), y_pred.argmax(-1).detach().numpy(), weights='quadratic')
  return score

# Training

In [None]:
import torch.optim as optim

# set optimizer
lr = 2e-5
weight_decay = 0.01
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# set epochs
epochs = 10

In [None]:
# save component
best_valid_loss = float('inf')
best_valid_qwk = 0

save_path = os.path.join(data_path, "model")

In [None]:
#loss_fn = nn.CrossEntropyLoss()
loss_fn = OrdinalLogLoss(1.5)

In [None]:
def eval(model, valid_loader):
  valid_iter = iter(valid_loader)
  with torch.no_grad():
    model.eval()
    loss = 0
    qwk = 0

    for batch in valid_loader:
      encoded, y_true = next(valid_iter)
      encoded["input_ids"] = encoded["input_ids"].to('cuda')
      encoded["token_type_ids"] = encoded["token_type_ids"].to('cuda')
      encoded["attention_mask"] = encoded["attention_mask"].to('cuda')

      y_pred = model(**encoded).logits.cpu()
      loss += loss_fn(y_pred, y_true).item()
      qwk += compute_qwk(y_pred, y_true)

    loss /= len(valid_loader)
    qwk /= len(valid_loader)

  return loss, qwk

In [None]:
from tqdm.auto import tqdm

# training
for epoch in tqdm(range(epochs)):
  train_iter = iter(train_loader)

  train_loss = 0
  for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
    model.train()
    optimizer.zero_grad()

    encoded, y_true = next(train_iter)
    encoded["input_ids"] = encoded["input_ids"].to('cuda')
    encoded["token_type_ids"] = encoded["token_type_ids"].to('cuda')
    encoded["attention_mask"] = encoded["attention_mask"].to('cuda')

    y_pred = model(**encoded).logits.cpu()

    loss = loss_fn(y_pred, y_true)
    print(f"Epoch: {epoch+1}/{epochs} Iter: {i+1}/{len(train_loader)} Loss: {loss:.3f}")

    qwk = compute_qwk(y_pred, y_true)
    print(f"Epoch: {epoch+1}/{epochs} Iter: {i+1}/{len(train_loader)} QWK(batch): {qwk:.3f}")

    loss.backward()
    optimizer.step()

    train_loss += loss.item()

    if i % 100 == 0 and i != 0:
      loss, qwk = eval(model, valid_loader)
      print(f"Epoch: {epoch+1}/{epochs} Iter: {i+1}/{len(train_loader)} Valid Loss: {loss:.3f} Valid QWK: {qwk:.3f}")

      if qwk > best_valid_qwk:
        best_valid_qwk = qwk
        best_valid_loss = loss
        torch.save(model.state_dict(), os.path.join(save_path, "best_model.pt"))

  train_loss /= len(train_loader)
  print(f"Epoch: {epoch+1}/{epochs} Train Loss: {train_loss:.3f}")

print("Best valid qwk: ", best_valid_qwk)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 0/10 Iter: 1/432 Loss: 2.530
Epoch: 0/10 Iter: 1/432 QWK(batch): -0.020


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch: 0/10 Iter: 2/432 Loss: 2.273
Epoch: 0/10 Iter: 2/432 QWK(batch): 0.118
Epoch: 0/10 Iter: 3/432 Loss: 2.382
Epoch: 0/10 Iter: 3/432 QWK(batch): 0.178
Epoch: 0/10 Iter: 4/432 Loss: 2.195
Epoch: 0/10 Iter: 4/432 QWK(batch): -0.031
Epoch: 0/10 Iter: 5/432 Loss: 1.853
Epoch: 0/10 Iter: 5/432 QWK(batch): 0.065
Epoch: 0/10 Iter: 6/432 Loss: 2.181
Epoch: 0/10 Iter: 6/432 QWK(batch): -0.023
Epoch: 0/10 Iter: 7/432 Loss: 1.853
Epoch: 0/10 Iter: 7/432 QWK(batch): 0.148
Epoch: 0/10 Iter: 8/432 Loss: 2.058
Epoch: 0/10 Iter: 8/432 QWK(batch): -0.165
Epoch: 0/10 Iter: 9/432 Loss: 2.060
Epoch: 0/10 Iter: 9/432 QWK(batch): 0.003
Epoch: 0/10 Iter: 10/432 Loss: 1.929
Epoch: 0/10 Iter: 10/432 QWK(batch): -0.169
Epoch: 0/10 Iter: 11/432 Loss: 1.784
Epoch: 0/10 Iter: 11/432 QWK(batch): 0.003
Epoch: 0/10 Iter: 12/432 Loss: 2.053
Epoch: 0/10 Iter: 12/432 QWK(batch): -0.033
Epoch: 0/10 Iter: 13/432 Loss: 1.852
Epoch: 0/10 Iter: 13/432 QWK(batch): -0.083
Epoch: 0/10 Iter: 14/432 Loss: 1.709
Epoch: 0/10 I

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 1/10 Iter: 1/432 Loss: 0.894
Epoch: 1/10 Iter: 1/432 QWK(batch): 0.690
Epoch: 1/10 Iter: 2/432 Loss: 1.109
Epoch: 1/10 Iter: 2/432 QWK(batch): 0.569
Epoch: 1/10 Iter: 3/432 Loss: 1.031
Epoch: 1/10 Iter: 3/432 QWK(batch): 0.746
Epoch: 1/10 Iter: 4/432 Loss: 0.938
Epoch: 1/10 Iter: 4/432 QWK(batch): 0.782
Epoch: 1/10 Iter: 5/432 Loss: 0.710
Epoch: 1/10 Iter: 5/432 QWK(batch): 0.825
Epoch: 1/10 Iter: 6/432 Loss: 1.416
Epoch: 1/10 Iter: 6/432 QWK(batch): 0.449
Epoch: 1/10 Iter: 7/432 Loss: 0.697
Epoch: 1/10 Iter: 7/432 QWK(batch): 0.817
Epoch: 1/10 Iter: 8/432 Loss: 1.055
Epoch: 1/10 Iter: 8/432 QWK(batch): 0.774
Epoch: 1/10 Iter: 9/432 Loss: 1.030
Epoch: 1/10 Iter: 9/432 QWK(batch): 0.726
Epoch: 1/10 Iter: 10/432 Loss: 0.854
Epoch: 1/10 Iter: 10/432 QWK(batch): 0.801
Epoch: 1/10 Iter: 11/432 Loss: 0.847
Epoch: 1/10 Iter: 11/432 QWK(batch): 0.811
Epoch: 1/10 Iter: 12/432 Loss: 0.812
Epoch: 1/10 Iter: 12/432 QWK(batch): 0.806
Epoch: 1/10 Iter: 13/432 Loss: 0.835
Epoch: 1/10 Iter: 13/

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 2/10 Iter: 1/432 Loss: 1.063
Epoch: 2/10 Iter: 1/432 QWK(batch): 0.664
Epoch: 2/10 Iter: 2/432 Loss: 1.120
Epoch: 2/10 Iter: 2/432 QWK(batch): 0.701
Epoch: 2/10 Iter: 3/432 Loss: 0.821
Epoch: 2/10 Iter: 3/432 QWK(batch): 0.816
Epoch: 2/10 Iter: 4/432 Loss: 0.730
Epoch: 2/10 Iter: 4/432 QWK(batch): 0.784
Epoch: 2/10 Iter: 5/432 Loss: 0.774
Epoch: 2/10 Iter: 5/432 QWK(batch): 0.716
Epoch: 2/10 Iter: 6/432 Loss: 0.931
Epoch: 2/10 Iter: 6/432 QWK(batch): 0.626
Epoch: 2/10 Iter: 7/432 Loss: 0.829
Epoch: 2/10 Iter: 7/432 QWK(batch): 0.624
Epoch: 2/10 Iter: 8/432 Loss: 0.762
Epoch: 2/10 Iter: 8/432 QWK(batch): 0.787
Epoch: 2/10 Iter: 9/432 Loss: 0.922
Epoch: 2/10 Iter: 9/432 QWK(batch): 0.806
Epoch: 2/10 Iter: 10/432 Loss: 0.832
Epoch: 2/10 Iter: 10/432 QWK(batch): 0.678
Epoch: 2/10 Iter: 11/432 Loss: 0.674
Epoch: 2/10 Iter: 11/432 QWK(batch): 0.753
Epoch: 2/10 Iter: 12/432 Loss: 0.889
Epoch: 2/10 Iter: 12/432 QWK(batch): 0.660
Epoch: 2/10 Iter: 13/432 Loss: 0.977
Epoch: 2/10 Iter: 13/

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 3/10 Iter: 1/432 Loss: 0.796
Epoch: 3/10 Iter: 1/432 QWK(batch): 0.755
Epoch: 3/10 Iter: 2/432 Loss: 0.720
Epoch: 3/10 Iter: 2/432 QWK(batch): 0.774
Epoch: 3/10 Iter: 3/432 Loss: 0.828
Epoch: 3/10 Iter: 3/432 QWK(batch): 0.754
Epoch: 3/10 Iter: 4/432 Loss: 0.613
Epoch: 3/10 Iter: 4/432 QWK(batch): 0.765
Epoch: 3/10 Iter: 5/432 Loss: 0.799
Epoch: 3/10 Iter: 5/432 QWK(batch): 0.675
Epoch: 3/10 Iter: 6/432 Loss: 0.733
Epoch: 3/10 Iter: 6/432 QWK(batch): 0.691
Epoch: 3/10 Iter: 7/432 Loss: 0.640
Epoch: 3/10 Iter: 7/432 QWK(batch): 0.787
Epoch: 3/10 Iter: 8/432 Loss: 0.859
Epoch: 3/10 Iter: 8/432 QWK(batch): 0.745
Epoch: 3/10 Iter: 9/432 Loss: 0.537
Epoch: 3/10 Iter: 9/432 QWK(batch): 0.810
Epoch: 3/10 Iter: 10/432 Loss: 0.890
Epoch: 3/10 Iter: 10/432 QWK(batch): 0.709
Epoch: 3/10 Iter: 11/432 Loss: 0.926
Epoch: 3/10 Iter: 11/432 QWK(batch): 0.668
Epoch: 3/10 Iter: 12/432 Loss: 0.821
Epoch: 3/10 Iter: 12/432 QWK(batch): 0.785
Epoch: 3/10 Iter: 13/432 Loss: 0.739
Epoch: 3/10 Iter: 13/

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 4/10 Iter: 1/432 Loss: 0.980
Epoch: 4/10 Iter: 1/432 QWK(batch): 0.758
Epoch: 4/10 Iter: 2/432 Loss: 1.093
Epoch: 4/10 Iter: 2/432 QWK(batch): 0.783
Epoch: 4/10 Iter: 3/432 Loss: 0.933
Epoch: 4/10 Iter: 3/432 QWK(batch): 0.734
Epoch: 4/10 Iter: 4/432 Loss: 0.784
Epoch: 4/10 Iter: 4/432 QWK(batch): 0.797
Epoch: 4/10 Iter: 5/432 Loss: 0.953
Epoch: 4/10 Iter: 5/432 QWK(batch): 0.702
Epoch: 4/10 Iter: 6/432 Loss: 0.982
Epoch: 4/10 Iter: 6/432 QWK(batch): 0.553
Epoch: 4/10 Iter: 7/432 Loss: 0.792
Epoch: 4/10 Iter: 7/432 QWK(batch): 0.750
Epoch: 4/10 Iter: 8/432 Loss: 1.024
Epoch: 4/10 Iter: 8/432 QWK(batch): 0.681
Epoch: 4/10 Iter: 9/432 Loss: 0.941
Epoch: 4/10 Iter: 9/432 QWK(batch): 0.635
Epoch: 4/10 Iter: 10/432 Loss: 0.736
Epoch: 4/10 Iter: 10/432 QWK(batch): 0.749
Epoch: 4/10 Iter: 11/432 Loss: 0.819
Epoch: 4/10 Iter: 11/432 QWK(batch): 0.579
Epoch: 4/10 Iter: 12/432 Loss: 0.781
Epoch: 4/10 Iter: 12/432 QWK(batch): 0.662
Epoch: 4/10 Iter: 13/432 Loss: 0.984
Epoch: 4/10 Iter: 13/

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 5/10 Iter: 1/432 Loss: 0.646
Epoch: 5/10 Iter: 1/432 QWK(batch): 0.877
Epoch: 5/10 Iter: 2/432 Loss: 0.822
Epoch: 5/10 Iter: 2/432 QWK(batch): 0.762
Epoch: 5/10 Iter: 3/432 Loss: 0.829
Epoch: 5/10 Iter: 3/432 QWK(batch): 0.701
Epoch: 5/10 Iter: 4/432 Loss: 0.552
Epoch: 5/10 Iter: 4/432 QWK(batch): 0.891
Epoch: 5/10 Iter: 5/432 Loss: 0.700
Epoch: 5/10 Iter: 5/432 QWK(batch): 0.725
Epoch: 5/10 Iter: 6/432 Loss: 0.944
Epoch: 5/10 Iter: 6/432 QWK(batch): 0.701
Epoch: 5/10 Iter: 7/432 Loss: 0.632
Epoch: 5/10 Iter: 7/432 QWK(batch): 0.861
Epoch: 5/10 Iter: 8/432 Loss: 0.566
Epoch: 5/10 Iter: 8/432 QWK(batch): 0.841
Epoch: 5/10 Iter: 9/432 Loss: 0.730
Epoch: 5/10 Iter: 9/432 QWK(batch): 0.744
Epoch: 5/10 Iter: 10/432 Loss: 0.893
Epoch: 5/10 Iter: 10/432 QWK(batch): 0.746
Epoch: 5/10 Iter: 11/432 Loss: 0.859
Epoch: 5/10 Iter: 11/432 QWK(batch): 0.625
Epoch: 5/10 Iter: 12/432 Loss: 0.871
Epoch: 5/10 Iter: 12/432 QWK(batch): 0.743
Epoch: 5/10 Iter: 13/432 Loss: 0.576
Epoch: 5/10 Iter: 13/

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 6/10 Iter: 1/432 Loss: 0.648
Epoch: 6/10 Iter: 1/432 QWK(batch): 0.792
Epoch: 6/10 Iter: 2/432 Loss: 0.973
Epoch: 6/10 Iter: 2/432 QWK(batch): 0.586
Epoch: 6/10 Iter: 3/432 Loss: 0.871
Epoch: 6/10 Iter: 3/432 QWK(batch): 0.703
Epoch: 6/10 Iter: 4/432 Loss: 0.931
Epoch: 6/10 Iter: 4/432 QWK(batch): 0.721
Epoch: 6/10 Iter: 5/432 Loss: 0.757
Epoch: 6/10 Iter: 5/432 QWK(batch): 0.788
Epoch: 6/10 Iter: 6/432 Loss: 0.825
Epoch: 6/10 Iter: 6/432 QWK(batch): 0.667
Epoch: 6/10 Iter: 7/432 Loss: 0.546
Epoch: 6/10 Iter: 7/432 QWK(batch): 0.867
Epoch: 6/10 Iter: 8/432 Loss: 0.612
Epoch: 6/10 Iter: 8/432 QWK(batch): 0.871
Epoch: 6/10 Iter: 9/432 Loss: 0.713
Epoch: 6/10 Iter: 9/432 QWK(batch): 0.787
Epoch: 6/10 Iter: 10/432 Loss: 1.128
Epoch: 6/10 Iter: 10/432 QWK(batch): 0.652
Epoch: 6/10 Iter: 11/432 Loss: 0.632
Epoch: 6/10 Iter: 11/432 QWK(batch): 0.894
Epoch: 6/10 Iter: 12/432 Loss: 0.832
Epoch: 6/10 Iter: 12/432 QWK(batch): 0.776
Epoch: 6/10 Iter: 13/432 Loss: 0.653
Epoch: 6/10 Iter: 13/

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 7/10 Iter: 1/432 Loss: 0.839
Epoch: 7/10 Iter: 1/432 QWK(batch): 0.670
Epoch: 7/10 Iter: 2/432 Loss: 0.665
Epoch: 7/10 Iter: 2/432 QWK(batch): 0.860
Epoch: 7/10 Iter: 3/432 Loss: 0.717
Epoch: 7/10 Iter: 3/432 QWK(batch): 0.754
Epoch: 7/10 Iter: 4/432 Loss: 0.731
Epoch: 7/10 Iter: 4/432 QWK(batch): 0.692
Epoch: 7/10 Iter: 5/432 Loss: 0.782
Epoch: 7/10 Iter: 5/432 QWK(batch): 0.820
Epoch: 7/10 Iter: 6/432 Loss: 0.664
Epoch: 7/10 Iter: 6/432 QWK(batch): 0.794
Epoch: 7/10 Iter: 7/432 Loss: 0.593
Epoch: 7/10 Iter: 7/432 QWK(batch): 0.767
Epoch: 7/10 Iter: 8/432 Loss: 0.702
Epoch: 7/10 Iter: 8/432 QWK(batch): 0.828
Epoch: 7/10 Iter: 9/432 Loss: 0.650
Epoch: 7/10 Iter: 9/432 QWK(batch): 0.786
Epoch: 7/10 Iter: 10/432 Loss: 0.569
Epoch: 7/10 Iter: 10/432 QWK(batch): 0.863
Epoch: 7/10 Iter: 11/432 Loss: 0.815
Epoch: 7/10 Iter: 11/432 QWK(batch): 0.746
Epoch: 7/10 Iter: 12/432 Loss: 0.731
Epoch: 7/10 Iter: 12/432 QWK(batch): 0.859
Epoch: 7/10 Iter: 13/432 Loss: 0.947
Epoch: 7/10 Iter: 13/

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 8/10 Iter: 1/432 Loss: 0.903
Epoch: 8/10 Iter: 1/432 QWK(batch): 0.755
Epoch: 8/10 Iter: 2/432 Loss: 0.834
Epoch: 8/10 Iter: 2/432 QWK(batch): 0.620
Epoch: 8/10 Iter: 3/432 Loss: 0.931
Epoch: 8/10 Iter: 3/432 QWK(batch): 0.682
Epoch: 8/10 Iter: 4/432 Loss: 0.700
Epoch: 8/10 Iter: 4/432 QWK(batch): 0.767
Epoch: 8/10 Iter: 5/432 Loss: 1.146
Epoch: 8/10 Iter: 5/432 QWK(batch): 0.623
Epoch: 8/10 Iter: 6/432 Loss: 0.641
Epoch: 8/10 Iter: 6/432 QWK(batch): 0.798
Epoch: 8/10 Iter: 7/432 Loss: 0.838
Epoch: 8/10 Iter: 7/432 QWK(batch): 0.800
Epoch: 8/10 Iter: 8/432 Loss: 0.570
Epoch: 8/10 Iter: 8/432 QWK(batch): 0.855
Epoch: 8/10 Iter: 9/432 Loss: 0.873
Epoch: 8/10 Iter: 9/432 QWK(batch): 0.738
Epoch: 8/10 Iter: 10/432 Loss: 0.564
Epoch: 8/10 Iter: 10/432 QWK(batch): 0.869
Epoch: 8/10 Iter: 11/432 Loss: 0.669
Epoch: 8/10 Iter: 11/432 QWK(batch): 0.813
Epoch: 8/10 Iter: 12/432 Loss: 0.632
Epoch: 8/10 Iter: 12/432 QWK(batch): 0.809
Epoch: 8/10 Iter: 13/432 Loss: 0.614
Epoch: 8/10 Iter: 13/

  0%|          | 0/432 [00:00<?, ?it/s]

Epoch: 9/10 Iter: 1/432 Loss: 0.790
Epoch: 9/10 Iter: 1/432 QWK(batch): 0.779
Epoch: 9/10 Iter: 2/432 Loss: 0.600
Epoch: 9/10 Iter: 2/432 QWK(batch): 0.887
Epoch: 9/10 Iter: 3/432 Loss: 0.991
Epoch: 9/10 Iter: 3/432 QWK(batch): 0.678
Epoch: 9/10 Iter: 4/432 Loss: 0.632
Epoch: 9/10 Iter: 4/432 QWK(batch): 0.902
Epoch: 9/10 Iter: 5/432 Loss: 0.584
Epoch: 9/10 Iter: 5/432 QWK(batch): 0.881
Epoch: 9/10 Iter: 6/432 Loss: 0.914
Epoch: 9/10 Iter: 6/432 QWK(batch): 0.659
Epoch: 9/10 Iter: 7/432 Loss: 0.545
Epoch: 9/10 Iter: 7/432 QWK(batch): 0.842
Epoch: 9/10 Iter: 8/432 Loss: 0.723
Epoch: 9/10 Iter: 8/432 QWK(batch): 0.769
Epoch: 9/10 Iter: 9/432 Loss: 0.706
Epoch: 9/10 Iter: 9/432 QWK(batch): 0.870
Epoch: 9/10 Iter: 10/432 Loss: 0.716
Epoch: 9/10 Iter: 10/432 QWK(batch): 0.822
Epoch: 9/10 Iter: 11/432 Loss: 0.597
Epoch: 9/10 Iter: 11/432 QWK(batch): 0.836
Epoch: 9/10 Iter: 12/432 Loss: 0.833
Epoch: 9/10 Iter: 12/432 QWK(batch): 0.781
Epoch: 9/10 Iter: 13/432 Loss: 0.632
Epoch: 9/10 Iter: 13/

# Test

In [None]:
test_df = pd.read_csv(os.path.join(data_path, "test.csv"))
test_data = test_df['full_text'].tolist()
test_dataset = DataSet(tokenizer, test_data, None)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False, drop_last=False)
test_iter = iter(test_loader)

In [None]:
print(len(test_dataset))
print(len(test_loader))

3
1


In [None]:
best_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
best_model.load_state_dict(torch.load(os.path.join(save_path, "best_model.pt")))

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [None]:
input= next(test_iter)
input["input_ids"] = input["input_ids"].to('cuda')
input["token_type_ids"] = input["token_type_ids"].to('cuda')
input["attention_mask"] = input["attention_mask"].to('cuda')
y_test_pred = model(**input).logits.cpu()
y_test_pred = y_test_pred.argmax(-1).numpy()

In [None]:
test_true_df = test_df.copy()
test_pred_df = test_df.copy()
test_pred_df['score'] = y_test_pred + 1

In [None]:
test_df

Unnamed: 0,essay_id,full_text
0,000d118,Many people have car where they live. The thin...
1,000fe60,I am a scientist at NASA that is discussing th...
2,001ab80,People always wish they had the same technolog...


In [None]:
#test_true_df = test_df
test_true_df

Unnamed: 0,essay_id,full_text
0,000d118,Many people have car where they live. The thin...
1,000fe60,I am a scientist at NASA that is discussing th...
2,001ab80,People always wish they had the same technolog...


In [None]:
test_pred_df

Unnamed: 0,essay_id,full_text,score
0,000d118,Many people have car where they live. The thin...,3
1,000fe60,I am a scientist at NASA that is discussing th...,3
2,001ab80,People always wish they had the same technolog...,5


In [None]:
test_pred_df.to_csv(os.path.join(data_path, "submission.csv"), index=False)