# Create Dataset

In [1]:
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import pandas as pd
import random
import torch

**Import Data:**

In [2]:
import csv

In [3]:
positive_samples = pd.read_csv("./data/msr_paraphrase_train.txt", sep="\t", encoding="utf-8", quoting=csv.QUOTE_NONE)
positive_samples.head()

Unnamed: 0,Quality,#1 ID,#2 ID,#1 String,#2 String
0,1,702876,702977,"Amrozi accused his brother, whom he called ""th...","Referring to him as only ""the witness"", Amrozi..."
1,0,2108705,2108831,Yucaipa owned Dominick's before selling the ch...,Yucaipa bought Dominick's in 1995 for $693 mil...
2,1,1330381,1330521,They had published an advertisement on the Int...,"On June 10, the ship's owners had published an..."
3,0,3344667,3344648,"Around 0335 GMT, Tab shares were up 19 cents, ...","Tab shares jumped 20 cents, or 4.6%, to set a ..."
4,1,1236820,1236712,"The stock rose $2.11, or about 11 percent, to ...",PG&E Corp. shares jumped $1.63 or 8 percent to...


In [4]:
positive_samples_test = pd.read_csv("./data/msr_paraphrase_test.txt", sep="\t", encoding="utf-8", quoting=csv.QUOTE_NONE)
positive_samples_test.head()

Unnamed: 0,Quality,#1 ID,#2 ID,#1 String,#2 String
0,1,1089874,1089925,"PCCW's chief operating officer, Mike Butcher, ...",Current Chief Operating Officer Mike Butcher a...
1,1,3019446,3019327,The world's two largest automakers said their ...,Domestic sales at both GM and No. 2 Ford Motor...
2,1,1945605,1945824,According to the federal Centers for Disease C...,The Centers for Disease Control and Prevention...
3,0,1430402,1430329,A tropical storm rapidly developed in the Gulf...,A tropical storm rapidly developed in the Gulf...
4,0,3354381,3354396,The company didn't detail the costs of the rep...,But company officials expect the costs of the ...


In [5]:
all_sentences = pd.read_csv("./data/msr_paraphrase_data.txt", sep="\t", encoding="utf-8", quoting=csv.QUOTE_NONE)
all_sentences.head()

Unnamed: 0,Sentence ID,String,Author,URL,Agency,Date,Web Date
0,702876,"Amrozi accused his brother, whom he called ""th...",Darren Goodsir,www.theage.com.au,*,June 5 2003,2003/06/04
1,702977,"Referring to him as only ""the witness"", Amrozi...",Darren Goodsir,www.smh.com.au,Sydney Morning Herald,June 5 2003,2003/06/04
2,2108705,Yucaipa owned Dominick's before selling the ch...,MICHAEL GIBBS,www.nwherald.com,*,*,2003/08/23
3,2108831,Yucaipa bought Dominick's in 1995 for $693 mil...,ALEX VEIGA,www.miami.com,*,*,2003/08/23
4,1330381,They had published an advertisement on the Int...,Philip Pangalos,www.alertnet.org,*,*,2003/06/25


**Add positive label:**

In [7]:
positive_pairs_test = positive_samples_test[['#1 String', '#2 String']].copy()
positive_pairs_test['label'] = 1

In [8]:
positive_pairs = positive_samples[['#1 String', '#2 String']].copy()
positive_pairs['label'] = 1

In [9]:
positive_pairs = pd.concat([positive_samples_test, positive_pairs])
positive_pairs.head()

Unnamed: 0,#1 String,#2 String,label
0,"PCCW's chief operating officer, Mike Butcher, ...",Current Chief Operating Officer Mike Butcher a...,1
1,The world's two largest automakers said their ...,Domestic sales at both GM and No. 2 Ford Motor...,1
2,According to the federal Centers for Disease C...,The Centers for Disease Control and Prevention...,1
3,A tropical storm rapidly developed in the Gulf...,A tropical storm rapidly developed in the Gulf...,1
4,The company didn't detail the costs of the rep...,But company officials expect the costs of the ...,1


**Sample negative pairs:**

In [10]:
random.seed(42)
sentences = all_sentences['String'].tolist()
negative_pairs = []

for _ in range(len(positive_pairs)):
    sent1, sent2 = random.sample(sentences, 2)
    negative_pairs.append((sent1, sent2, 0))

negative_pairs = pd.DataFrame(negative_pairs, columns=['#1 String', '#2 String', 'label'])

**Merge positive and negative pairs:**

In [11]:
dataset = pd.concat([positive_pairs, negative_pairs], ignore_index=True)
dataset = dataset.dropna().reset_index(drop=True)
dataset = dataset.sample(frac=1).reset_index(drop=True)

**Sample validation set:**

In [12]:
train_data, val_data = train_test_split(dataset, test_size=0.1, random_state=42)

**Save data:**

In [14]:
train_data.to_csv("./data/train_set.csv")
val_data.to_csv("./data/val_set.csv")