In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.tensorboard import SummaryWriter

from model import SiameseManhattanBERT, MeanPooler
from dataset import Dataset, Collator
from utils import set_global_seed
from train import train
from metrics import compute_metrics_on_df

2022-05-27 18:39:12.063990: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


In [3]:
set_global_seed(42)

In [4]:
df = pd.read_csv('./data/train.csv', index_col='id')
df.dropna(inplace=True)

df['question1'] = df['question1'].str.lower()
df['question2'] = df['question2'].str.lower()

In [5]:
df = df.iloc[:5000]

In [6]:
df

Unnamed: 0_level_0,qid1,qid2,question1,question2,is_duplicate
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,1,2,what is the step by step guide to invest in sh...,what is the step by step guide to invest in sh...,0
1,3,4,what is the story of kohinoor (koh-i-noor) dia...,what would happen if the indian government sto...,0
2,5,6,how can i increase the speed of my internet co...,how can internet speed be increased by hacking...,0
3,7,8,why am i mentally very lonely? how can i solve...,find the remainder when [math]23^{24}[/math] i...,0
4,9,10,"which one dissolve in water quikly sugar, salt...",which fish would survive in salt water?,0
...,...,...,...,...,...
4995,9850,9851,"what has a better roi, marketing on radio sta...",how does a pirate radio station work?,0
4996,9852,9853,which mobile is good for 50k?,which mobile is good within 20k?,0
4997,9854,9855,is the character jane in the movie predestinat...,what actually happened in predestination?,0
4998,9856,9857,how true are near death experiences?,are near death experiences (ndes) real?,1


In [7]:
df['is_duplicate'].value_counts()

0    3089
1    1911
Name: is_duplicate, dtype: int64

In [8]:
df_train, df_test = train_test_split(
    df,
    test_size=0.25,
    random_state=42,
    shuffle=True,
    stratify=df['is_duplicate'],
)

In [9]:
print(f'Train size: {df_train.shape[0]}')
print(f'Test size:  {df_test.shape[0]}')

Train size: 3750
Test size:  1250


In [10]:
train_dataset = Dataset(df=df_train)
test_dataset = Dataset(df=df_test)

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [12]:
model_name = 'distilroberta-base'

tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name)

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
tokenizer_kwargs = {
    'return_tensors': 'pt',
    'padding':        True,
    'truncation':     True,
    'max_length':     512,
}

collate_fn = Collator(tokenizer, tokenizer_kwargs)

In [14]:
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=128,
    shuffle=True,
    collate_fn=collate_fn,
)
test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=128,
    shuffle=False,
    collate_fn=collate_fn,
)

In [15]:
pooler = MeanPooler()

model = SiameseManhattanBERT(
    bert_model=bert_model,
    pooler=pooler,
).to(device)

In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.MSELoss()

In [17]:
n_epochs = 10
writer = SummaryWriter()

In [18]:
train(
    n_epochs=n_epochs,
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    optimizer=optimizer,
    criterion=criterion,
    writer=writer,
    device=device,
)

Epoch [1 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.24it/s]


Train loss: 0.3811951686938604

Train metrics:
{'accuracy': 0.6178666666666667, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'roc_auc': 0.7105588988335556, 'log_loss': 7.104348402751214}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.82it/s]


Test loss:  0.3827373743057251

Test metrics:
{'accuracy': 0.6176, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'roc_auc': 0.7176301298588679, 'log_loss': 6.266035835900185}

Epoch [2 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.28it/s]


Train loss: 0.38181160589059193

Train metrics:
{'accuracy': 0.6178666666666667, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'roc_auc': 0.6972090447106417, 'log_loss': 5.5332819519519205}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.89it/s]


Test loss:  0.3814109623432159

Test metrics:
{'accuracy': 0.6176, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'roc_auc': 0.708142736358315, 'log_loss': 4.296085237297527}

Epoch [3 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.23it/s]


Train loss: 0.37792660196622213

Train metrics:
{'accuracy': 0.6178666666666667, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'roc_auc': 0.6570796693392478, 'log_loss': 2.7982189663557313}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.87it/s]


Test loss:  0.32436237931251527

Test metrics:
{'accuracy': 0.6176, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'roc_auc': 0.7031429531510829, 'log_loss': 1.3119145934010943}

Epoch [4 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.23it/s]


Train loss: 0.36409761508305866

Train metrics:
{'accuracy': 0.6328, 'precision': 0.5316027088036117, 'recall': 0.3286810886252617, 'f1': 0.40620957309184996, 'roc_auc': 0.6261266810048969, 'log_loss': 0.7692702736047616}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.90it/s]


Test loss:  0.26161422282457353

Test metrics:
{'accuracy': 0.5232, 'precision': 0.4438095238095238, 'recall': 0.9748953974895398, 'f1': 0.6099476439790575, 'roc_auc': 0.7412361523619572, 'log_loss': 0.721356813019514}

Epoch [5 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.23it/s]


Train loss: 0.31347309003273643

Train metrics:
{'accuracy': 0.4162666666666667, 'precision': 0.39488320355951056, 'recall': 0.9909281228192603, 'f1': 0.5647245973354543, 'roc_auc': 0.7218676483565599, 'log_loss': 0.8822590948482355}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.86it/s]


