In [7]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizer, BertModel
from pytorch_pretrained_bert import BertAdam

from sklearn.metrics import mean_absolute_error, mean_squared_error, median_absolute_error
from tqdm import tqdm

In [8]:
data = pd.read_csv('CLEAN.csv',index_col=0)
# print(data['misarticulation_index'].values[:15])
data = data.loc[data['first_lang_english']==1]
data.drop(columns=['first_lang_english'],inplace=True)
data = data[['response_text','misarticulation_index']]

data['misarticulation_index'] = (data['misarticulation_index']/0.33333).astype(int).astype(float)/18
# print(data['misarticulation_index'].values[:15])

In [31]:
split_ = np.random.RandomState(seed=0).permutation(data.shape[0])
num_train = int(data.shape[0]*0.7)

data_train, data_test =data[:num_train], data[num_train:]

In [32]:
class TMPDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.X = df['response_text'].values
        self.y = df['misarticulation_index'].values

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

batch_size = 8
    
train_dataset = TMPDataset(data_train)
test_dataset = TMPDataset(data_test)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

In [33]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [41]:
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path = "bert-base-uncased") # bert-base-uncased
model = BertModel.from_pretrained(pretrained_model_name_or_path = "bert-base-uncased") # bert-base-uncased
# tokenizer.save_pretrained('./')
# model.save_pretrained('./')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [35]:
regressor = nn.Sequential(nn.Linear(768,64),nn.ELU(),nn.Linear(64,1))

In [37]:
optimizer = BertAdam(list(model.parameters())+list(regressor.parameters()), lr=5e-5,
                     weight_decay=1e-2, warmup=0.2, t_total=10*(num_train//batch_size))

In [38]:
loss_fn = torch.nn.L1Loss()

In [39]:
for i in range(10):    
    for X,y in tqdm(train_dataloader):
#         y = y.cuda()
        model.train()
        regressor.train()
        optimizer.zero_grad()
        
        encoded_input = tokenizer(list(X), padding=True, truncation=True, return_tensors="pt",max_length=512)
#         encoded_input = {k:v.cuda() for k,v in encoded_input.items()}
        output = model(**encoded_input)
        # print(output.pooler_output)
        logits = regressor(output.pooler_output)
        loss = loss_fn(logits.flatten(), y)
        # print(loss)
        loss.backward()
        optimizer.step()

        # break
    print(f"---------------------epoch {i}----------------------")
    with torch.no_grad():
        model.eval()
        regressor.eval()
        y_pred_list = []
        y_true_list = []
        for X,y in test_dataloader:
            encoded_input = tokenizer(list(X), padding=True, truncation=True, return_tensors="pt",max_length=512)
            output = model(**encoded_input)
            y_pred = regressor(output.pooler_output).flatten().cpu().numpy()
            y_pred_list.append(y_pred)
            y_true_list.append(y)
        y_pred = np.hstack(y_pred_list)
        y_true = np.hstack(y_true_list)
        ms = [mean_absolute_error, mean_squared_error, median_absolute_error]
        for m in ms:
            print(m, m(y_true*6, y_pred*6))

100%|███████████████████████████████████████████| 44/44 [12:44<00:00, 17.38s/it]


---------------------epoch 0----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.9457354233191725
<function mean_squared_error at 0x7fbba2d95d30> 1.5703517851530884
<function median_absolute_error at 0x7fbba2d95f70> 0.9560967683792114


100%|███████████████████████████████████████████| 44/44 [14:41<00:00, 20.04s/it]


---------------------epoch 1----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.8810899388371853
<function mean_squared_error at 0x7fbba2d95d30> 1.2417257453197745
<function median_absolute_error at 0x7fbba2d95f70> 0.7887363632520039


100%|███████████████████████████████████████████| 44/44 [32:23<00:00, 44.16s/it]


---------------------epoch 2----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.8684325414268594
<function mean_squared_error at 0x7fbba2d95d30> 1.2159604016073853
<function median_absolute_error at 0x7fbba2d95f70> 0.7304771343866985


100%|███████████████████████████████████████████| 44/44 [28:10<00:00, 38.41s/it]


---------------------epoch 3----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.8721571081040199
<function mean_squared_error at 0x7fbba2d95d30> 1.30842385881367
<function median_absolute_error at 0x7fbba2d95f70> 0.7337990601857505


100%|███████████████████████████████████████████| 44/44 [12:59<00:00, 17.72s/it]


---------------------epoch 4----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.9099314177506849
<function mean_squared_error at 0x7fbba2d95d30> 1.336962712286562
<function median_absolute_error at 0x7fbba2d95f70> 0.8078422149022422


100%|███████████████████████████████████████████| 44/44 [13:30<00:00, 18.42s/it]


---------------------epoch 5----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.959207788894051
<function mean_squared_error at 0x7fbba2d95d30> 1.4863309249105594
<function median_absolute_error at 0x7fbba2d95f70> 0.7732179860273997


100%|███████████████████████████████████████████| 44/44 [13:01<00:00, 17.77s/it]


---------------------epoch 6----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.9878640403611617
<function mean_squared_error at 0x7fbba2d95d30> 1.5303982849254942
<function median_absolute_error at 0x7fbba2d95f70> 0.8752762079238892


100%|███████████████████████████████████████████| 44/44 [13:21<00:00, 18.22s/it]


---------------------epoch 7----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.9859430426568314
<function mean_squared_error at 0x7fbba2d95d30> 1.511545148947009
<function median_absolute_error at 0x7fbba2d95f70> 0.839295546213786


100%|███████████████████████████████████████████| 44/44 [13:12<00:00, 18.01s/it]


---------------------epoch 8----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.9083902200585917
<function mean_squared_error at 0x7fbba2d95d30> 1.3428554445483674
<function median_absolute_error at 0x7fbba2d95f70> 0.744165857632955


100%|███████████████████████████████████████████| 44/44 [13:04<00:00, 17.84s/it]


---------------------epoch 9----------------------
<function mean_absolute_error at 0x7fbba2d953a0> 0.9001904543031727
<function mean_squared_error at 0x7fbba2d95d30> 1.325231106782395
<function median_absolute_error at 0x7fbba2d95f70> 0.7205376625061035
