In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import pandas as pd
from sklearn.model_selection import train_test_split
import re
import unicodedata
import string

# Set random seed for PyTorch CPU operations
torch.manual_seed(42)


# Preprocessing functions
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

def clean_text(text):
    text = unicode_to_ascii(text.lower().strip())
    text = re.sub(r"i'm", "i am", text)
    text = re.sub(r"\r", "", text)
    text = re.sub(r"he's", "he is", text)
    text = re.sub(r"she's", "she is", text)
    text = re.sub(r"it's", "it is", text)
    text = re.sub(r"that's", "that is", text)
    text = re.sub(r"what's", "that is", text)
    text = re.sub(r"where's", "where is", text)
    text = re.sub(r"how's", "how is", text)
    text = re.sub(r"\'ll", " will", text)
    text = re.sub(r"\'ve", " have", text)
    text = re.sub(r"\'re", " are", text)
    text = re.sub(r"\'d", " would", text)
    text = re.sub(r"\'re", " are", text)
    text = re.sub(r"won't", "will not", text)
    text = re.sub(r"can't", "cannot", text)
    text = re.sub(r"n't", " not", text)
    text = re.sub(r"n'", "ng", text)
    text = re.sub(r"'bout", "about", text)
    text = re.sub(r"'til", "until", text)
    text = re.sub(r"[-()\"#/@;:<>{}`+=~|.!?,]", "", text)
    text = text.translate(str.maketrans('', '', string.punctuation)) 
    text = re.sub("(\\W)"," ",text) 
    text = re.sub('\S*\d\S*\s*','', text)
    text =  "<sos> " +  text + " <eos>"
    
    return text
    
    

# Custom Dataset class
class DialogDataset(Dataset):
    def __init__(self, questions, answers, src_vocab, tgt_vocab):
        self.questions = questions
        self.answers = answers
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.tokenizer = get_tokenizer('basic_english')

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        src = [self.src_vocab[token] for token in self.tokenizer(self.questions[idx])]
        tgt = [self.tgt_vocab[token] for token in self.tokenizer(self.answers[idx])]
        return torch.tensor(src), torch.tensor(tgt)

    @staticmethod
    def collate_fn(batch):
        src_batch, tgt_batch = [], []
        for src_item, tgt_item in batch:
            src_batch.append(torch.cat([torch.tensor([src_vocab["<sos>"]]), src_item, torch.tensor([src_vocab["<eos>"]])], dim=0))
            tgt_batch.append(torch.cat([torch.tensor([tgt_vocab["<sos>"]]), tgt_item, torch.tensor([tgt_vocab["<eos>"]])], dim=0))
        src_batch = pad_sequence(src_batch, padding_value=src_vocab["<pad>"]).transpose(0, 1)
        tgt_batch = pad_sequence(tgt_batch, padding_value=tgt_vocab["<pad>"]).transpose(0, 1)
        return src_batch, tgt_batch

# Load and preprocess data
data = pd.read_csv("./dialogs.txt", sep='\t', header=None, names=['question', 'answer'])
data["question"] = data.question.apply(clean_text)
data["answer"] = data.answer.apply(clean_text)

# Split data
train_data, val_data = train_test_split(data, test_size=0.2)

# Build vocabularies
tokenizer = get_tokenizer('basic_english')

def build_vocab(data):
    vocab = build_vocab_from_iterator(map(tokenizer, data), specials=["<pad>", "<sos>", "<eos>"])
    vocab.set_default_index(vocab["<pad>"])
    return vocab


src_vocab = build_vocab(train_data['question'])
tgt_vocab = build_vocab(train_data['answer'])

# Create datasets
train_dataset = DialogDataset(train_data['question'].tolist(), train_data['answer'].tolist(), src_vocab, tgt_vocab)
val_dataset = DialogDataset(val_data['question'].tolist(), val_data['answer'].tolist(), src_vocab, tgt_vocab)

# DataLoader
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=DialogDataset.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=DialogDataset.collate_fn)


# checking dimension batch from DataLoader
for src, tgt in train_loader:
    print("Batch shapes:", src.shape, tgt.shape)
    break

#save vocabularies
import pickle

# Assuming src_vocab and tgt_vocab are your source and target vocabularies
with open('src_vocab.pkl', 'wb') as f:
    pickle.dump(src_vocab, f)

with open('tgt_vocab.pkl', 'wb') as f:
    pickle.dump(tgt_vocab, f)
    
print(len(src_vocab))
print(len(tgt_vocab))


#save train and val data
train_data.to_pickle('train_data.pkl')
val_data.to_pickle('val_data.pkl')

Batch shapes: torch.Size([64, 19]) torch.Size([64, 19])
2090
2137
Max index in source batch: 2084
Max index in target batch: 1990
Max index in source batch: 2007
Max index in target batch: 2061
Max index in source batch: 2067
Max index in target batch: 2131
Max index in source batch: 2010
Max index in target batch: 2081
Max index in source batch: 2031
Max index in target batch: 2119
Max index in source batch: 2068
Max index in target batch: 1887