Test loss:  0.36413970589637756

Test metrics:
{'accuracy': 0.3872, 'precision': 0.3842443729903537, 'recall': 1.0, 'f1': 0.5551684088269454, 'roc_auc': 0.7431737377241096, 'log_loss': 0.9686330836296082}

Epoch [6 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.22it/s]


Train loss: 0.2928679202993711

Train metrics:
{'accuracy': 0.38453333333333334, 'precision': 0.3830526597166533, 'recall': 1.0, 'f1': 0.5539234634712022, 'roc_auc': 0.7388633002044116, 'log_loss': 1.0271163654590647}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.88it/s]


Test loss:  0.39792837798595426

Test metrics:
{'accuracy': 0.384, 'precision': 0.38301282051282054, 'recall': 1.0, 'f1': 0.5538818076477405, 'roc_auc': 0.7508983350315432, 'log_loss': 1.0681754980921745}

Epoch [7 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.28it/s]


Train loss: 0.29130504131317136

Train metrics:
{'accuracy': 0.384, 'precision': 0.382847982901416, 'recall': 1.0, 'f1': 0.5537094281298299, 'roc_auc': 0.751537454435058, 'log_loss': 1.0796592754562695}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.87it/s]


Test loss:  0.3961103677749634

Test metrics:
{'accuracy': 0.3872, 'precision': 0.3842443729903537, 'recall': 1.0, 'f1': 0.5551684088269454, 'roc_auc': 0.7511706809460837, 'log_loss': 1.0664963161945342}

Epoch [8 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.28it/s]


Train loss: 0.2819138323267301

Train metrics:
{'accuracy': 0.39466666666666667, 'precision': 0.38686131386861317, 'recall': 0.9986043265875785, 'f1': 0.5576773187840999, 'roc_auc': 0.7437811364829452, 'log_loss': 1.0357280184398094}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.88it/s]


Test loss:  0.3732678651809692

Test metrics:
{'accuracy': 0.4096, 'precision': 0.3930921052631579, 'recall': 1.0, 'f1': 0.564344746162928, 'roc_auc': 0.745170941097405, 'log_loss': 1.0071663561582564}

Epoch [9 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.24it/s]


Train loss: 0.2755727122227351

Train metrics:
{'accuracy': 0.44, 'precision': 0.40522875816993464, 'recall': 0.9951151430565248, 'f1': 0.5759289176090467, 'roc_auc': 0.7390637362544691, 'log_loss': 0.9554663797269265}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.83it/s]


Test loss:  0.34384775459766387

Test metrics:
{'accuracy': 0.4504, 'precision': 0.40889276373147343, 'recall': 0.9811715481171548, 'f1': 0.5772307692307692, 'roc_auc': 0.742257788280183, 'log_loss': 0.9346191865950823}

Epoch [10 / 10]



loop over train batches: 100%|██████████| 30/30 [00:13<00:00,  2.27it/s]


Train loss: 0.2769174804290136

Train metrics:
{'accuracy': 0.4898666666666667, 'precision': 0.42642550582464744, 'recall': 0.9706908583391486, 'f1': 0.5925452609158679, 'roc_auc': 0.7337762001240264, 'log_loss': 0.8910946500132482}



loop over test batches: 100%|██████████| 10/10 [00:01<00:00,  7.87it/s]

Test loss:  0.3293678373098373

Test metrics:
{'accuracy': 0.4864, 'precision': 0.42435424354243545, 'recall': 0.9623430962343096, 'f1': 0.58898847631242, 'roc_auc': 0.7406318425217334, 'log_loss': 0.9024921199902892}






---

In [19]:
model.eval();

In [20]:
train_metrics = compute_metrics_on_df(
    model=model,
    df=train_dataloader.dataset.df,
    tokenizer=train_dataloader.collate_fn.tokenizer,
    tokenizer_kwargs=train_dataloader.collate_fn.tokenizer_kwargs,
)

vectorize question1: 100%|██████████| 3750/3750 [00:16<00:00, 228.16it/s]
vectorize question2: 100%|██████████| 3750/3750 [00:16<00:00, 230.25it/s]


In [21]:
train_metrics

{'accuracy': 0.49306666666666665,
 'precision': 0.4278668310727497,
 'recall': 0.9685973482205164,
 'f1': 0.5935428693607013,
 'roc_auc': 0.7405934352751185,
 'log_loss': 0.8905066748956839}

In [22]:
test_metrics = compute_metrics_on_df(
    model=model,
    df=test_dataloader.dataset.df,
    tokenizer=test_dataloader.collate_fn.tokenizer,
    tokenizer_kwargs=test_dataloader.collate_fn.tokenizer_kwargs,
)

vectorize question1: 100%|██████████| 1250/1250 [00:05<00:00, 231.67it/s]
vectorize question2: 100%|██████████| 1250/1250 [00:05<00:00, 231.93it/s]


In [23]:
test_metrics

{'accuracy': 0.4896,
 'precision': 0.42578849721706863,
 'recall': 0.9602510460251046,
 'f1': 0.589974293059126,
 'roc_auc': 0.7400383723198994,
 'log_loss': 0.8910870019569993}